mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-23 19:49:13 +02:00
config: preserve existing user when creating sessions from idp token (#5502)
* config: preserve existing user when creating sessions from idp token * fix
This commit is contained in:
parent
932db70d96
commit
cb5ee48323
2 changed files with 77 additions and 17 deletions
|
@ -2,6 +2,7 @@ package config
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
@ -210,7 +211,15 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
|
|||
AccessToken: rawAccessToken,
|
||||
ExpiresAt: s.ExpiresAt,
|
||||
}
|
||||
u := c.newUserFromIDPClaims(res.Claims)
|
||||
|
||||
u, err := c.getUser(ctx, s.GetUserId())
|
||||
if storage.IsNotFound(err) {
|
||||
u = &user.User{Id: s.GetUserId()}
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving existing user: %w", err)
|
||||
}
|
||||
c.fillUserFromIDPClaims(u, res.Claims)
|
||||
|
||||
err = c.putSessionAndUser(ctx, s, u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error saving session and user: %w", err)
|
||||
|
@ -261,7 +270,15 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
|
|||
|
||||
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||
s.SetRawIDToken(rawIdentityToken)
|
||||
u := c.newUserFromIDPClaims(res.Claims)
|
||||
|
||||
u, err := c.getUser(ctx, s.GetUserId())
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
u = &user.User{Id: s.GetUserId()}
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving existing user: %w", err)
|
||||
}
|
||||
c.fillUserFromIDPClaims(u, res.Claims)
|
||||
|
||||
err = c.putSessionAndUser(ctx, s, u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error saving session and user: %w", err)
|
||||
|
@ -305,10 +322,10 @@ func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
|
|||
return s
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) newUserFromIDPClaims(
|
||||
func (c *incomingIDPTokenSessionCreator) fillUserFromIDPClaims(
|
||||
u *user.User,
|
||||
claims jwtutil.Claims,
|
||||
) *user.User {
|
||||
u := new(user.User)
|
||||
) {
|
||||
if userID, ok := claims.GetUserID(); ok {
|
||||
u.Id = userID
|
||||
}
|
||||
|
@ -318,8 +335,7 @@ func (c *incomingIDPTokenSessionCreator) newUserFromIDPClaims(
|
|||
if email, ok := claims.GetString("email"); ok {
|
||||
u.Email = email
|
||||
}
|
||||
u.Claims = identity.Claims(claims).Flatten().ToPB()
|
||||
return u
|
||||
u.AddClaims(identity.Claims(claims).Flatten())
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, sessionID string) (*session.Session, error) {
|
||||
|
@ -341,6 +357,25 @@ func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, session
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) getUser(ctx context.Context, userID string) (*user.User, error) {
|
||||
record, err := c.getRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := record.GetData().UnmarshalNew()
|
||||
if err != nil {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
u, ok := msg.(*user.User)
|
||||
if !ok {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) putSessionAndUser(ctx context.Context, s *session.Session, u *user.User) error {
|
||||
var records []*databroker.Record
|
||||
if id := s.GetId(); id != "" {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue