diff --git a/config/session.go b/config/session.go index 40e893c47..d895f18eb 100644 --- a/config/session.go +++ b/config/session.go @@ -2,6 +2,7 @@ package config import ( "context" + "errors" "fmt" "net/http" "strings" @@ -210,7 +211,15 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken( AccessToken: rawAccessToken, ExpiresAt: s.ExpiresAt, } - u := c.newUserFromIDPClaims(res.Claims) + + u, err := c.getUser(ctx, s.GetUserId()) + if storage.IsNotFound(err) { + u = &user.User{Id: s.GetUserId()} + } else if err != nil { + return nil, fmt.Errorf("error retrieving existing user: %w", err) + } + c.fillUserFromIDPClaims(u, res.Claims) + err = c.putSessionAndUser(ctx, s, u) if err != nil { return nil, fmt.Errorf("error saving session and user: %w", err) @@ -261,7 +270,15 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken( s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims) s.SetRawIDToken(rawIdentityToken) - u := c.newUserFromIDPClaims(res.Claims) + + u, err := c.getUser(ctx, s.GetUserId()) + if errors.Is(err, storage.ErrNotFound) { + u = &user.User{Id: s.GetUserId()} + } else if err != nil { + return nil, fmt.Errorf("error retrieving existing user: %w", err) + } + c.fillUserFromIDPClaims(u, res.Claims) + err = c.putSessionAndUser(ctx, s, u) if err != nil { return nil, fmt.Errorf("error saving session and user: %w", err) @@ -305,10 +322,10 @@ func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims( return s } -func (c *incomingIDPTokenSessionCreator) newUserFromIDPClaims( +func (c *incomingIDPTokenSessionCreator) fillUserFromIDPClaims( + u *user.User, claims jwtutil.Claims, -) *user.User { - u := new(user.User) +) { if userID, ok := claims.GetUserID(); ok { u.Id = userID } @@ -318,8 +335,7 @@ func (c *incomingIDPTokenSessionCreator) newUserFromIDPClaims( if email, ok := claims.GetString("email"); ok { u.Email = email } - u.Claims = identity.Claims(claims).Flatten().ToPB() - return u + u.AddClaims(identity.Claims(claims).Flatten()) } func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, sessionID string) (*session.Session, error) { @@ -341,6 +357,25 @@ func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, session return s, nil } +func (c *incomingIDPTokenSessionCreator) getUser(ctx context.Context, userID string) (*user.User, error) { + record, err := c.getRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID) + if err != nil { + return nil, err + } + + msg, err := record.GetData().UnmarshalNew() + if err != nil { + return nil, storage.ErrNotFound + } + + u, ok := msg.(*user.User) + if !ok { + return nil, storage.ErrNotFound + } + + return u, nil +} + func (c *incomingIDPTokenSessionCreator) putSessionAndUser(ctx context.Context, s *session.Session, u *user.User) error { var records []*databroker.Record if id := s.GetId(); id != "" { diff --git a/config/session_test.go b/config/session_test.go index 745e85b31..bf349001b 100644 --- a/config/session_test.go +++ b/config/session_test.go @@ -13,6 +13,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/encoding/jws" @@ -416,19 +417,41 @@ func Test_newSessionFromIDPClaims(t *testing.T) { } } -func Test_newUserFromIDPClaims(t *testing.T) { +func Test_fillUserFromIDPClaims(t *testing.T) { t.Parallel() for _, tc := range []struct { - name string - claims jwtutil.Claims - expect *user.User + name string + claims jwtutil.Claims + current *user.User + expect *user.User }{ - {"empty claims", nil, &user.User{}}, + {"empty claims", nil, nil, &user.User{}}, {"full claims", jwtutil.Claims{ "sub": "USER_ID", "name": "NAME", "email": "EMAIL", + }, nil, &user.User{ + Id: "USER_ID", + Name: "NAME", + Email: "EMAIL", + Claims: identity.FlattenedClaims{ + "sub": {"USER_ID"}, + "name": {"NAME"}, + "email": {"EMAIL"}, + }.ToPB(), + }}, + {"existing claims", jwtutil.Claims{ + "sub": "USER_ID", + }, &user.User{ + Id: "USER_ID", + Name: "NAME", + Email: "EMAIL", + Claims: identity.FlattenedClaims{ + "sub": {"USER_ID"}, + "name": {"NAME"}, + "email": {"EMAIL"}, + }.ToPB(), }, &user.User{ Id: "USER_ID", Name: "NAME", @@ -443,7 +466,11 @@ func Test_newUserFromIDPClaims(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - actual := new(incomingIDPTokenSessionCreator).newUserFromIDPClaims(tc.claims) + actual := new(user.User) + if tc.current != nil { + actual = proto.Clone(tc.current).(*user.User) + } + new(incomingIDPTokenSessionCreator).fillUserFromIDPClaims(actual, tc.claims) testutil.AssertProtoEqual(t, tc.expect, actual) }) } @@ -476,8 +503,7 @@ func TestIncomingIDPTokenSessionCreator_CreateSession(t *testing.T) { require.NoError(t, err) req.Header.Set(httputil.HeaderPomeriumIDPAccessToken, "ACCESS_TOKEN") c := NewIncomingIDPTokenSessionCreator( - func(_ context.Context, recordType, _ string) (*databroker.Record, error) { - assert.Equal(t, "type.googleapis.com/session.Session", recordType) + func(_ context.Context, _, _ string) (*databroker.Record, error) { return nil, storage.ErrNotFound }, func(_ context.Context, records []*databroker.Record) error { @@ -518,8 +544,7 @@ func TestIncomingIDPTokenSessionCreator_CreateSession(t *testing.T) { require.NoError(t, err) req.Header.Set(httputil.HeaderPomeriumIDPIdentityToken, "IDENTITY_TOKEN") c := NewIncomingIDPTokenSessionCreator( - func(_ context.Context, recordType, _ string) (*databroker.Record, error) { - assert.Equal(t, "type.googleapis.com/session.Session", recordType) + func(_ context.Context, _, _ string) (*databroker.Record, error) { return nil, storage.ErrNotFound }, func(_ context.Context, records []*databroker.Record) error {