databroker: update identity manager to use route credentials (#5728)

## Summary
Currently when we refresh sessions we always use the global IdP
credentials. This PR updates the identity manager to use route settings
when defined.

To do this a new `idp_id` field is added to the session stored in the
databroker.

## Related issues
-
[ENG-2595](https://linear.app/pomerium/issue/ENG-2595/refresh-using-custom-idp-uses-wrong-credentials)
- https://github.com/pomerium/pomerium/issues/4759

## Checklist

- [x] reference any related issues
- [x] updated unit tests
- [x] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [x] ready for review
This commit is contained in:
Caleb Doxsey 2025-07-15 18:04:36 -06:00 committed by GitHub
parent e5e799a868
commit 622519e901
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 185 additions and 123 deletions

View file

@ -1,15 +1,19 @@
package config
import (
"context"
"slices"
oteltrace "go.opentelemetry.io/otel/trace"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/identity"
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/identity"
)
// GetIdentityProviderForID returns the identity provider associated with the given IDP id.
// If none is found the default provider is returned.
func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, error) {
func (o *Options) GetIdentityProviderForID(idpID string) (*identitypb.Provider, error) {
for p := range o.GetAllPolicies() {
idp, err := o.GetIdentityProviderForPolicy(p)
if err != nil {
@ -25,7 +29,7 @@ func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, er
// GetIdentityProviderForPolicy gets the identity provider associated with the given policy.
// If policy is nil, or changes none of the default settings, the default provider is returned.
func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provider, error) {
func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identitypb.Provider, error) {
clientSecret, err := o.GetClientSecret()
if err != nil {
return nil, err
@ -36,7 +40,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
return nil, err
}
idp := &identity.Provider{
idp := &identitypb.Provider{
AuthenticateServiceUrl: authenticateURL.String(),
ClientId: o.ClientID,
ClientSecret: clientSecret,
@ -46,7 +50,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
RequestParams: o.RequestParams,
}
if v := o.IDPAccessTokenAllowedAudiences; v != nil {
idp.AccessTokenAllowedAudiences = &identity.Provider_StringList{
idp.AccessTokenAllowedAudiences = &identitypb.Provider_StringList{
Values: slices.Clone(*v),
}
}
@ -58,7 +62,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
idp.ClientSecret = policy.IDPClientSecret
}
if v := policy.IDPAccessTokenAllowedAudiences; v != nil {
idp.AccessTokenAllowedAudiences = &identity.Provider_StringList{
idp.AccessTokenAllowedAudiences = &identitypb.Provider_StringList{
Values: slices.Clone(*v),
}
}
@ -68,7 +72,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
}
// GetIdentityProviderForRequestURL gets the identity provider associated with the given request URL.
func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity.Provider, error) {
func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identitypb.Provider, error) {
u, err := urlutil.ParseAndValidateURL(requestURL)
if err != nil {
return nil, err
@ -81,3 +85,18 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity
}
return o.GetIdentityProviderForPolicy(nil)
}
// GetAuthenticator gets the authenticator for the given IDP id.
func (o *Options) GetAuthenticator(ctx context.Context, tracerProvider oteltrace.TracerProvider, idpID string) (identity.Authenticator, error) {
redirectURL, err := o.GetAuthenticateRedirectURL()
if err != nil {
return nil, err
}
idp, err := o.GetIdentityProviderForID(idpID)
if err != nil {
return nil, err
}
return identity.GetIdentityProvider(ctx, tracerProvider, idp, redirectURL)
}

View file

@ -851,13 +851,14 @@ func (o *Options) GetAuthenticateRedirectURL() (*url.URL, error) {
// flow should be used (i.e. for hosted authenticate).
func (o *Options) UseStatelessAuthenticateFlow() bool {
if flow := os.Getenv("DEBUG_FORCE_AUTHENTICATE_FLOW"); flow != "" {
if flow == "stateless" {
switch flow {
case "stateless":
return true
} else if flow == "stateful" {
case "stateful":
return false
default:
log.Error().Msgf("ignoring unknown DEBUG_FORCE_AUTHENTICATE_FLOW setting %q", flow)
}
log.Error().
Msgf("ignoring unknown DEBUG_FORCE_AUTHENTICATE_FLOW setting %q", flow)
}
u, err := o.GetInternalAuthenticateURL()
if err != nil {

View file

@ -240,7 +240,7 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
return nil, fmt.Errorf("%w: invalid access token", sessions.ErrInvalidSession)
}
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s = c.newSessionFromIDPClaims(cfg, idp.Id, sessionID, res.Claims)
s.OauthToken = &session.OAuthToken{
TokenType: "Bearer",
AccessToken: rawAccessToken,
@ -309,7 +309,7 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
return nil, fmt.Errorf("%w: invalid identity token", sessions.ErrInvalidSession)
}
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
s = c.newSessionFromIDPClaims(cfg, idp.Id, sessionID, res.Claims)
s.SetRawIDToken(rawIdentityToken)
u, err := c.getUser(ctx, s.GetUserId())
@ -338,11 +338,12 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
cfg *Config,
idpID string,
sessionID string,
claims jwtutil.Claims,
) *session.Session {
now := c.timeNow()
s := new(session.Session)
s := session.New(idpID, sessionID)
s.Id = sessionID
if userID, ok := claims.GetUserID(); ok {
s.UserId = userID

View file

@ -375,7 +375,7 @@ func Test_newSessionFromIDPClaims(t *testing.T) {
c := &incomingIDPTokenSessionCreator{
timeNow: func() time.Time { return tm1 },
}
actual := c.newSessionFromIDPClaims(cfg, tc.sessionID, tc.claims)
actual := c.newSessionFromIDPClaims(cfg, "", tc.sessionID, tc.claims)
testutil.AssertProtoEqual(t, tc.expect, actual)
})
}

View file

@ -171,7 +171,7 @@ func (c *DataBroker) Run(ctx context.Context) error {
return eg.Wait()
}
func (c *DataBroker) update(ctx context.Context, cfg *config.Config) error {
func (c *DataBroker) update(_ context.Context, cfg *config.Config) error {
if err := validate(cfg.Options); err != nil {
return fmt.Errorf("databroker: bad option: %w", err)
}
@ -182,29 +182,19 @@ func (c *DataBroker) update(ctx context.Context, cfg *config.Config) error {
}
c.sharedKey.Store(sharedKey)
oauthOptions, err := cfg.Options.GetOauthOptions()
if err != nil {
return fmt.Errorf("databroker: invalid oauth options: %w", err)
}
dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection)
options := append([]manager.Option{
manager.WithDataBrokerClient(dataBrokerClient),
manager.WithEventManager(c.eventsMgr),
manager.WithCachedGetAuthenticator(func(ctx context.Context, idpID string) (identity.Authenticator, error) {
if !cfg.Options.SupportsUserRefresh() {
return nil, fmt.Errorf("disabling refresh of user sessions")
}
return cfg.Options.GetAuthenticator(ctx, c.tracerProvider, idpID)
}),
}, c.managerOptions...)
if cfg.Options.SupportsUserRefresh() {
authenticator, err := identity.NewAuthenticator(ctx, c.tracerProvider, oauthOptions)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("databroker: failed to create authenticator")
} else {
options = append(options, manager.WithAuthenticator(authenticator))
}
} else {
log.Ctx(ctx).Info().Msg("databroker: disabling refresh of user sessions")
}
if c.manager == nil {
c.manager = manager.New(options...)
} else {

View file

@ -189,15 +189,13 @@ func (s *Stateful) PersistSession(
now := timeNow()
sessionExpiry := timestamppb.New(now.Add(s.sessionDuration))
sess := &session.Session{
Id: sessionState.ID,
UserId: sessionState.UserID(),
IssuedAt: timestamppb.New(now),
AccessedAt: timestamppb.New(now),
ExpiresAt: sessionExpiry,
OauthToken: manager.ToOAuthToken(accessToken),
Audience: sessionState.Audience,
}
sess := session.New(sessionState.IdentityProviderID, sessionState.ID)
sess.UserId = sessionState.UserID()
sess.IssuedAt = timestamppb.New(now)
sess.AccessedAt = timestamppb.New(now)
sess.ExpiresAt = sessionExpiry
sess.OauthToken = manager.ToOAuthToken(accessToken)
sess.Audience = sessionState.Audience
sess.SetRawIDToken(claims.RawIDToken)
sess.AddClaims(claims.Flatten())
@ -236,9 +234,7 @@ func (s *Stateful) GetUserInfoData(
isImpersonated = true
}
if err != nil {
pbSession = &session.Session{
Id: sessionState.ID,
}
pbSession = session.New(sessionState.IdentityProviderID, sessionState.ID)
}
pbUser, err := user.Get(r.Context(), s.dataBrokerClient, pbSession.GetUserId())

View file

@ -415,7 +415,7 @@ func (s *Stateless) Callback(w http.ResponseWriter, r *http.Request) error {
ss := newSessionStateFromProfile(profile)
sess, err := session.Get(r.Context(), s.dataBrokerClient, ss.ID)
if err != nil {
sess = &session.Session{Id: ss.ID}
sess = session.New(ss.IdentityProviderID, ss.ID)
}
populateSessionFromProfile(sess, profile, ss, s.options.CookieExpire)
u, err := user.Get(r.Context(), s.dataBrokerClient, ss.UserID())

View file

@ -83,6 +83,14 @@ func Patch(
return res, err
}
// New creates a new Session.
func New(idpID, id string) *Session {
return &Session{
IdpId: idpID,
Id: id,
}
}
// AddClaims adds the flattened claims to the session.
func (x *Session) AddClaims(claims identity.FlattenedClaims) {
if x.Claims == nil {

View file

@ -190,6 +190,7 @@ type Session struct {
Claims map[string]*structpb.ListValue `protobuf:"bytes,9,rep,name=claims,proto3" json:"claims,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
Audience []string `protobuf:"bytes,10,rep,name=audience,proto3" json:"audience,omitempty"`
RefreshDisabled bool `protobuf:"varint,19,opt,name=refresh_disabled,json=refreshDisabled,proto3" json:"refresh_disabled,omitempty"`
IdpId string `protobuf:"bytes,20,opt,name=idp_id,json=idpId,proto3" json:"idp_id,omitempty"`
ImpersonateSessionId *string `protobuf:"bytes,15,opt,name=impersonate_session_id,json=impersonateSessionId,proto3,oneof" json:"impersonate_session_id,omitempty"`
}
@ -309,6 +310,13 @@ func (x *Session) GetRefreshDisabled() bool {
return false
}
func (x *Session) GetIdpId() string {
if x != nil {
return x.IdpId
}
return ""
}
func (x *Session) GetImpersonateSessionId() string {
if x != nil && x.ImpersonateSessionId != nil {
return *x.ImpersonateSessionId
@ -438,7 +446,7 @@ var file_session_proto_rawDesc = []byte{
0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72,
0x65, 0x73, 0x41, 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f,
0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, 0x66,
0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xe6, 0x06, 0x0a, 0x07, 0x53, 0x65,
0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xfd, 0x06, 0x0a, 0x07, 0x53, 0x65,
0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e,
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12,
0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12,
@ -473,30 +481,32 @@ var file_session_proto_rawDesc = []byte{
0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65,
0x12, 0x29, 0x0a, 0x10, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x64, 0x69, 0x73, 0x61,
0x62, 0x6c, 0x65, 0x64, 0x18, 0x13, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x72, 0x65, 0x66, 0x72,
0x65, 0x73, 0x68, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x39, 0x0a, 0x16, 0x69,
0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69,
0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x14, 0x69,
0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f,
0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x1a, 0x87, 0x01, 0x0a, 0x10, 0x44, 0x65, 0x76, 0x69, 0x63,
0x65, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x12, 0x17, 0x0a, 0x07, 0x74,
0x79, 0x70, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x79,
0x70, 0x65, 0x49, 0x64, 0x12, 0x3a, 0x0a, 0x0b, 0x75, 0x6e, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61,
0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67,
0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74,
0x79, 0x48, 0x00, 0x52, 0x0b, 0x75, 0x6e, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65,
0x12, 0x10, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x02,
0x69, 0x64, 0x42, 0x0c, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c,
0x1a, 0x55, 0x0a, 0x0b, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12,
0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65,
0x79, 0x12, 0x30, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b,
0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
0x75, 0x66, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x05, 0x76, 0x61,
0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x19, 0x0a, 0x17, 0x5f, 0x69, 0x6d, 0x70, 0x65,
0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f,
0x69, 0x64, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d,
0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x73, 0x65, 0x73, 0x73,
0x69, 0x6f, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x65, 0x73, 0x68, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x15, 0x0a, 0x06, 0x69,
0x64, 0x70, 0x5f, 0x69, 0x64, 0x18, 0x14, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x69, 0x64, 0x70,
0x49, 0x64, 0x12, 0x39, 0x0a, 0x16, 0x69, 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74,
0x65, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01,
0x28, 0x09, 0x48, 0x00, 0x52, 0x14, 0x69, 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74,
0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x1a, 0x87, 0x01,
0x0a, 0x10, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69,
0x61, 0x6c, 0x12, 0x17, 0x0a, 0x07, 0x74, 0x79, 0x70, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x79, 0x70, 0x65, 0x49, 0x64, 0x12, 0x3a, 0x0a, 0x0b, 0x75,
0x6e, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b,
0x32, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x48, 0x00, 0x52, 0x0b, 0x75, 0x6e, 0x61, 0x76,
0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x10, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x03, 0x20,
0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x02, 0x69, 0x64, 0x42, 0x0c, 0x0a, 0x0a, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x1a, 0x55, 0x0a, 0x0b, 0x43, 0x6c, 0x61, 0x69, 0x6d,
0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20,
0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75,
0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x56, 0x61,
0x6c, 0x75, 0x65, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x19,
0x0a, 0x17, 0x5f, 0x69, 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73,
0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x69, 0x74,
0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d,
0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72,
0x70, 0x63, 0x2f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x33,
}
var (

View file

@ -43,6 +43,7 @@ message Session {
map<string, google.protobuf.ListValue> claims = 9;
repeated string audience = 10;
bool refresh_disabled = 19;
string idp_id = 20;
optional string impersonate_session_id = 15;
}

View file

@ -1,10 +1,14 @@
package manager
import (
"context"
"fmt"
"sync"
"time"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/identity"
)
var (
@ -15,7 +19,6 @@ var (
)
type config struct {
authenticator Authenticator
dataBrokerClient databroker.DataBrokerServiceClient
sessionRefreshGracePeriod time.Duration
sessionRefreshCoolOffDuration time.Duration
@ -23,10 +26,14 @@ type config struct {
leaseTTL time.Duration
now func() time.Time
eventMgr *events.Manager
getAuthenticator func(ctx context.Context, idpID string) (identity.Authenticator, error)
}
func newConfig(options ...Option) *config {
cfg := new(config)
WithGetAuthenticator(func(_ context.Context, _ string) (identity.Authenticator, error) {
return nil, fmt.Errorf("authenticator not configured")
})
WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg)
WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg)
WithNow(time.Now)(cfg)
@ -41,10 +48,27 @@ func newConfig(options ...Option) *config {
// An Option customizes the configuration used for the identity manager.
type Option func(*config)
// WithAuthenticator sets the authenticator in the config.
func WithAuthenticator(authenticator Authenticator) Option {
// WithCachedGetAuthenticator sets the get authenticator function in the config.
func WithCachedGetAuthenticator(getAuthenticator func(ctx context.Context, idpID string) (identity.Authenticator, error)) Option {
return func(cfg *config) {
cfg.authenticator = authenticator
var mu sync.Mutex
type state struct {
authenticator identity.Authenticator
err error
}
lookup := map[string]state{}
cfg.getAuthenticator = func(ctx context.Context, idpID string) (identity.Authenticator, error) {
mu.Lock()
defer mu.Unlock()
s, ok := lookup[idpID]
if !ok {
s.authenticator, s.err = getAuthenticator(ctx, idpID)
lookup[idpID] = s
}
return s.authenticator, s.err
}
}
}
@ -55,6 +79,13 @@ func WithDataBrokerClient(dataBrokerClient databroker.DataBrokerServiceClient) O
}
}
// WithGetAuthenticator sets the get authenticator function in the config.
func WithGetAuthenticator(getAuthenticator func(ctx context.Context, idpID string) (identity.Authenticator, error)) Option {
return func(cfg *config) {
cfg.getAuthenticator = getAuthenticator
}
}
// WithSessionRefreshGracePeriod sets the session refresh grace period used by the manager.
func WithSessionRefreshGracePeriod(dur time.Duration) Option {
return func(cfg *config) {

View file

@ -94,7 +94,7 @@ func (mgr *Manager) RunLeased(ctx context.Context) error {
}
func (mgr *Manager) onDeleteAllSessions(ctx context.Context) {
log.Ctx(ctx).Debug().Msg("all session deleted")
log.Ctx(ctx).Debug().Msg("all sessions deleted")
mgr.mu.Lock()
mgr.dataStore.deleteAllSessions()
@ -197,9 +197,10 @@ func (mgr *Manager) refreshSession(ctx context.Context, sessionID string) {
return
}
authenticator := mgr.cfg.Load().authenticator
if authenticator == nil {
authenticator, err := mgr.cfg.Load().getAuthenticator(ctx, s.GetIdpId())
if err != nil {
log.Ctx(ctx).Info().
Err(err).
Str("user-id", s.GetUserId()).
Str("session-id", s.GetId()).
Msg("no authenticator defined, deleting session")
@ -275,12 +276,6 @@ func (mgr *Manager) refreshSession(ctx context.Context, sessionID string) {
func (mgr *Manager) updateUserInfo(ctx context.Context, userID string) {
log.Ctx(ctx).Info().Str("user-id", userID).Msg("updating user info")
authenticator := mgr.cfg.Load().authenticator
if authenticator == nil {
return
}
mgr.mu.Lock()
u, ss := mgr.dataStore.getUserAndSessions(userID)
mgr.mu.Unlock()
@ -301,7 +296,12 @@ func (mgr *Manager) updateUserInfo(ctx context.Context, userID string) {
continue
}
err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.GetOauthToken()), newUserUnmarshaler(u))
authenticator, err := mgr.cfg.Load().getAuthenticator(ctx, s.GetIdpId())
if err != nil {
continue
}
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.GetOauthToken()), newUserUnmarshaler(u))
metrics.RecordIdentityManagerUserRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
if isTemporaryError(err) {

View file

@ -23,6 +23,7 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/identity"
@ -202,7 +203,7 @@ func (a *Auth) handleLogin(
querier KeyboardInteractiveQuerier,
) error {
// Initiate the IdP login flow.
authenticator, err := a.getAuthenticator(ctx, hostname)
idp, authenticator, err := a.getAuthenticator(ctx, hostname)
if err != nil {
return err
}
@ -228,7 +229,7 @@ func (a *Auth) handleLogin(
if err != nil {
return err
}
return a.saveSession(ctx, sessionID, &sessionClaims, token)
return a.saveSession(ctx, idp.Id, sessionID, &sessionClaims, token)
}
var errAccessDenied = status.Error(codes.PermissionDenied, "access denied")
@ -320,6 +321,7 @@ func (a *Auth) DeleteSession(ctx context.Context, info StreamAuthInfo) error {
func (a *Auth) saveSession(
ctx context.Context,
idpID,
id string,
claims *identity.SessionClaims,
token *oauth2.Token,
@ -333,15 +335,13 @@ func (a *Auth) saveSession(
return err
}
sess := &session.Session{
Id: id,
UserId: state.UserID(),
IssuedAt: nowpb,
AccessedAt: nowpb,
ExpiresAt: timestamppb.New(now.Add(sessionLifetime)),
OauthToken: manager.ToOAuthToken(token),
Audience: state.Audience,
}
sess := session.New(idpID, id)
sess.UserId = state.UserID()
sess.IssuedAt = nowpb
sess.AccessedAt = nowpb
sess.ExpiresAt = timestamppb.New(now.Add(sessionLifetime))
sess.OauthToken = manager.ToOAuthToken(token)
sess.Audience = state.Audience
sess.SetRawIDToken(claims.RawIDToken)
sess.AddClaims(claims.Flatten())
@ -367,20 +367,25 @@ func (a *Auth) saveSession(
return nil
}
func (a *Auth) getAuthenticator(ctx context.Context, hostname string) (identity.Authenticator, error) {
func (a *Auth) getAuthenticator(ctx context.Context, hostname string) (*identitypb.Provider, identity.Authenticator, error) {
opts := a.currentConfig.Load().Options
redirectURL, err := opts.GetAuthenticateRedirectURL()
if err != nil {
return nil, err
return nil, nil, err
}
idp, err := opts.GetIdentityProviderForPolicy(opts.GetRouteForSSHHostname(hostname))
if err != nil {
return nil, err
return nil, nil, err
}
return identity.GetIdentityProvider(ctx, a.tracerProvider, idp, redirectURL)
authenticator, err := identity.GetIdentityProvider(ctx, a.tracerProvider, idp, redirectURL)
if err != nil {
return nil, nil, err
}
return idp, authenticator, nil
}
var _ AuthInterface = (*Auth)(nil)

View file

@ -52,7 +52,7 @@ func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
if err == nil {
data.Session, data.IsImpersonated, err = p.getSession(r.Context(), ss.ID)
if err != nil {
data.Session = &session.Session{Id: ss.ID}
data.Session = session.New(ss.IdentityProviderID, ss.ID)
}
data.User, err = p.getUser(r.Context(), data.Session.GetUserId())