diff --git a/authorize/ssh_grpc.go b/authorize/ssh_grpc.go index 545464d84..8ea28566d 100644 --- a/authorize/ssh_grpc.go +++ b/authorize/ssh_grpc.go @@ -31,6 +31,7 @@ import ( "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/identity" "github.com/pomerium/pomerium/pkg/identity/manager" @@ -667,10 +668,22 @@ func (a *Authorize) PersistSession( sess.SetRawIDToken(claims.RawIDToken) sess.AddClaims(claims.Flatten()) - // XXX: do we need to create a user record too? - // compare with Stateful.PersistSession() + client := a.GetDataBrokerServiceClient() - res, err := session.Put(ctx, a.GetDataBrokerServiceClient(), sess) + u, _ := user.Get(ctx, client, sess.GetUserId()) + if u == nil { + // if no user exists yet, create a new one + u = &user.User{ + Id: sess.GetUserId(), + } + } + u.PopulateFromClaims(claims.Claims) + _, err := databroker.Put(ctx, client, u) + if err != nil { + return nil, err + } + + res, err := session.Put(ctx, client, sess) if err != nil { return nil, err } diff --git a/pkg/grpc/user/user.go b/pkg/grpc/user/user.go index ecfc80413..9d399b997 100644 --- a/pkg/grpc/user/user.go +++ b/pkg/grpc/user/user.go @@ -58,6 +58,22 @@ func (x *User) AddClaims(claims identity.FlattenedClaims) { } } +// TODO: consolidate with AddClaims? +func (u *User) PopulateFromClaims(claims map[string]any) { + if v, ok := claims["name"]; ok { + u.Name = fmt.Sprint(v) + } + if v, ok := claims["email"]; ok { + u.Email = fmt.Sprint(v) + } + if u.Claims == nil { + u.Claims = make(map[string]*structpb.ListValue) + } + for k, vs := range identity.Claims(claims).Flatten().ToPB() { + u.Claims[k] = vs + } +} + // GetClaim returns a claim. // // This method is used by the dashboard template HTML to display claim data.