mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-07 22:36:05 +02:00
config: preserve existing user when creating sessions from idp token (#5502)
* config: preserve existing user when creating sessions from idp token * fix
This commit is contained in:
parent
932db70d96
commit
cb5ee48323
2 changed files with 77 additions and 17 deletions
|
@ -2,6 +2,7 @@ package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -210,7 +211,15 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
|
||||||
AccessToken: rawAccessToken,
|
AccessToken: rawAccessToken,
|
||||||
ExpiresAt: s.ExpiresAt,
|
ExpiresAt: s.ExpiresAt,
|
||||||
}
|
}
|
||||||
u := c.newUserFromIDPClaims(res.Claims)
|
|
||||||
|
u, err := c.getUser(ctx, s.GetUserId())
|
||||||
|
if storage.IsNotFound(err) {
|
||||||
|
u = &user.User{Id: s.GetUserId()}
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, fmt.Errorf("error retrieving existing user: %w", err)
|
||||||
|
}
|
||||||
|
c.fillUserFromIDPClaims(u, res.Claims)
|
||||||
|
|
||||||
err = c.putSessionAndUser(ctx, s, u)
|
err = c.putSessionAndUser(ctx, s, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error saving session and user: %w", err)
|
return nil, fmt.Errorf("error saving session and user: %w", err)
|
||||||
|
@ -261,7 +270,15 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
|
||||||
|
|
||||||
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||||
s.SetRawIDToken(rawIdentityToken)
|
s.SetRawIDToken(rawIdentityToken)
|
||||||
u := c.newUserFromIDPClaims(res.Claims)
|
|
||||||
|
u, err := c.getUser(ctx, s.GetUserId())
|
||||||
|
if errors.Is(err, storage.ErrNotFound) {
|
||||||
|
u = &user.User{Id: s.GetUserId()}
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, fmt.Errorf("error retrieving existing user: %w", err)
|
||||||
|
}
|
||||||
|
c.fillUserFromIDPClaims(u, res.Claims)
|
||||||
|
|
||||||
err = c.putSessionAndUser(ctx, s, u)
|
err = c.putSessionAndUser(ctx, s, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error saving session and user: %w", err)
|
return nil, fmt.Errorf("error saving session and user: %w", err)
|
||||||
|
@ -305,10 +322,10 @@ func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *incomingIDPTokenSessionCreator) newUserFromIDPClaims(
|
func (c *incomingIDPTokenSessionCreator) fillUserFromIDPClaims(
|
||||||
|
u *user.User,
|
||||||
claims jwtutil.Claims,
|
claims jwtutil.Claims,
|
||||||
) *user.User {
|
) {
|
||||||
u := new(user.User)
|
|
||||||
if userID, ok := claims.GetUserID(); ok {
|
if userID, ok := claims.GetUserID(); ok {
|
||||||
u.Id = userID
|
u.Id = userID
|
||||||
}
|
}
|
||||||
|
@ -318,8 +335,7 @@ func (c *incomingIDPTokenSessionCreator) newUserFromIDPClaims(
|
||||||
if email, ok := claims.GetString("email"); ok {
|
if email, ok := claims.GetString("email"); ok {
|
||||||
u.Email = email
|
u.Email = email
|
||||||
}
|
}
|
||||||
u.Claims = identity.Claims(claims).Flatten().ToPB()
|
u.AddClaims(identity.Claims(claims).Flatten())
|
||||||
return u
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, sessionID string) (*session.Session, error) {
|
func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, sessionID string) (*session.Session, error) {
|
||||||
|
@ -341,6 +357,25 @@ func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, session
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *incomingIDPTokenSessionCreator) getUser(ctx context.Context, userID string) (*user.User, error) {
|
||||||
|
record, err := c.getRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := record.GetData().UnmarshalNew()
|
||||||
|
if err != nil {
|
||||||
|
return nil, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
u, ok := msg.(*user.User)
|
||||||
|
if !ok {
|
||||||
|
return nil, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *incomingIDPTokenSessionCreator) putSessionAndUser(ctx context.Context, s *session.Session, u *user.User) error {
|
func (c *incomingIDPTokenSessionCreator) putSessionAndUser(ctx context.Context, s *session.Session, u *user.User) error {
|
||||||
var records []*databroker.Record
|
var records []*databroker.Record
|
||||||
if id := s.GetId(); id != "" {
|
if id := s.GetId(); id != "" {
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||||
|
@ -416,19 +417,41 @@ func Test_newSessionFromIDPClaims(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_newUserFromIDPClaims(t *testing.T) {
|
func Test_fillUserFromIDPClaims(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
for _, tc := range []struct {
|
for _, tc := range []struct {
|
||||||
name string
|
name string
|
||||||
claims jwtutil.Claims
|
claims jwtutil.Claims
|
||||||
expect *user.User
|
current *user.User
|
||||||
|
expect *user.User
|
||||||
}{
|
}{
|
||||||
{"empty claims", nil, &user.User{}},
|
{"empty claims", nil, nil, &user.User{}},
|
||||||
{"full claims", jwtutil.Claims{
|
{"full claims", jwtutil.Claims{
|
||||||
"sub": "USER_ID",
|
"sub": "USER_ID",
|
||||||
"name": "NAME",
|
"name": "NAME",
|
||||||
"email": "EMAIL",
|
"email": "EMAIL",
|
||||||
|
}, nil, &user.User{
|
||||||
|
Id: "USER_ID",
|
||||||
|
Name: "NAME",
|
||||||
|
Email: "EMAIL",
|
||||||
|
Claims: identity.FlattenedClaims{
|
||||||
|
"sub": {"USER_ID"},
|
||||||
|
"name": {"NAME"},
|
||||||
|
"email": {"EMAIL"},
|
||||||
|
}.ToPB(),
|
||||||
|
}},
|
||||||
|
{"existing claims", jwtutil.Claims{
|
||||||
|
"sub": "USER_ID",
|
||||||
|
}, &user.User{
|
||||||
|
Id: "USER_ID",
|
||||||
|
Name: "NAME",
|
||||||
|
Email: "EMAIL",
|
||||||
|
Claims: identity.FlattenedClaims{
|
||||||
|
"sub": {"USER_ID"},
|
||||||
|
"name": {"NAME"},
|
||||||
|
"email": {"EMAIL"},
|
||||||
|
}.ToPB(),
|
||||||
}, &user.User{
|
}, &user.User{
|
||||||
Id: "USER_ID",
|
Id: "USER_ID",
|
||||||
Name: "NAME",
|
Name: "NAME",
|
||||||
|
@ -443,7 +466,11 @@ func Test_newUserFromIDPClaims(t *testing.T) {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
actual := new(incomingIDPTokenSessionCreator).newUserFromIDPClaims(tc.claims)
|
actual := new(user.User)
|
||||||
|
if tc.current != nil {
|
||||||
|
actual = proto.Clone(tc.current).(*user.User)
|
||||||
|
}
|
||||||
|
new(incomingIDPTokenSessionCreator).fillUserFromIDPClaims(actual, tc.claims)
|
||||||
testutil.AssertProtoEqual(t, tc.expect, actual)
|
testutil.AssertProtoEqual(t, tc.expect, actual)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -476,8 +503,7 @@ func TestIncomingIDPTokenSessionCreator_CreateSession(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Header.Set(httputil.HeaderPomeriumIDPAccessToken, "ACCESS_TOKEN")
|
req.Header.Set(httputil.HeaderPomeriumIDPAccessToken, "ACCESS_TOKEN")
|
||||||
c := NewIncomingIDPTokenSessionCreator(
|
c := NewIncomingIDPTokenSessionCreator(
|
||||||
func(_ context.Context, recordType, _ string) (*databroker.Record, error) {
|
func(_ context.Context, _, _ string) (*databroker.Record, error) {
|
||||||
assert.Equal(t, "type.googleapis.com/session.Session", recordType)
|
|
||||||
return nil, storage.ErrNotFound
|
return nil, storage.ErrNotFound
|
||||||
},
|
},
|
||||||
func(_ context.Context, records []*databroker.Record) error {
|
func(_ context.Context, records []*databroker.Record) error {
|
||||||
|
@ -518,8 +544,7 @@ func TestIncomingIDPTokenSessionCreator_CreateSession(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Header.Set(httputil.HeaderPomeriumIDPIdentityToken, "IDENTITY_TOKEN")
|
req.Header.Set(httputil.HeaderPomeriumIDPIdentityToken, "IDENTITY_TOKEN")
|
||||||
c := NewIncomingIDPTokenSessionCreator(
|
c := NewIncomingIDPTokenSessionCreator(
|
||||||
func(_ context.Context, recordType, _ string) (*databroker.Record, error) {
|
func(_ context.Context, _, _ string) (*databroker.Record, error) {
|
||||||
assert.Equal(t, "type.googleapis.com/session.Session", recordType)
|
|
||||||
return nil, storage.ErrNotFound
|
return nil, storage.ErrNotFound
|
||||||
},
|
},
|
||||||
func(_ context.Context, records []*databroker.Record) error {
|
func(_ context.Context, records []*databroker.Record) error {
|
||||||
|
|
Loading…
Add table
Reference in a new issue