mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-30 06:51:30 +02:00
databroker: update identity manager to use route credentials (#5728)
## Summary Currently when we refresh sessions we always use the global IdP credentials. This PR updates the identity manager to use route settings when defined. To do this a new `idp_id` field is added to the session stored in the databroker. ## Related issues - [ENG-2595](https://linear.app/pomerium/issue/ENG-2595/refresh-using-custom-idp-uses-wrong-credentials) - https://github.com/pomerium/pomerium/issues/4759 ## Checklist - [x] reference any related issues - [x] updated unit tests - [x] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [x] ready for review
This commit is contained in:
parent
e5e799a868
commit
622519e901
14 changed files with 185 additions and 123 deletions
|
@ -1,15 +1,19 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
"github.com/pomerium/pomerium/pkg/identity"
|
||||
)
|
||||
|
||||
// GetIdentityProviderForID returns the identity provider associated with the given IDP id.
|
||||
// If none is found the default provider is returned.
|
||||
func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, error) {
|
||||
func (o *Options) GetIdentityProviderForID(idpID string) (*identitypb.Provider, error) {
|
||||
for p := range o.GetAllPolicies() {
|
||||
idp, err := o.GetIdentityProviderForPolicy(p)
|
||||
if err != nil {
|
||||
|
@ -25,7 +29,7 @@ func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, er
|
|||
|
||||
// GetIdentityProviderForPolicy gets the identity provider associated with the given policy.
|
||||
// If policy is nil, or changes none of the default settings, the default provider is returned.
|
||||
func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provider, error) {
|
||||
func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identitypb.Provider, error) {
|
||||
clientSecret, err := o.GetClientSecret()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -36,7 +40,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
|
|||
return nil, err
|
||||
}
|
||||
|
||||
idp := &identity.Provider{
|
||||
idp := &identitypb.Provider{
|
||||
AuthenticateServiceUrl: authenticateURL.String(),
|
||||
ClientId: o.ClientID,
|
||||
ClientSecret: clientSecret,
|
||||
|
@ -46,7 +50,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
|
|||
RequestParams: o.RequestParams,
|
||||
}
|
||||
if v := o.IDPAccessTokenAllowedAudiences; v != nil {
|
||||
idp.AccessTokenAllowedAudiences = &identity.Provider_StringList{
|
||||
idp.AccessTokenAllowedAudiences = &identitypb.Provider_StringList{
|
||||
Values: slices.Clone(*v),
|
||||
}
|
||||
}
|
||||
|
@ -58,7 +62,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
|
|||
idp.ClientSecret = policy.IDPClientSecret
|
||||
}
|
||||
if v := policy.IDPAccessTokenAllowedAudiences; v != nil {
|
||||
idp.AccessTokenAllowedAudiences = &identity.Provider_StringList{
|
||||
idp.AccessTokenAllowedAudiences = &identitypb.Provider_StringList{
|
||||
Values: slices.Clone(*v),
|
||||
}
|
||||
}
|
||||
|
@ -68,7 +72,7 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
|
|||
}
|
||||
|
||||
// GetIdentityProviderForRequestURL gets the identity provider associated with the given request URL.
|
||||
func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity.Provider, error) {
|
||||
func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identitypb.Provider, error) {
|
||||
u, err := urlutil.ParseAndValidateURL(requestURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -81,3 +85,18 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity
|
|||
}
|
||||
return o.GetIdentityProviderForPolicy(nil)
|
||||
}
|
||||
|
||||
// GetAuthenticator gets the authenticator for the given IDP id.
|
||||
func (o *Options) GetAuthenticator(ctx context.Context, tracerProvider oteltrace.TracerProvider, idpID string) (identity.Authenticator, error) {
|
||||
redirectURL, err := o.GetAuthenticateRedirectURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idp, err := o.GetIdentityProviderForID(idpID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return identity.GetIdentityProvider(ctx, tracerProvider, idp, redirectURL)
|
||||
}
|
||||
|
|
|
@ -851,13 +851,14 @@ func (o *Options) GetAuthenticateRedirectURL() (*url.URL, error) {
|
|||
// flow should be used (i.e. for hosted authenticate).
|
||||
func (o *Options) UseStatelessAuthenticateFlow() bool {
|
||||
if flow := os.Getenv("DEBUG_FORCE_AUTHENTICATE_FLOW"); flow != "" {
|
||||
if flow == "stateless" {
|
||||
switch flow {
|
||||
case "stateless":
|
||||
return true
|
||||
} else if flow == "stateful" {
|
||||
case "stateful":
|
||||
return false
|
||||
default:
|
||||
log.Error().Msgf("ignoring unknown DEBUG_FORCE_AUTHENTICATE_FLOW setting %q", flow)
|
||||
}
|
||||
log.Error().
|
||||
Msgf("ignoring unknown DEBUG_FORCE_AUTHENTICATE_FLOW setting %q", flow)
|
||||
}
|
||||
u, err := o.GetInternalAuthenticateURL()
|
||||
if err != nil {
|
||||
|
|
|
@ -240,7 +240,7 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
|
|||
return nil, fmt.Errorf("%w: invalid access token", sessions.ErrInvalidSession)
|
||||
}
|
||||
|
||||
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||
s = c.newSessionFromIDPClaims(cfg, idp.Id, sessionID, res.Claims)
|
||||
s.OauthToken = &session.OAuthToken{
|
||||
TokenType: "Bearer",
|
||||
AccessToken: rawAccessToken,
|
||||
|
@ -309,7 +309,7 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
|
|||
return nil, fmt.Errorf("%w: invalid identity token", sessions.ErrInvalidSession)
|
||||
}
|
||||
|
||||
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||
s = c.newSessionFromIDPClaims(cfg, idp.Id, sessionID, res.Claims)
|
||||
s.SetRawIDToken(rawIdentityToken)
|
||||
|
||||
u, err := c.getUser(ctx, s.GetUserId())
|
||||
|
@ -338,11 +338,12 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
|
|||
|
||||
func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
|
||||
cfg *Config,
|
||||
idpID string,
|
||||
sessionID string,
|
||||
claims jwtutil.Claims,
|
||||
) *session.Session {
|
||||
now := c.timeNow()
|
||||
s := new(session.Session)
|
||||
s := session.New(idpID, sessionID)
|
||||
s.Id = sessionID
|
||||
if userID, ok := claims.GetUserID(); ok {
|
||||
s.UserId = userID
|
||||
|
|
|
@ -375,7 +375,7 @@ func Test_newSessionFromIDPClaims(t *testing.T) {
|
|||
c := &incomingIDPTokenSessionCreator{
|
||||
timeNow: func() time.Time { return tm1 },
|
||||
}
|
||||
actual := c.newSessionFromIDPClaims(cfg, tc.sessionID, tc.claims)
|
||||
actual := c.newSessionFromIDPClaims(cfg, "", tc.sessionID, tc.claims)
|
||||
testutil.AssertProtoEqual(t, tc.expect, actual)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -171,7 +171,7 @@ func (c *DataBroker) Run(ctx context.Context) error {
|
|||
return eg.Wait()
|
||||
}
|
||||
|
||||
func (c *DataBroker) update(ctx context.Context, cfg *config.Config) error {
|
||||
func (c *DataBroker) update(_ context.Context, cfg *config.Config) error {
|
||||
if err := validate(cfg.Options); err != nil {
|
||||
return fmt.Errorf("databroker: bad option: %w", err)
|
||||
}
|
||||
|
@ -182,29 +182,19 @@ func (c *DataBroker) update(ctx context.Context, cfg *config.Config) error {
|
|||
}
|
||||
c.sharedKey.Store(sharedKey)
|
||||
|
||||
oauthOptions, err := cfg.Options.GetOauthOptions()
|
||||
if err != nil {
|
||||
return fmt.Errorf("databroker: invalid oauth options: %w", err)
|
||||
}
|
||||
|
||||
dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection)
|
||||
|
||||
options := append([]manager.Option{
|
||||
manager.WithDataBrokerClient(dataBrokerClient),
|
||||
manager.WithEventManager(c.eventsMgr),
|
||||
manager.WithCachedGetAuthenticator(func(ctx context.Context, idpID string) (identity.Authenticator, error) {
|
||||
if !cfg.Options.SupportsUserRefresh() {
|
||||
return nil, fmt.Errorf("disabling refresh of user sessions")
|
||||
}
|
||||
return cfg.Options.GetAuthenticator(ctx, c.tracerProvider, idpID)
|
||||
}),
|
||||
}, c.managerOptions...)
|
||||
|
||||
if cfg.Options.SupportsUserRefresh() {
|
||||
authenticator, err := identity.NewAuthenticator(ctx, c.tracerProvider, oauthOptions)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("databroker: failed to create authenticator")
|
||||
} else {
|
||||
options = append(options, manager.WithAuthenticator(authenticator))
|
||||
}
|
||||
} else {
|
||||
log.Ctx(ctx).Info().Msg("databroker: disabling refresh of user sessions")
|
||||
}
|
||||
|
||||
if c.manager == nil {
|
||||
c.manager = manager.New(options...)
|
||||
} else {
|
||||
|
|
|
@ -189,15 +189,13 @@ func (s *Stateful) PersistSession(
|
|||
now := timeNow()
|
||||
sessionExpiry := timestamppb.New(now.Add(s.sessionDuration))
|
||||
|
||||
sess := &session.Session{
|
||||
Id: sessionState.ID,
|
||||
UserId: sessionState.UserID(),
|
||||
IssuedAt: timestamppb.New(now),
|
||||
AccessedAt: timestamppb.New(now),
|
||||
ExpiresAt: sessionExpiry,
|
||||
OauthToken: manager.ToOAuthToken(accessToken),
|
||||
Audience: sessionState.Audience,
|
||||
}
|
||||
sess := session.New(sessionState.IdentityProviderID, sessionState.ID)
|
||||
sess.UserId = sessionState.UserID()
|
||||
sess.IssuedAt = timestamppb.New(now)
|
||||
sess.AccessedAt = timestamppb.New(now)
|
||||
sess.ExpiresAt = sessionExpiry
|
||||
sess.OauthToken = manager.ToOAuthToken(accessToken)
|
||||
sess.Audience = sessionState.Audience
|
||||
sess.SetRawIDToken(claims.RawIDToken)
|
||||
sess.AddClaims(claims.Flatten())
|
||||
|
||||
|
@ -236,9 +234,7 @@ func (s *Stateful) GetUserInfoData(
|
|||
isImpersonated = true
|
||||
}
|
||||
if err != nil {
|
||||
pbSession = &session.Session{
|
||||
Id: sessionState.ID,
|
||||
}
|
||||
pbSession = session.New(sessionState.IdentityProviderID, sessionState.ID)
|
||||
}
|
||||
|
||||
pbUser, err := user.Get(r.Context(), s.dataBrokerClient, pbSession.GetUserId())
|
||||
|
|
|
@ -415,7 +415,7 @@ func (s *Stateless) Callback(w http.ResponseWriter, r *http.Request) error {
|
|||
ss := newSessionStateFromProfile(profile)
|
||||
sess, err := session.Get(r.Context(), s.dataBrokerClient, ss.ID)
|
||||
if err != nil {
|
||||
sess = &session.Session{Id: ss.ID}
|
||||
sess = session.New(ss.IdentityProviderID, ss.ID)
|
||||
}
|
||||
populateSessionFromProfile(sess, profile, ss, s.options.CookieExpire)
|
||||
u, err := user.Get(r.Context(), s.dataBrokerClient, ss.UserID())
|
||||
|
|
|
@ -83,6 +83,14 @@ func Patch(
|
|||
return res, err
|
||||
}
|
||||
|
||||
// New creates a new Session.
|
||||
func New(idpID, id string) *Session {
|
||||
return &Session{
|
||||
IdpId: idpID,
|
||||
Id: id,
|
||||
}
|
||||
}
|
||||
|
||||
// AddClaims adds the flattened claims to the session.
|
||||
func (x *Session) AddClaims(claims identity.FlattenedClaims) {
|
||||
if x.Claims == nil {
|
||||
|
|
|
@ -190,6 +190,7 @@ type Session struct {
|
|||
Claims map[string]*structpb.ListValue `protobuf:"bytes,9,rep,name=claims,proto3" json:"claims,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
|
||||
Audience []string `protobuf:"bytes,10,rep,name=audience,proto3" json:"audience,omitempty"`
|
||||
RefreshDisabled bool `protobuf:"varint,19,opt,name=refresh_disabled,json=refreshDisabled,proto3" json:"refresh_disabled,omitempty"`
|
||||
IdpId string `protobuf:"bytes,20,opt,name=idp_id,json=idpId,proto3" json:"idp_id,omitempty"`
|
||||
ImpersonateSessionId *string `protobuf:"bytes,15,opt,name=impersonate_session_id,json=impersonateSessionId,proto3,oneof" json:"impersonate_session_id,omitempty"`
|
||||
}
|
||||
|
||||
|
@ -309,6 +310,13 @@ func (x *Session) GetRefreshDisabled() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (x *Session) GetIdpId() string {
|
||||
if x != nil {
|
||||
return x.IdpId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Session) GetImpersonateSessionId() string {
|
||||
if x != nil && x.ImpersonateSessionId != nil {
|
||||
return *x.ImpersonateSessionId
|
||||
|
@ -438,7 +446,7 @@ var file_session_proto_rawDesc = []byte{
|
|||
0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72,
|
||||
0x65, 0x73, 0x41, 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f,
|
||||
0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, 0x66,
|
||||
0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xe6, 0x06, 0x0a, 0x07, 0x53, 0x65,
|
||||
0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xfd, 0x06, 0x0a, 0x07, 0x53, 0x65,
|
||||
0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e,
|
||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12,
|
||||
0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12,
|
||||
|
@ -473,30 +481,32 @@ var file_session_proto_rawDesc = []byte{
|
|||
0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65,
|
||||
0x12, 0x29, 0x0a, 0x10, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x64, 0x69, 0x73, 0x61,
|
||||
0x62, 0x6c, 0x65, 0x64, 0x18, 0x13, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x72, 0x65, 0x66, 0x72,
|
||||
0x65, 0x73, 0x68, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x39, 0x0a, 0x16, 0x69,
|
||||
0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69,
|
||||
0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x14, 0x69,
|
||||
0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f,
|
||||
0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x1a, 0x87, 0x01, 0x0a, 0x10, 0x44, 0x65, 0x76, 0x69, 0x63,
|
||||
0x65, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x12, 0x17, 0x0a, 0x07, 0x74,
|
||||
0x79, 0x70, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x79,
|
||||
0x70, 0x65, 0x49, 0x64, 0x12, 0x3a, 0x0a, 0x0b, 0x75, 0x6e, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61,
|
||||
0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67,
|
||||
0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74,
|
||||
0x79, 0x48, 0x00, 0x52, 0x0b, 0x75, 0x6e, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65,
|
||||
0x12, 0x10, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x02,
|
||||
0x69, 0x64, 0x42, 0x0c, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c,
|
||||
0x1a, 0x55, 0x0a, 0x0b, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12,
|
||||
0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65,
|
||||
0x79, 0x12, 0x30, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b,
|
||||
0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
|
||||
0x75, 0x66, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x05, 0x76, 0x61,
|
||||
0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x19, 0x0a, 0x17, 0x5f, 0x69, 0x6d, 0x70, 0x65,
|
||||
0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f,
|
||||
0x69, 0x64, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d,
|
||||
0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
|
||||
0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x73, 0x65, 0x73, 0x73,
|
||||
0x69, 0x6f, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x65, 0x73, 0x68, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x15, 0x0a, 0x06, 0x69,
|
||||
0x64, 0x70, 0x5f, 0x69, 0x64, 0x18, 0x14, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x69, 0x64, 0x70,
|
||||
0x49, 0x64, 0x12, 0x39, 0x0a, 0x16, 0x69, 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74,
|
||||
0x65, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01,
|
||||
0x28, 0x09, 0x48, 0x00, 0x52, 0x14, 0x69, 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74,
|
||||
0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x1a, 0x87, 0x01,
|
||||
0x0a, 0x10, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69,
|
||||
0x61, 0x6c, 0x12, 0x17, 0x0a, 0x07, 0x74, 0x79, 0x70, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x79, 0x70, 0x65, 0x49, 0x64, 0x12, 0x3a, 0x0a, 0x0b, 0x75,
|
||||
0x6e, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b,
|
||||
0x32, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
|
||||
0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x48, 0x00, 0x52, 0x0b, 0x75, 0x6e, 0x61, 0x76,
|
||||
0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x10, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x03, 0x20,
|
||||
0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x02, 0x69, 0x64, 0x42, 0x0c, 0x0a, 0x0a, 0x63, 0x72, 0x65,
|
||||
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x1a, 0x55, 0x0a, 0x0b, 0x43, 0x6c, 0x61, 0x69, 0x6d,
|
||||
0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75,
|
||||
0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65,
|
||||
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x56, 0x61,
|
||||
0x6c, 0x75, 0x65, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x19,
|
||||
0x0a, 0x17, 0x5f, 0x69, 0x6d, 0x70, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73,
|
||||
0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x69, 0x74,
|
||||
0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d,
|
||||
0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72,
|
||||
0x70, 0x63, 0x2f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
|
@ -43,6 +43,7 @@ message Session {
|
|||
map<string, google.protobuf.ListValue> claims = 9;
|
||||
repeated string audience = 10;
|
||||
bool refresh_disabled = 19;
|
||||
string idp_id = 20;
|
||||
|
||||
optional string impersonate_session_id = 15;
|
||||
}
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/events"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/identity"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -15,7 +19,6 @@ var (
|
|||
)
|
||||
|
||||
type config struct {
|
||||
authenticator Authenticator
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
sessionRefreshGracePeriod time.Duration
|
||||
sessionRefreshCoolOffDuration time.Duration
|
||||
|
@ -23,10 +26,14 @@ type config struct {
|
|||
leaseTTL time.Duration
|
||||
now func() time.Time
|
||||
eventMgr *events.Manager
|
||||
getAuthenticator func(ctx context.Context, idpID string) (identity.Authenticator, error)
|
||||
}
|
||||
|
||||
func newConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithGetAuthenticator(func(_ context.Context, _ string) (identity.Authenticator, error) {
|
||||
return nil, fmt.Errorf("authenticator not configured")
|
||||
})
|
||||
WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg)
|
||||
WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg)
|
||||
WithNow(time.Now)(cfg)
|
||||
|
@ -41,10 +48,27 @@ func newConfig(options ...Option) *config {
|
|||
// An Option customizes the configuration used for the identity manager.
|
||||
type Option func(*config)
|
||||
|
||||
// WithAuthenticator sets the authenticator in the config.
|
||||
func WithAuthenticator(authenticator Authenticator) Option {
|
||||
// WithCachedGetAuthenticator sets the get authenticator function in the config.
|
||||
func WithCachedGetAuthenticator(getAuthenticator func(ctx context.Context, idpID string) (identity.Authenticator, error)) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.authenticator = authenticator
|
||||
var mu sync.Mutex
|
||||
type state struct {
|
||||
authenticator identity.Authenticator
|
||||
err error
|
||||
}
|
||||
lookup := map[string]state{}
|
||||
cfg.getAuthenticator = func(ctx context.Context, idpID string) (identity.Authenticator, error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
s, ok := lookup[idpID]
|
||||
if !ok {
|
||||
s.authenticator, s.err = getAuthenticator(ctx, idpID)
|
||||
lookup[idpID] = s
|
||||
}
|
||||
|
||||
return s.authenticator, s.err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,6 +79,13 @@ func WithDataBrokerClient(dataBrokerClient databroker.DataBrokerServiceClient) O
|
|||
}
|
||||
}
|
||||
|
||||
// WithGetAuthenticator sets the get authenticator function in the config.
|
||||
func WithGetAuthenticator(getAuthenticator func(ctx context.Context, idpID string) (identity.Authenticator, error)) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.getAuthenticator = getAuthenticator
|
||||
}
|
||||
}
|
||||
|
||||
// WithSessionRefreshGracePeriod sets the session refresh grace period used by the manager.
|
||||
func WithSessionRefreshGracePeriod(dur time.Duration) Option {
|
||||
return func(cfg *config) {
|
||||
|
|
|
@ -94,7 +94,7 @@ func (mgr *Manager) RunLeased(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (mgr *Manager) onDeleteAllSessions(ctx context.Context) {
|
||||
log.Ctx(ctx).Debug().Msg("all session deleted")
|
||||
log.Ctx(ctx).Debug().Msg("all sessions deleted")
|
||||
|
||||
mgr.mu.Lock()
|
||||
mgr.dataStore.deleteAllSessions()
|
||||
|
@ -197,9 +197,10 @@ func (mgr *Manager) refreshSession(ctx context.Context, sessionID string) {
|
|||
return
|
||||
}
|
||||
|
||||
authenticator := mgr.cfg.Load().authenticator
|
||||
if authenticator == nil {
|
||||
authenticator, err := mgr.cfg.Load().getAuthenticator(ctx, s.GetIdpId())
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Info().
|
||||
Err(err).
|
||||
Str("user-id", s.GetUserId()).
|
||||
Str("session-id", s.GetId()).
|
||||
Msg("no authenticator defined, deleting session")
|
||||
|
@ -275,12 +276,6 @@ func (mgr *Manager) refreshSession(ctx context.Context, sessionID string) {
|
|||
|
||||
func (mgr *Manager) updateUserInfo(ctx context.Context, userID string) {
|
||||
log.Ctx(ctx).Info().Str("user-id", userID).Msg("updating user info")
|
||||
|
||||
authenticator := mgr.cfg.Load().authenticator
|
||||
if authenticator == nil {
|
||||
return
|
||||
}
|
||||
|
||||
mgr.mu.Lock()
|
||||
u, ss := mgr.dataStore.getUserAndSessions(userID)
|
||||
mgr.mu.Unlock()
|
||||
|
@ -301,7 +296,12 @@ func (mgr *Manager) updateUserInfo(ctx context.Context, userID string) {
|
|||
continue
|
||||
}
|
||||
|
||||
err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.GetOauthToken()), newUserUnmarshaler(u))
|
||||
authenticator, err := mgr.cfg.Load().getAuthenticator(ctx, s.GetIdpId())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.GetOauthToken()), newUserUnmarshaler(u))
|
||||
metrics.RecordIdentityManagerUserRefresh(ctx, err)
|
||||
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
|
||||
if isTemporaryError(err) {
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/identity"
|
||||
|
@ -202,7 +203,7 @@ func (a *Auth) handleLogin(
|
|||
querier KeyboardInteractiveQuerier,
|
||||
) error {
|
||||
// Initiate the IdP login flow.
|
||||
authenticator, err := a.getAuthenticator(ctx, hostname)
|
||||
idp, authenticator, err := a.getAuthenticator(ctx, hostname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -228,7 +229,7 @@ func (a *Auth) handleLogin(
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return a.saveSession(ctx, sessionID, &sessionClaims, token)
|
||||
return a.saveSession(ctx, idp.Id, sessionID, &sessionClaims, token)
|
||||
}
|
||||
|
||||
var errAccessDenied = status.Error(codes.PermissionDenied, "access denied")
|
||||
|
@ -320,6 +321,7 @@ func (a *Auth) DeleteSession(ctx context.Context, info StreamAuthInfo) error {
|
|||
|
||||
func (a *Auth) saveSession(
|
||||
ctx context.Context,
|
||||
idpID,
|
||||
id string,
|
||||
claims *identity.SessionClaims,
|
||||
token *oauth2.Token,
|
||||
|
@ -333,15 +335,13 @@ func (a *Auth) saveSession(
|
|||
return err
|
||||
}
|
||||
|
||||
sess := &session.Session{
|
||||
Id: id,
|
||||
UserId: state.UserID(),
|
||||
IssuedAt: nowpb,
|
||||
AccessedAt: nowpb,
|
||||
ExpiresAt: timestamppb.New(now.Add(sessionLifetime)),
|
||||
OauthToken: manager.ToOAuthToken(token),
|
||||
Audience: state.Audience,
|
||||
}
|
||||
sess := session.New(idpID, id)
|
||||
sess.UserId = state.UserID()
|
||||
sess.IssuedAt = nowpb
|
||||
sess.AccessedAt = nowpb
|
||||
sess.ExpiresAt = timestamppb.New(now.Add(sessionLifetime))
|
||||
sess.OauthToken = manager.ToOAuthToken(token)
|
||||
sess.Audience = state.Audience
|
||||
sess.SetRawIDToken(claims.RawIDToken)
|
||||
sess.AddClaims(claims.Flatten())
|
||||
|
||||
|
@ -367,20 +367,25 @@ func (a *Auth) saveSession(
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) getAuthenticator(ctx context.Context, hostname string) (identity.Authenticator, error) {
|
||||
func (a *Auth) getAuthenticator(ctx context.Context, hostname string) (*identitypb.Provider, identity.Authenticator, error) {
|
||||
opts := a.currentConfig.Load().Options
|
||||
|
||||
redirectURL, err := opts.GetAuthenticateRedirectURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
idp, err := opts.GetIdentityProviderForPolicy(opts.GetRouteForSSHHostname(hostname))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return identity.GetIdentityProvider(ctx, a.tracerProvider, idp, redirectURL)
|
||||
authenticator, err := identity.GetIdentityProvider(ctx, a.tracerProvider, idp, redirectURL)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return idp, authenticator, nil
|
||||
}
|
||||
|
||||
var _ AuthInterface = (*Auth)(nil)
|
||||
|
|
|
@ -52,7 +52,7 @@ func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
|||
if err == nil {
|
||||
data.Session, data.IsImpersonated, err = p.getSession(r.Context(), ss.ID)
|
||||
if err != nil {
|
||||
data.Session = &session.Session{Id: ss.ID}
|
||||
data.Session = session.New(ss.IdentityProviderID, ss.ID)
|
||||
}
|
||||
|
||||
data.User, err = p.getUser(r.Context(), data.Session.GetUserId())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue