From 622519e901e16f257e52fe5b6dea26746aa56e20 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Tue, 15 Jul 2025 18:04:36 -0600 Subject: [PATCH] databroker: update identity manager to use route credentials (#5728) ## Summary Currently when we refresh sessions we always use the global IdP credentials. This PR updates the identity manager to use route settings when defined. To do this a new `idp_id` field is added to the session stored in the databroker. ## Related issues - [ENG-2595](https://linear.app/pomerium/issue/ENG-2595/refresh-using-custom-idp-uses-wrong-credentials) - https://github.com/pomerium/pomerium/issues/4759 ## Checklist - [x] reference any related issues - [x] updated unit tests - [x] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [x] ready for review --- config/identity.go | 33 +++++++++++--- config/options.go | 9 ++-- config/session.go | 7 +-- config/session_test.go | 2 +- databroker/cache.go | 24 +++-------- internal/authenticateflow/stateful.go | 20 ++++----- internal/authenticateflow/stateless.go | 2 +- pkg/grpc/session/session.go | 8 ++++ pkg/grpc/session/session.pb.go | 60 +++++++++++++++----------- pkg/grpc/session/session.proto | 47 ++++++++++---------- pkg/identity/manager/config.go | 39 +++++++++++++++-- pkg/identity/manager/manager.go | 20 ++++----- pkg/ssh/auth.go | 35 ++++++++------- proxy/data.go | 2 +- 14 files changed, 185 insertions(+), 123 deletions(-) diff --git a/config/identity.go b/config/identity.go index 644f58605..3ad0863fd 100644 --- a/config/identity.go +++ b/config/identity.go @@ -1,15 +1,19 @@ package config import ( + "context" "slices" + oteltrace "go.opentelemetry.io/otel/trace" + "github.com/pomerium/pomerium/internal/urlutil" - "github.com/pomerium/pomerium/pkg/grpc/identity" + identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" + "github.com/pomerium/pomerium/pkg/identity" ) // GetIdentityProviderForID returns the identity provider associated with the given IDP id. // If none is found the default provider is returned. -func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, error) { +func (o *Options) GetIdentityProviderForID(idpID string) (*identitypb.Provider, error) { for p := range o.GetAllPolicies() { idp, err := o.GetIdentityProviderForPolicy(p) if err != nil { @@ -25,7 +29,7 @@ func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, er // GetIdentityProviderForPolicy gets the identity provider associated with the given policy. // If policy is nil, or changes none of the default settings, the default provider is returned. -func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provider, error) { +func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identitypb.Provider, error) { clientSecret, err := o.GetClientSecret() if err != nil { return nil, err @@ -36,7 +40,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid return nil, err } - idp := &identity.Provider{ + idp := &identitypb.Provider{ AuthenticateServiceUrl: authenticateURL.String(), ClientId: o.ClientID, ClientSecret: clientSecret, @@ -46,7 +50,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid RequestParams: o.RequestParams, } if v := o.IDPAccessTokenAllowedAudiences; v != nil { - idp.AccessTokenAllowedAudiences = &identity.Provider_StringList{ + idp.AccessTokenAllowedAudiences = &identitypb.Provider_StringList{ Values: slices.Clone(*v), } } @@ -58,7 +62,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid idp.ClientSecret = policy.IDPClientSecret } if v := policy.IDPAccessTokenAllowedAudiences; v != nil { - idp.AccessTokenAllowedAudiences = &identity.Provider_StringList{ + idp.AccessTokenAllowedAudiences = &identitypb.Provider_StringList{ Values: slices.Clone(*v), } } @@ -68,7 +72,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid } // GetIdentityProviderForRequestURL gets the identity provider associated with the given request URL. -func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity.Provider, error) { +func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identitypb.Provider, error) { u, err := urlutil.ParseAndValidateURL(requestURL) if err != nil { return nil, err @@ -81,3 +85,18 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity } return o.GetIdentityProviderForPolicy(nil) } + +// GetAuthenticator gets the authenticator for the given IDP id. +func (o *Options) GetAuthenticator(ctx context.Context, tracerProvider oteltrace.TracerProvider, idpID string) (identity.Authenticator, error) { + redirectURL, err := o.GetAuthenticateRedirectURL() + if err != nil { + return nil, err + } + + idp, err := o.GetIdentityProviderForID(idpID) + if err != nil { + return nil, err + } + + return identity.GetIdentityProvider(ctx, tracerProvider, idp, redirectURL) +} diff --git a/config/options.go b/config/options.go index 798dc7713..ece99a3ec 100644 --- a/config/options.go +++ b/config/options.go @@ -851,13 +851,14 @@ func (o *Options) GetAuthenticateRedirectURL() (*url.URL, error) { // flow should be used (i.e. for hosted authenticate). func (o *Options) UseStatelessAuthenticateFlow() bool { if flow := os.Getenv("DEBUG_FORCE_AUTHENTICATE_FLOW"); flow != "" { - if flow == "stateless" { + switch flow { + case "stateless": return true - } else if flow == "stateful" { + case "stateful": return false + default: + log.Error().Msgf("ignoring unknown DEBUG_FORCE_AUTHENTICATE_FLOW setting %q", flow) } - log.Error(). - Msgf("ignoring unknown DEBUG_FORCE_AUTHENTICATE_FLOW setting %q", flow) } u, err := o.GetInternalAuthenticateURL() if err != nil { diff --git a/config/session.go b/config/session.go index f498a9e7f..890d5c55a 100644 --- a/config/session.go +++ b/config/session.go @@ -240,7 +240,7 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken( return nil, fmt.Errorf("%w: invalid access token", sessions.ErrInvalidSession) } - s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims) + s = c.newSessionFromIDPClaims(cfg, idp.Id, sessionID, res.Claims) s.OauthToken = &session.OAuthToken{ TokenType: "Bearer", AccessToken: rawAccessToken, @@ -309,7 +309,7 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken( return nil, fmt.Errorf("%w: invalid identity token", sessions.ErrInvalidSession) } - s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims) + s = c.newSessionFromIDPClaims(cfg, idp.Id, sessionID, res.Claims) s.SetRawIDToken(rawIdentityToken) u, err := c.getUser(ctx, s.GetUserId()) @@ -338,11 +338,12 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken( func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims( cfg *Config, + idpID string, sessionID string, claims jwtutil.Claims, ) *session.Session { now := c.timeNow() - s := new(session.Session) + s := session.New(idpID, sessionID) s.Id = sessionID if userID, ok := claims.GetUserID(); ok { s.UserId = userID diff --git a/config/session_test.go b/config/session_test.go index a23d78eae..d5306a7a0 100644 --- a/config/session_test.go +++ b/config/session_test.go @@ -375,7 +375,7 @@ func Test_newSessionFromIDPClaims(t *testing.T) { c := &incomingIDPTokenSessionCreator{ timeNow: func() time.Time { return tm1 }, } - actual := c.newSessionFromIDPClaims(cfg, tc.sessionID, tc.claims) + actual := c.newSessionFromIDPClaims(cfg, "", tc.sessionID, tc.claims) testutil.AssertProtoEqual(t, tc.expect, actual) }) } diff --git a/databroker/cache.go b/databroker/cache.go index 52772af22..5dcee085f 100644 --- a/databroker/cache.go +++ b/databroker/cache.go @@ -171,7 +171,7 @@ func (c *DataBroker) Run(ctx context.Context) error { return eg.Wait() } -func (c *DataBroker) update(ctx context.Context, cfg *config.Config) error { +func (c *DataBroker) update(_ context.Context, cfg *config.Config) error { if err := validate(cfg.Options); err != nil { return fmt.Errorf("databroker: bad option: %w", err) } @@ -182,29 +182,19 @@ func (c *DataBroker) update(ctx context.Context, cfg *config.Config) error { } c.sharedKey.Store(sharedKey) - oauthOptions, err := cfg.Options.GetOauthOptions() - if err != nil { - return fmt.Errorf("databroker: invalid oauth options: %w", err) - } - dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection) options := append([]manager.Option{ manager.WithDataBrokerClient(dataBrokerClient), manager.WithEventManager(c.eventsMgr), + manager.WithCachedGetAuthenticator(func(ctx context.Context, idpID string) (identity.Authenticator, error) { + if !cfg.Options.SupportsUserRefresh() { + return nil, fmt.Errorf("disabling refresh of user sessions") + } + return cfg.Options.GetAuthenticator(ctx, c.tracerProvider, idpID) + }), }, c.managerOptions...) - if cfg.Options.SupportsUserRefresh() { - authenticator, err := identity.NewAuthenticator(ctx, c.tracerProvider, oauthOptions) - if err != nil { - log.Ctx(ctx).Error().Err(err).Msg("databroker: failed to create authenticator") - } else { - options = append(options, manager.WithAuthenticator(authenticator)) - } - } else { - log.Ctx(ctx).Info().Msg("databroker: disabling refresh of user sessions") - } - if c.manager == nil { c.manager = manager.New(options...) } else { diff --git a/internal/authenticateflow/stateful.go b/internal/authenticateflow/stateful.go index f534a0019..a96668856 100644 --- a/internal/authenticateflow/stateful.go +++ b/internal/authenticateflow/stateful.go @@ -189,15 +189,13 @@ func (s *Stateful) PersistSession( now := timeNow() sessionExpiry := timestamppb.New(now.Add(s.sessionDuration)) - sess := &session.Session{ - Id: sessionState.ID, - UserId: sessionState.UserID(), - IssuedAt: timestamppb.New(now), - AccessedAt: timestamppb.New(now), - ExpiresAt: sessionExpiry, - OauthToken: manager.ToOAuthToken(accessToken), - Audience: sessionState.Audience, - } + sess := session.New(sessionState.IdentityProviderID, sessionState.ID) + sess.UserId = sessionState.UserID() + sess.IssuedAt = timestamppb.New(now) + sess.AccessedAt = timestamppb.New(now) + sess.ExpiresAt = sessionExpiry + sess.OauthToken = manager.ToOAuthToken(accessToken) + sess.Audience = sessionState.Audience sess.SetRawIDToken(claims.RawIDToken) sess.AddClaims(claims.Flatten()) @@ -236,9 +234,7 @@ func (s *Stateful) GetUserInfoData( isImpersonated = true } if err != nil { - pbSession = &session.Session{ - Id: sessionState.ID, - } + pbSession = session.New(sessionState.IdentityProviderID, sessionState.ID) } pbUser, err := user.Get(r.Context(), s.dataBrokerClient, pbSession.GetUserId()) diff --git a/internal/authenticateflow/stateless.go b/internal/authenticateflow/stateless.go index e9d26b056..54a30fcf3 100644 --- a/internal/authenticateflow/stateless.go +++ b/internal/authenticateflow/stateless.go @@ -415,7 +415,7 @@ func (s *Stateless) Callback(w http.ResponseWriter, r *http.Request) error { ss := newSessionStateFromProfile(profile) sess, err := session.Get(r.Context(), s.dataBrokerClient, ss.ID) if err != nil { - sess = &session.Session{Id: ss.ID} + sess = session.New(ss.IdentityProviderID, ss.ID) } populateSessionFromProfile(sess, profile, ss, s.options.CookieExpire) u, err := user.Get(r.Context(), s.dataBrokerClient, ss.UserID()) diff --git a/pkg/grpc/session/session.go b/pkg/grpc/session/session.go index ff4594da7..527020676 100644 --- a/pkg/grpc/session/session.go +++ b/pkg/grpc/session/session.go @@ -83,6 +83,14 @@ func Patch( return res, err } +// New creates a new Session. +func New(idpID, id string) *Session { + return &Session{ + IdpId: idpID, + Id: id, + } +} + // AddClaims adds the flattened claims to the session. func (x *Session) AddClaims(claims identity.FlattenedClaims) { if x.Claims == nil { diff --git a/pkg/grpc/session/session.pb.go b/pkg/grpc/session/session.pb.go index d13933191..558ca0dd9 100644 --- a/pkg/grpc/session/session.pb.go +++ b/pkg/grpc/session/session.pb.go @@ -190,6 +190,7 @@ type Session struct { Claims map[string]*structpb.ListValue `protobuf:"bytes,9,rep,name=claims,proto3" json:"claims,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` Audience []string `protobuf:"bytes,10,rep,name=audience,proto3" json:"audience,omitempty"` RefreshDisabled bool `protobuf:"varint,19,opt,name=refresh_disabled,json=refreshDisabled,proto3" json:"refresh_disabled,omitempty"` + IdpId string `protobuf:"bytes,20,opt,name=idp_id,json=idpId,proto3" json:"idp_id,omitempty"` ImpersonateSessionId *string `protobuf:"bytes,15,opt,name=impersonate_session_id,json=impersonateSessionId,proto3,oneof" json:"impersonate_session_id,omitempty"` } @@ -309,6 +310,13 @@ func (x *Session) GetRefreshDisabled() bool { return false } +func (x *Session) GetIdpId() string { + if x != nil { + return x.IdpId + } + return "" +} + func (x *Session) GetImpersonateSessionId() string { if x != nil && x.ImpersonateSessionId != nil { return *x.ImpersonateSessionId @@ -438,7 +446,7 @@ var file_session_proto_rawDesc = []byte{ 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, 0x66, - 0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xe6, 0x06, 0x0a, 0x07, 0x53, 0x65, + 0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xfd, 0x06, 0x0a, 0x07, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, @@ -473,30 +481,32 @@ var file_session_proto_rawDesc = []byte{ 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x29, 0x0a, 0x10, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x13, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x72, 0x65, 0x66, 0x72, - 0x65, 0x73, 0x68, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x39, 0x0a, 0x16, 0x69, - 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, - 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x14, 0x69, - 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, - 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x1a, 0x87, 0x01, 0x0a, 0x10, 0x44, 0x65, 0x76, 0x69, 0x63, - 0x65, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x12, 0x17, 0x0a, 0x07, 0x74, - 0x79, 0x70, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x79, - 0x70, 0x65, 0x49, 0x64, 0x12, 0x3a, 0x0a, 0x0b, 0x75, 0x6e, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, - 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x48, 0x00, 0x52, 0x0b, 0x75, 0x6e, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, - 0x12, 0x10, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x02, - 0x69, 0x64, 0x42, 0x0c, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, - 0x1a, 0x55, 0x0a, 0x0b, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, - 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, - 0x79, 0x12, 0x30, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x05, 0x76, 0x61, - 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x19, 0x0a, 0x17, 0x5f, 0x69, 0x6d, 0x70, 0x65, - 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, - 0x69, 0x64, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, - 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, - 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x73, 0x65, 0x73, 0x73, - 0x69, 0x6f, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x73, 0x68, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x15, 0x0a, 0x06, 0x69, + 0x64, 0x70, 0x5f, 0x69, 0x64, 0x18, 0x14, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x69, 0x64, 0x70, + 0x49, 0x64, 0x12, 0x39, 0x0a, 0x16, 0x69, 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, + 0x65, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01, + 0x28, 0x09, 0x48, 0x00, 0x52, 0x14, 0x69, 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, + 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x1a, 0x87, 0x01, + 0x0a, 0x10, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, + 0x61, 0x6c, 0x12, 0x17, 0x0a, 0x07, 0x74, 0x79, 0x70, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x79, 0x70, 0x65, 0x49, 0x64, 0x12, 0x3a, 0x0a, 0x0b, 0x75, + 0x6e, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x48, 0x00, 0x52, 0x0b, 0x75, 0x6e, 0x61, 0x76, + 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x10, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x02, 0x69, 0x64, 0x42, 0x0c, 0x0a, 0x0a, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x1a, 0x55, 0x0a, 0x0b, 0x43, 0x6c, 0x61, 0x69, 0x6d, + 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x56, 0x61, + 0x6c, 0x75, 0x65, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x19, + 0x0a, 0x17, 0x5f, 0x69, 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73, + 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x69, 0x74, + 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, + 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, + 0x70, 0x63, 0x2f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( diff --git a/pkg/grpc/session/session.proto b/pkg/grpc/session/session.proto index 8101046fa..8eb0bd394 100644 --- a/pkg/grpc/session/session.proto +++ b/pkg/grpc/session/session.proto @@ -1,48 +1,49 @@ syntax = "proto3"; package session; -option go_package = "github.com/pomerium/pomerium/pkg/grpc/session"; +option go_package = "github.com/pomerium/pomerium/pkg/grpc/session"; import "google/protobuf/empty.proto"; import "google/protobuf/struct.proto"; import "google/protobuf/timestamp.proto"; message IDToken { - string issuer = 1; - string subject = 2; + string issuer = 1; + string subject = 2; google.protobuf.Timestamp expires_at = 3; - google.protobuf.Timestamp issued_at = 4; - string raw = 5; + google.protobuf.Timestamp issued_at = 4; + string raw = 5; } message OAuthToken { - string access_token = 1; - string token_type = 2; - google.protobuf.Timestamp expires_at = 3; - string refresh_token = 4; + string access_token = 1; + string token_type = 2; + google.protobuf.Timestamp expires_at = 3; + string refresh_token = 4; } message Session { message DeviceCredential { string type_id = 1; - oneof credential { + oneof credential { google.protobuf.Empty unavailable = 2; - string id = 3; + string id = 3; } } - string version = 1; - string id = 2; - string user_id = 3; - repeated DeviceCredential device_credentials = 17; - google.protobuf.Timestamp issued_at = 14; - google.protobuf.Timestamp expires_at = 4; - google.protobuf.Timestamp accessed_at = 18; - IDToken id_token = 6; - OAuthToken oauth_token = 7; - map claims = 9; - repeated string audience = 10; - bool refresh_disabled = 19; + string version = 1; + string id = 2; + string user_id = 3; + repeated DeviceCredential device_credentials = 17; + google.protobuf.Timestamp issued_at = 14; + google.protobuf.Timestamp expires_at = 4; + google.protobuf.Timestamp accessed_at = 18; + IDToken id_token = 6; + OAuthToken oauth_token = 7; + map claims = 9; + repeated string audience = 10; + bool refresh_disabled = 19; + string idp_id = 20; optional string impersonate_session_id = 15; } diff --git a/pkg/identity/manager/config.go b/pkg/identity/manager/config.go index 71ebc30f7..f3f314283 100644 --- a/pkg/identity/manager/config.go +++ b/pkg/identity/manager/config.go @@ -1,10 +1,14 @@ package manager import ( + "context" + "fmt" + "sync" "time" "github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/identity" ) var ( @@ -15,7 +19,6 @@ var ( ) type config struct { - authenticator Authenticator dataBrokerClient databroker.DataBrokerServiceClient sessionRefreshGracePeriod time.Duration sessionRefreshCoolOffDuration time.Duration @@ -23,10 +26,14 @@ type config struct { leaseTTL time.Duration now func() time.Time eventMgr *events.Manager + getAuthenticator func(ctx context.Context, idpID string) (identity.Authenticator, error) } func newConfig(options ...Option) *config { cfg := new(config) + WithGetAuthenticator(func(_ context.Context, _ string) (identity.Authenticator, error) { + return nil, fmt.Errorf("authenticator not configured") + }) WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg) WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg) WithNow(time.Now)(cfg) @@ -41,10 +48,27 @@ func newConfig(options ...Option) *config { // An Option customizes the configuration used for the identity manager. type Option func(*config) -// WithAuthenticator sets the authenticator in the config. -func WithAuthenticator(authenticator Authenticator) Option { +// WithCachedGetAuthenticator sets the get authenticator function in the config. +func WithCachedGetAuthenticator(getAuthenticator func(ctx context.Context, idpID string) (identity.Authenticator, error)) Option { return func(cfg *config) { - cfg.authenticator = authenticator + var mu sync.Mutex + type state struct { + authenticator identity.Authenticator + err error + } + lookup := map[string]state{} + cfg.getAuthenticator = func(ctx context.Context, idpID string) (identity.Authenticator, error) { + mu.Lock() + defer mu.Unlock() + + s, ok := lookup[idpID] + if !ok { + s.authenticator, s.err = getAuthenticator(ctx, idpID) + lookup[idpID] = s + } + + return s.authenticator, s.err + } } } @@ -55,6 +79,13 @@ func WithDataBrokerClient(dataBrokerClient databroker.DataBrokerServiceClient) O } } +// WithGetAuthenticator sets the get authenticator function in the config. +func WithGetAuthenticator(getAuthenticator func(ctx context.Context, idpID string) (identity.Authenticator, error)) Option { + return func(cfg *config) { + cfg.getAuthenticator = getAuthenticator + } +} + // WithSessionRefreshGracePeriod sets the session refresh grace period used by the manager. func WithSessionRefreshGracePeriod(dur time.Duration) Option { return func(cfg *config) { diff --git a/pkg/identity/manager/manager.go b/pkg/identity/manager/manager.go index 3f3744125..d03947d4a 100644 --- a/pkg/identity/manager/manager.go +++ b/pkg/identity/manager/manager.go @@ -94,7 +94,7 @@ func (mgr *Manager) RunLeased(ctx context.Context) error { } func (mgr *Manager) onDeleteAllSessions(ctx context.Context) { - log.Ctx(ctx).Debug().Msg("all session deleted") + log.Ctx(ctx).Debug().Msg("all sessions deleted") mgr.mu.Lock() mgr.dataStore.deleteAllSessions() @@ -197,9 +197,10 @@ func (mgr *Manager) refreshSession(ctx context.Context, sessionID string) { return } - authenticator := mgr.cfg.Load().authenticator - if authenticator == nil { + authenticator, err := mgr.cfg.Load().getAuthenticator(ctx, s.GetIdpId()) + if err != nil { log.Ctx(ctx).Info(). + Err(err). Str("user-id", s.GetUserId()). Str("session-id", s.GetId()). Msg("no authenticator defined, deleting session") @@ -275,12 +276,6 @@ func (mgr *Manager) refreshSession(ctx context.Context, sessionID string) { func (mgr *Manager) updateUserInfo(ctx context.Context, userID string) { log.Ctx(ctx).Info().Str("user-id", userID).Msg("updating user info") - - authenticator := mgr.cfg.Load().authenticator - if authenticator == nil { - return - } - mgr.mu.Lock() u, ss := mgr.dataStore.getUserAndSessions(userID) mgr.mu.Unlock() @@ -301,7 +296,12 @@ func (mgr *Manager) updateUserInfo(ctx context.Context, userID string) { continue } - err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.GetOauthToken()), newUserUnmarshaler(u)) + authenticator, err := mgr.cfg.Load().getAuthenticator(ctx, s.GetIdpId()) + if err != nil { + continue + } + + err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.GetOauthToken()), newUserUnmarshaler(u)) metrics.RecordIdentityManagerUserRefresh(ctx, err) mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err) if isTemporaryError(err) { diff --git a/pkg/ssh/auth.go b/pkg/ssh/auth.go index d292d1d71..357665724 100644 --- a/pkg/ssh/auth.go +++ b/pkg/ssh/auth.go @@ -23,6 +23,7 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/pkg/grpc/databroker" + identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/identity" @@ -202,7 +203,7 @@ func (a *Auth) handleLogin( querier KeyboardInteractiveQuerier, ) error { // Initiate the IdP login flow. - authenticator, err := a.getAuthenticator(ctx, hostname) + idp, authenticator, err := a.getAuthenticator(ctx, hostname) if err != nil { return err } @@ -228,7 +229,7 @@ func (a *Auth) handleLogin( if err != nil { return err } - return a.saveSession(ctx, sessionID, &sessionClaims, token) + return a.saveSession(ctx, idp.Id, sessionID, &sessionClaims, token) } var errAccessDenied = status.Error(codes.PermissionDenied, "access denied") @@ -320,6 +321,7 @@ func (a *Auth) DeleteSession(ctx context.Context, info StreamAuthInfo) error { func (a *Auth) saveSession( ctx context.Context, + idpID, id string, claims *identity.SessionClaims, token *oauth2.Token, @@ -333,15 +335,13 @@ func (a *Auth) saveSession( return err } - sess := &session.Session{ - Id: id, - UserId: state.UserID(), - IssuedAt: nowpb, - AccessedAt: nowpb, - ExpiresAt: timestamppb.New(now.Add(sessionLifetime)), - OauthToken: manager.ToOAuthToken(token), - Audience: state.Audience, - } + sess := session.New(idpID, id) + sess.UserId = state.UserID() + sess.IssuedAt = nowpb + sess.AccessedAt = nowpb + sess.ExpiresAt = timestamppb.New(now.Add(sessionLifetime)) + sess.OauthToken = manager.ToOAuthToken(token) + sess.Audience = state.Audience sess.SetRawIDToken(claims.RawIDToken) sess.AddClaims(claims.Flatten()) @@ -367,20 +367,25 @@ func (a *Auth) saveSession( return nil } -func (a *Auth) getAuthenticator(ctx context.Context, hostname string) (identity.Authenticator, error) { +func (a *Auth) getAuthenticator(ctx context.Context, hostname string) (*identitypb.Provider, identity.Authenticator, error) { opts := a.currentConfig.Load().Options redirectURL, err := opts.GetAuthenticateRedirectURL() if err != nil { - return nil, err + return nil, nil, err } idp, err := opts.GetIdentityProviderForPolicy(opts.GetRouteForSSHHostname(hostname)) if err != nil { - return nil, err + return nil, nil, err } - return identity.GetIdentityProvider(ctx, a.tracerProvider, idp, redirectURL) + authenticator, err := identity.GetIdentityProvider(ctx, a.tracerProvider, idp, redirectURL) + if err != nil { + return nil, nil, err + } + + return idp, authenticator, nil } var _ AuthInterface = (*Auth)(nil) diff --git a/proxy/data.go b/proxy/data.go index 3fad0f412..c82f50ee4 100644 --- a/proxy/data.go +++ b/proxy/data.go @@ -52,7 +52,7 @@ func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData { if err == nil { data.Session, data.IsImpersonated, err = p.getSession(r.Context(), ss.ID) if err != nil { - data.Session = &session.Session{Id: ss.ID} + data.Session = session.New(ss.IdentityProviderID, ss.ID) } data.User, err = p.getUser(r.Context(), data.Session.GetUserId())