mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-07 21:32:54 +02:00
more tests
This commit is contained in:
parent
41454b2156
commit
1908ca2697
3 changed files with 86 additions and 2 deletions
|
@ -135,6 +135,7 @@ type IncomingIDPTokenSessionCreator interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type incomingIDPTokenSessionCreator struct {
|
type incomingIDPTokenSessionCreator struct {
|
||||||
|
timeNow func() time.Time
|
||||||
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
|
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
|
||||||
putRecords func(ctx context.Context, records []*databroker.Record) error
|
putRecords func(ctx context.Context, records []*databroker.Record) error
|
||||||
}
|
}
|
||||||
|
@ -143,7 +144,7 @@ func NewIncomingIDPTokenSessionCreator(
|
||||||
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error),
|
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error),
|
||||||
putRecords func(ctx context.Context, records []*databroker.Record) error,
|
putRecords func(ctx context.Context, records []*databroker.Record) error,
|
||||||
) IncomingIDPTokenSessionCreator {
|
) IncomingIDPTokenSessionCreator {
|
||||||
return &incomingIDPTokenSessionCreator{getRecord: getRecord, putRecords: putRecords}
|
return &incomingIDPTokenSessionCreator{timeNow: time.Now, getRecord: getRecord, putRecords: putRecords}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateSession attempts to create a session for incoming idp access and
|
// CreateSession attempts to create a session for incoming idp access and
|
||||||
|
@ -265,7 +266,7 @@ func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
|
||||||
sessionID string,
|
sessionID string,
|
||||||
claims jwtutil.Claims,
|
claims jwtutil.Claims,
|
||||||
) *session.Session {
|
) *session.Session {
|
||||||
now := time.Now()
|
now := c.timeNow()
|
||||||
s := new(session.Session)
|
s := new(session.Session)
|
||||||
s.Id = sessionID
|
s.Id = sessionID
|
||||||
if userID, ok := claims.GetUserID(); ok {
|
if userID, ok := claims.GetUserID(); ok {
|
||||||
|
|
|
@ -5,10 +5,12 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
@ -17,6 +19,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/testutil"
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/identity"
|
"github.com/pomerium/pomerium/pkg/identity"
|
||||||
)
|
)
|
||||||
|
@ -350,6 +353,66 @@ func TestGetIncomingIDPIdentityTokenForPolicy(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_newSessionFromIDPClaims(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tm1 := time.Date(2025, 2, 18, 8, 6, 0, 0, time.UTC)
|
||||||
|
tm2 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
tm3 := tm2.Add(time.Hour)
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
sessionID string
|
||||||
|
claims jwtutil.Claims
|
||||||
|
expect *session.Session
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"empty claims", "S1",
|
||||||
|
nil,
|
||||||
|
&session.Session{
|
||||||
|
Id: "S1",
|
||||||
|
AccessedAt: timestamppb.New(tm1),
|
||||||
|
ExpiresAt: timestamppb.New(tm1.Add(time.Hour * 14)),
|
||||||
|
IssuedAt: timestamppb.New(tm1),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"full claims", "S2",
|
||||||
|
jwtutil.Claims{
|
||||||
|
"aud": "https://www.example.com",
|
||||||
|
"sub": "U1",
|
||||||
|
"iat": tm2.Unix(),
|
||||||
|
"exp": tm3.Unix(),
|
||||||
|
},
|
||||||
|
&session.Session{
|
||||||
|
Id: "S2",
|
||||||
|
UserId: "U1",
|
||||||
|
AccessedAt: timestamppb.New(tm1),
|
||||||
|
ExpiresAt: timestamppb.New(tm3),
|
||||||
|
IssuedAt: timestamppb.New(tm2),
|
||||||
|
Audience: []string{"https://www.example.com"},
|
||||||
|
Claims: identity.FlattenedClaims{
|
||||||
|
"aud": {"https://www.example.com"},
|
||||||
|
"sub": {"U1"},
|
||||||
|
"iat": {tm2.Unix()},
|
||||||
|
"exp": {tm3.Unix()},
|
||||||
|
}.ToPB(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cfg := &Config{Options: NewDefaultOptions()}
|
||||||
|
c := &incomingIDPTokenSessionCreator{
|
||||||
|
timeNow: func() time.Time { return tm1 },
|
||||||
|
}
|
||||||
|
actual := c.newSessionFromIDPClaims(cfg, tc.sessionID, tc.claims)
|
||||||
|
testutil.AssertProtoEqual(t, tc.expect, actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_newUserFromIDPClaims(t *testing.T) {
|
func Test_newUserFromIDPClaims(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
|
@ -89,10 +89,30 @@ func (claims Claims) GetNumericDate(name string) (tm time.Time, ok bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch v := raw.(type) {
|
switch v := raw.(type) {
|
||||||
|
case float32:
|
||||||
|
return time.Unix(int64(v), 0), true
|
||||||
case float64:
|
case float64:
|
||||||
return time.Unix(int64(v), 0), true
|
return time.Unix(int64(v), 0), true
|
||||||
case int64:
|
case int64:
|
||||||
return time.Unix(v, 0), true
|
return time.Unix(v, 0), true
|
||||||
|
case int32:
|
||||||
|
return time.Unix(int64(v), 0), true
|
||||||
|
case int16:
|
||||||
|
return time.Unix(int64(v), 0), true
|
||||||
|
case int8:
|
||||||
|
return time.Unix(int64(v), 0), true
|
||||||
|
case int:
|
||||||
|
return time.Unix(int64(v), 0), true
|
||||||
|
case uint64:
|
||||||
|
return time.Unix(int64(v), 0), true
|
||||||
|
case uint32:
|
||||||
|
return time.Unix(int64(v), 0), true
|
||||||
|
case uint16:
|
||||||
|
return time.Unix(int64(v), 0), true
|
||||||
|
case uint8:
|
||||||
|
return time.Unix(int64(v), 0), true
|
||||||
|
case uint:
|
||||||
|
return time.Unix(int64(v), 0), true
|
||||||
case json.Number:
|
case json.Number:
|
||||||
i, err := v.Int64()
|
i, err := v.Int64()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue