From 1908ca2697aae88cd87e9ce66cda6ea48687df40 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Tue, 18 Feb 2025 08:16:13 -0700 Subject: [PATCH] more tests --- config/session.go | 5 +-- config/session_test.go | 63 +++++++++++++++++++++++++++++++++++++ internal/jwtutil/jwtutil.go | 20 ++++++++++++ 3 files changed, 86 insertions(+), 2 deletions(-) diff --git a/config/session.go b/config/session.go index 1c8db4acd..7e6637212 100644 --- a/config/session.go +++ b/config/session.go @@ -135,6 +135,7 @@ type IncomingIDPTokenSessionCreator interface { } type incomingIDPTokenSessionCreator struct { + timeNow func() time.Time getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) putRecords func(ctx context.Context, records []*databroker.Record) error } @@ -143,7 +144,7 @@ func NewIncomingIDPTokenSessionCreator( getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error), putRecords func(ctx context.Context, records []*databroker.Record) error, ) IncomingIDPTokenSessionCreator { - return &incomingIDPTokenSessionCreator{getRecord: getRecord, putRecords: putRecords} + return &incomingIDPTokenSessionCreator{timeNow: time.Now, getRecord: getRecord, putRecords: putRecords} } // CreateSession attempts to create a session for incoming idp access and @@ -265,7 +266,7 @@ func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims( sessionID string, claims jwtutil.Claims, ) *session.Session { - now := time.Now() + now := c.timeNow() s := new(session.Session) s.Id = sessionID if userID, ok := claims.GetUserID(); ok { diff --git a/config/session_test.go b/config/session_test.go index 42142d1af..1880b5cdd 100644 --- a/config/session_test.go +++ b/config/session_test.go @@ -5,10 +5,12 @@ import ( "net/http" "net/url" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/httputil" @@ -17,6 +19,7 @@ import ( "github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/identity" ) @@ -350,6 +353,66 @@ func TestGetIncomingIDPIdentityTokenForPolicy(t *testing.T) { } } +func Test_newSessionFromIDPClaims(t *testing.T) { + t.Parallel() + + tm1 := time.Date(2025, 2, 18, 8, 6, 0, 0, time.UTC) + tm2 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + tm3 := tm2.Add(time.Hour) + + for _, tc := range []struct { + name string + sessionID string + claims jwtutil.Claims + expect *session.Session + }{ + { + "empty claims", "S1", + nil, + &session.Session{ + Id: "S1", + AccessedAt: timestamppb.New(tm1), + ExpiresAt: timestamppb.New(tm1.Add(time.Hour * 14)), + IssuedAt: timestamppb.New(tm1), + }, + }, + { + "full claims", "S2", + jwtutil.Claims{ + "aud": "https://www.example.com", + "sub": "U1", + "iat": tm2.Unix(), + "exp": tm3.Unix(), + }, + &session.Session{ + Id: "S2", + UserId: "U1", + AccessedAt: timestamppb.New(tm1), + ExpiresAt: timestamppb.New(tm3), + IssuedAt: timestamppb.New(tm2), + Audience: []string{"https://www.example.com"}, + Claims: identity.FlattenedClaims{ + "aud": {"https://www.example.com"}, + "sub": {"U1"}, + "iat": {tm2.Unix()}, + "exp": {tm3.Unix()}, + }.ToPB(), + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + cfg := &Config{Options: NewDefaultOptions()} + c := &incomingIDPTokenSessionCreator{ + timeNow: func() time.Time { return tm1 }, + } + actual := c.newSessionFromIDPClaims(cfg, tc.sessionID, tc.claims) + testutil.AssertProtoEqual(t, tc.expect, actual) + }) + } +} + func Test_newUserFromIDPClaims(t *testing.T) { t.Parallel() diff --git a/internal/jwtutil/jwtutil.go b/internal/jwtutil/jwtutil.go index 7903dec2f..10f64e26b 100644 --- a/internal/jwtutil/jwtutil.go +++ b/internal/jwtutil/jwtutil.go @@ -89,10 +89,30 @@ func (claims Claims) GetNumericDate(name string) (tm time.Time, ok bool) { } switch v := raw.(type) { + case float32: + return time.Unix(int64(v), 0), true case float64: return time.Unix(int64(v), 0), true case int64: return time.Unix(v, 0), true + case int32: + return time.Unix(int64(v), 0), true + case int16: + return time.Unix(int64(v), 0), true + case int8: + return time.Unix(int64(v), 0), true + case int: + return time.Unix(int64(v), 0), true + case uint64: + return time.Unix(int64(v), 0), true + case uint32: + return time.Unix(int64(v), 0), true + case uint16: + return time.Unix(int64(v), 0), true + case uint8: + return time.Unix(int64(v), 0), true + case uint: + return time.Unix(int64(v), 0), true case json.Number: i, err := v.Int64() if err != nil {