Optimize identity provider lookup with new cache

This commit is contained in:
Joe Kralicky 2024-11-04 15:20:31 -05:00
parent c8b6b8f1a9
commit 8df3028533
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
12 changed files with 179 additions and 77 deletions

View file

@ -9,6 +9,7 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/authenticateflow"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
) )
@ -52,6 +53,14 @@ func New(ctx context.Context, cfg *config.Config, options ...Option) (*Authentic
state: atomicutil.NewValue(newAuthenticateState()), state: atomicutil.NewValue(newAuthenticateState()),
} }
if authenticateConfig.getIdentityProvider == nil {
idpCache, err := config.NewIdentityProviderCache(cfg.Options)
if err != nil {
return nil, err
}
authenticateConfig.getIdentityProvider = authenticateflow.IdentityProviderLookupFromCache(idpCache)
}
a.options.Store(cfg.Options) a.options.Store(cfg.Options)
state, err := newAuthenticateStateFromConfig(ctx, cfg, authenticateConfig) state, err := newAuthenticateStateFromConfig(ctx, cfg, authenticateConfig)

View file

@ -18,7 +18,6 @@ type Option func(*authenticateConfig)
func getAuthenticateConfig(options ...Option) *authenticateConfig { func getAuthenticateConfig(options ...Option) *authenticateConfig {
cfg := new(authenticateConfig) cfg := new(authenticateConfig)
WithGetIdentityProvider(defaultGetIdentityProvider)(cfg)
for _, option := range options { for _, option := range options {
option(cfg) option(cfg)
} }

View file

@ -407,7 +407,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
idp, _ := new(config.Options).GetIdentityProviderForID("") idp, _ := new(config.Options).GetIdentityProviderForPolicy(nil)
tests := []struct { tests := []struct {
name string name string

View file

@ -53,6 +53,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
return nil, err return nil, err
} }
a.state = atomicutil.NewValue(state) a.state = atomicutil.NewValue(state)
a.currentOptions.Store(cfg.Options) // FIXME: this is refactored out in a different branch
return a, nil return a, nil
} }

View file

@ -14,7 +14,6 @@ import (
"google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig" "github.com/pomerium/pomerium/config/envoyconfig"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
@ -49,8 +48,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
hreq := getHTTPRequestFromCheckRequest(in) hreq := getHTTPRequestFromCheckRequest(in)
requestID := requestid.FromHTTPHeader(hreq.Header) requestID := requestid.FromHTTPHeader(hreq.Header)
ctx = requestid.WithValue(ctx, requestID) ctx = requestid.WithValue(ctx, requestID)
routeID := envoyconfig.ExtAuthzContextExtensionsRouteID(in.GetAttributes().GetContextExtensions())
sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq) sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq, routeID)
var s sessionOrServiceAccount var s sessionOrServiceAccount
var u *user.User var u *user.User
@ -120,23 +120,10 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
ID: sessionState.ID, ID: sessionState.ID,
} }
} }
req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions())) req.Policy, _ = a.state.Load().idpCache.GetPolicyByID(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()))
return req, nil return req, nil
} }
func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy {
options := a.currentOptions.Load()
for p := range options.GetAllPolicies() {
id, _ := p.RouteID()
if id == routeID {
return p
}
}
return nil
}
func getHTTPRequestFromCheckRequest(req *envoy_service_auth_v3.CheckRequest) *http.Request { func getHTTPRequestFromCheckRequest(req *envoy_service_auth_v3.CheckRequest) *http.Request {
hattrs := req.GetAttributes().GetRequest().GetHttp() hattrs := req.GetAttributes().GetRequest().GetHttp()
u := getCheckRequestURL(req) u := getCheckRequestURL(req)

View file

@ -17,9 +17,10 @@ import (
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/config/envoyconfig"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
) )
@ -49,15 +50,25 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
-----END CERTIFICATE-----` -----END CERTIFICATE-----`
func Test_getEvaluatorRequest(t *testing.T) { func Test_getEvaluatorRequest(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} policies := []config.Policy{{
a.currentOptions.Store(&config.Options{ From: "https://example.com",
Policies: []config.Policy{{ To: mustParseWeightedURLs(t, "https://foo.bar"),
From: "https://example.com", SubPolicies: []config.SubPolicy{{
SubPolicies: []config.SubPolicy{{ Rego: []string{"allow = true"},
Rego: []string{"allow = true"},
}},
}}, }},
}}
policy0RouteID, err := policies[0].RouteID()
require.NoError(t, err)
a, err := New(context.Background(), &config.Config{
Options: &config.Options{
SharedKey: cryptutil.NewBase64Key(),
CookieSecret: cryptutil.NewBase64Key(),
Policies: policies,
},
}) })
require.NoError(t, err)
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(), actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
&envoy_service_auth_v3.CheckRequest{ &envoy_service_auth_v3.CheckRequest{
@ -76,6 +87,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
Body: "BODY", Body: "BODY",
}, },
}, },
ContextExtensions: envoyconfig.MakeExtAuthzContextExtensions(false, policy0RouteID),
MetadataContext: &envoy_config_core_v3.Metadata{ MetadataContext: &envoy_config_core_v3.Metadata{
FilterMetadata: map[string]*structpb.Struct{ FilterMetadata: map[string]*structpb.Struct{
"com.pomerium.client-certificate-info": { "com.pomerium.client-certificate-info": {
@ -94,7 +106,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
expect := &evaluator.Request{ expect := &evaluator.Request{
Policy: &a.currentOptions.Load().Policies[0], Policy: &policies[0],
Session: evaluator.RequestSession{ Session: evaluator.RequestSession{
ID: "SESSION_ID", ID: "SESSION_ID",
}, },
@ -117,16 +129,24 @@ func Test_getEvaluatorRequest(t *testing.T) {
} }
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) { func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} policies := []config.Policy{{
a.currentOptions.Store(&config.Options{ From: "https://example.com",
Policies: []config.Policy{{ To: mustParseWeightedURLs(t, "https://foo.bar"),
From: "https://example.com", SubPolicies: []config.SubPolicy{{
SubPolicies: []config.SubPolicy{{ Rego: []string{"allow = true"},
Rego: []string{"allow = true"},
}},
}}, }},
}) }}
policy0RouteID, err := policies[0].RouteID()
require.NoError(t, err)
a, err := New(context.Background(), &config.Config{
Options: &config.Options{
SharedKey: cryptutil.NewBase64Key(),
CookieSecret: cryptutil.NewBase64Key(),
Policies: policies,
},
})
require.NoError(t, err)
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(), actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
&envoy_service_auth_v3.CheckRequest{ &envoy_service_auth_v3.CheckRequest{
Attributes: &envoy_service_auth_v3.AttributeContext{ Attributes: &envoy_service_auth_v3.AttributeContext{
@ -144,11 +164,12 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
Body: "BODY", Body: "BODY",
}, },
}, },
ContextExtensions: envoyconfig.MakeExtAuthzContextExtensions(false, policy0RouteID),
}, },
}, nil) }, nil)
require.NoError(t, err) require.NoError(t, err)
expect := &evaluator.Request{ expect := &evaluator.Request{
Policy: &a.currentOptions.Load().Policies[0], Policy: &policies[0],
Session: evaluator.RequestSession{}, Session: evaluator.RequestSession{},
HTTP: evaluator.NewRequestHTTP( HTTP: evaluator.NewRequestHTTP(
http.MethodGet, http.MethodGet,

View file

@ -29,6 +29,7 @@ type authorizeState struct {
dataBrokerClient databroker.DataBrokerServiceClient dataBrokerClient databroker.DataBrokerServiceClient
auditEncryptor *protoutil.Encryptor auditEncryptor *protoutil.Encryptor
sessionStore *config.SessionStore sessionStore *config.SessionStore
idpCache *config.IdentityProviderCache
authenticateFlow authenticateFlow authenticateFlow authenticateFlow
} }
@ -79,13 +80,20 @@ func newAuthorizeStateFromConfig(
state.auditEncryptor = protoutil.NewEncryptor(auditKey) state.auditEncryptor = protoutil.NewEncryptor(auditKey)
} }
state.sessionStore, err = config.NewSessionStore(cfg.Options) idpCache, err := config.NewIdentityProviderCache(cfg.Options)
if err != nil {
return nil, err
}
state.idpCache = idpCache
state.sessionStore, err = config.NewSessionStore(cfg.Options, state.idpCache)
if err != nil { if err != nil {
return nil, fmt.Errorf("authorize: invalid session store: %w", err) return nil, fmt.Errorf("authorize: invalid session store: %w", err)
} }
if cfg.Options.UseStatelessAuthenticateFlow() { if cfg.Options.UseStatelessAuthenticateFlow() {
state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, nil, nil, nil, nil) state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, nil,
authenticateflow.IdentityProviderLookupFromCache(idpCache), nil, nil)
} else { } else {
state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, nil) state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, nil)
} }

View file

@ -1,26 +1,12 @@
package config package config
import ( import (
"fmt"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/identity" "github.com/pomerium/pomerium/pkg/grpc/identity"
) )
// GetIdentityProviderForID returns the identity provider associated with the given IDP id.
// If none is found the default provider is returned.
func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, error) {
for p := range o.GetAllPolicies() {
idp, err := o.GetIdentityProviderForPolicy(p)
if err != nil {
return nil, err
}
if idp.GetId() == idpID {
return idp, nil
}
}
return o.GetIdentityProviderForPolicy(nil)
}
// GetIdentityProviderForPolicy gets the identity provider associated with the given policy. // GetIdentityProviderForPolicy gets the identity provider associated with the given policy.
// If policy is nil, or changes none of the default settings, the default provider is returned. // If policy is nil, or changes none of the default settings, the default provider is returned.
func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provider, error) { func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provider, error) {
@ -69,3 +55,71 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity
} }
return o.GetIdentityProviderForPolicy(nil) return o.GetIdentityProviderForPolicy(nil)
} }
type IdentityProviderCache struct {
idpsByRouteID map[uint64]*identity.Provider
policiesByRouteID map[uint64]*Policy
idpsByID map[string]*identity.Provider
}
func NewIdentityProviderCache(opts *Options) (*IdentityProviderCache, error) {
rt := &IdentityProviderCache{
idpsByRouteID: make(map[uint64]*identity.Provider, opts.NumPolicies()),
policiesByRouteID: make(map[uint64]*Policy, opts.NumPolicies()),
idpsByID: make(map[string]*identity.Provider),
}
for policy := range opts.GetAllPolicies() {
id, err := policy.RouteID()
if err != nil {
return nil, err
}
idp, err := opts.GetIdentityProviderForPolicy(policy)
if err != nil {
return nil, err
}
rt.idpsByRouteID[id] = idp
rt.policiesByRouteID[id] = policy
if _, ok := rt.idpsByID[idp.Id]; !ok {
rt.idpsByID[idp.Id] = idp
}
}
return rt, nil
}
func (rt *IdentityProviderCache) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provider, error) {
routeID, err := policy.RouteID()
if err != nil {
return nil, err
}
idp, ok := rt.idpsByRouteID[routeID]
if !ok {
return nil, fmt.Errorf("no identity provider found for route %d", routeID)
}
return idp, nil
}
func (rt *IdentityProviderCache) GetIdentityProviderForRouteID(routeID uint64) (*identity.Provider, error) {
idp, ok := rt.idpsByRouteID[routeID]
if !ok {
return nil, fmt.Errorf("no identity provider found for route %d", routeID)
}
return idp, nil
}
func (rt *IdentityProviderCache) GetIdentityProviderByID(idpID string) (*identity.Provider, error) {
idp, ok := rt.idpsByID[idpID]
if !ok {
return nil, fmt.Errorf("no identity provider found for id %s", idpID)
}
return idp, nil
}
func (rt *IdentityProviderCache) GetPolicyByID(routeID uint64) (*Policy, error) {
policy, ok := rt.policiesByRouteID[routeID]
if !ok {
return nil, fmt.Errorf("no policy found for route %d", routeID)
}
return policy, nil
}

View file

@ -15,18 +15,20 @@ import (
// A SessionStore saves and loads sessions based on the options. // A SessionStore saves and loads sessions based on the options.
type SessionStore struct { type SessionStore struct {
store sessions.SessionStore store sessions.SessionStore
loader sessions.SessionLoader loader sessions.SessionLoader
options *Options options *Options
encoder encoding.MarshalUnmarshaler encoder encoding.MarshalUnmarshaler
idpCache *IdentityProviderCache
} }
var _ sessions.SessionStore = (*SessionStore)(nil) var _ sessions.SessionStore = (*SessionStore)(nil)
// NewSessionStore creates a new SessionStore from the Options. // NewSessionStore creates a new SessionStore from the Options.
func NewSessionStore(options *Options) (*SessionStore, error) { func NewSessionStore(options *Options, idpCache *IdentityProviderCache) (*SessionStore, error) {
store := &SessionStore{ store := &SessionStore{
options: options, options: options,
idpCache: idpCache,
} }
sharedKey, err := options.GetSharedKey() sharedKey, err := options.GetSharedKey()
@ -86,7 +88,7 @@ func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, e
} }
// LoadSessionStateAndCheckIDP loads the session state from a request and checks that the idp id matches. // LoadSessionStateAndCheckIDP loads the session state from a request and checks that the idp id matches.
func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request) (*sessions.State, error) { func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request, routeID uint64) (*sessions.State, error) {
state, err := store.LoadSessionState(r) state, err := store.LoadSessionState(r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -94,7 +96,7 @@ func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request) (*sessio
// confirm that the identity provider id matches the state // confirm that the identity provider id matches the state
if state.IdentityProviderID != "" { if state.IdentityProviderID != "" {
idp, err := store.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String()) idp, err := store.idpCache.GetIdentityProviderForRouteID(routeID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -42,7 +42,10 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
}) })
require.NoError(t, options.Validate()) require.NoError(t, options.Validate())
store, err := NewSessionStore(options) idpCache, err := NewIdentityProviderCache(options)
require.NoError(t, err)
store, err := NewSessionStore(options, idpCache)
require.NoError(t, err) require.NoError(t, err)
idp1, err := options.GetIdentityProviderForPolicy(nil) idp1, err := options.GetIdentityProviderForPolicy(nil)
@ -57,6 +60,11 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, idp3) require.NotNil(t, idp3)
policy0Id, err := options.Policies[0].RouteID()
require.NoError(t, err)
policy1Id, err := options.Policies[1].RouteID()
require.NoError(t, err)
makeJWS := func(t *testing.T, state *sessions.State) string { makeJWS := func(t *testing.T, state *sessions.State) string {
e, err := jws.NewHS256Signer(sharedKey) e, err := jws.NewHS256Signer(sharedKey)
require.NoError(t, err) require.NoError(t, err)
@ -70,7 +78,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
t.Run("mssing", func(t *testing.T) { t.Run("mssing", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://p1.example.com", nil) r, err := http.NewRequest(http.MethodGet, "https://p1.example.com", nil)
require.NoError(t, err) require.NoError(t, err)
s, err := store.LoadSessionStateAndCheckIDP(r) s, err := store.LoadSessionStateAndCheckIDP(r, 0)
assert.ErrorIs(t, err, sessions.ErrNoSessionFound) assert.ErrorIs(t, err, sessions.ErrNoSessionFound)
assert.Nil(t, s) assert.Nil(t, s)
}) })
@ -85,7 +93,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
urlutil.QuerySession: {rawJWS}, urlutil.QuerySession: {rawJWS},
}.Encode(), nil) }.Encode(), nil)
require.NoError(t, err) require.NoError(t, err)
s, err := store.LoadSessionStateAndCheckIDP(r) s, err := store.LoadSessionStateAndCheckIDP(r, policy0Id)
assert.NoError(t, err) assert.NoError(t, err)
assert.Empty(t, cmp.Diff(&sessions.State{ assert.Empty(t, cmp.Diff(&sessions.State{
Issuer: "authenticate.example.com", Issuer: "authenticate.example.com",
@ -103,7 +111,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
require.NoError(t, err) require.NoError(t, err)
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
s, err := store.LoadSessionStateAndCheckIDP(r) s, err := store.LoadSessionStateAndCheckIDP(r, policy1Id)
assert.NoError(t, err) assert.NoError(t, err)
assert.Empty(t, cmp.Diff(&sessions.State{ assert.Empty(t, cmp.Diff(&sessions.State{
Issuer: "authenticate.example.com", Issuer: "authenticate.example.com",
@ -121,7 +129,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
require.NoError(t, err) require.NoError(t, err)
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
s, err := store.LoadSessionStateAndCheckIDP(r) s, err := store.LoadSessionStateAndCheckIDP(r, policy1Id)
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, s) assert.Nil(t, s)
}) })
@ -134,7 +142,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
require.NoError(t, err) require.NoError(t, err)
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
s, err := store.LoadSessionStateAndCheckIDP(r) s, err := store.LoadSessionStateAndCheckIDP(r, policy1Id)
assert.NoError(t, err) assert.NoError(t, err)
assert.Empty(t, cmp.Diff(&sessions.State{ assert.Empty(t, cmp.Diff(&sessions.State{
Issuer: "authenticate.example.com", Issuer: "authenticate.example.com",

View file

@ -1,13 +1,14 @@
package authenticate package authenticateflow
import ( import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/identity" "github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/identity/oauth" "github.com/pomerium/pomerium/pkg/identity/oauth"
) )
func defaultGetIdentityProvider(options *config.Options, idpID string) (identity.Authenticator, error) { func NewAuthenticator(options *config.Options, idp *identitypb.Provider) (identity.Authenticator, error) {
authenticateURL, err := options.GetAuthenticateURL() authenticateURL, err := options.GetAuthenticateURL()
if err != nil { if err != nil {
return nil, err return nil, err
@ -19,10 +20,6 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity
} }
redirectURL.Path = options.AuthenticateCallbackPath redirectURL.Path = options.AuthenticateCallbackPath
idp, err := options.GetIdentityProviderForID(idpID)
if err != nil {
return nil, err
}
return identity.NewAuthenticator(oauth.Options{ return identity.NewAuthenticator(oauth.Options{
RedirectURL: redirectURL, RedirectURL: redirectURL,
ProviderName: idp.GetType(), ProviderName: idp.GetType(),
@ -33,3 +30,13 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity
AuthCodeOptions: idp.GetRequestParams(), AuthCodeOptions: idp.GetRequestParams(),
}) })
} }
func IdentityProviderLookupFromCache(idpCache *config.IdentityProviderCache) func(*config.Options, string) (identity.Authenticator, error) {
return func(options *config.Options, idpID string) (identity.Authenticator, error) {
idp, err := idpCache.GetIdentityProviderByID(idpID)
if err != nil {
return nil, err
}
return NewAuthenticator(options, idp)
}
}

View file

@ -26,6 +26,7 @@ type proxyState struct {
sharedKey []byte sharedKey []byte
sessionStore *config.SessionStore sessionStore *config.SessionStore
idpCache *config.IdentityProviderCache
dataBrokerClient databroker.DataBrokerServiceClient dataBrokerClient databroker.DataBrokerServiceClient
programmaticRedirectDomainWhitelist []string programmaticRedirectDomainWhitelist []string
authenticateFlow authenticateFlow authenticateFlow authenticateFlow
@ -47,12 +48,17 @@ func newProxyStateFromConfig(ctx context.Context, cfg *config.Config) (*proxySta
state.authenticateSigninURL = state.authenticateURL.ResolveReference(&url.URL{Path: signinURL}) state.authenticateSigninURL = state.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
state.authenticateRefreshURL = state.authenticateURL.ResolveReference(&url.URL{Path: refreshURL}) state.authenticateRefreshURL = state.authenticateURL.ResolveReference(&url.URL{Path: refreshURL})
state.idpCache, err = config.NewIdentityProviderCache(cfg.Options)
if err != nil {
return nil, err
}
state.sharedKey, err = cfg.Options.GetSharedKey() state.sharedKey, err = cfg.Options.GetSharedKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
state.sessionStore, err = config.NewSessionStore(cfg.Options) state.sessionStore, err = config.NewSessionStore(cfg.Options, state.idpCache)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -71,8 +77,8 @@ func newProxyStateFromConfig(ctx context.Context, cfg *config.Config) (*proxySta
state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist
if cfg.Options.UseStatelessAuthenticateFlow() { if cfg.Options.UseStatelessAuthenticateFlow() {
state.authenticateFlow, err = authenticateflow.NewStateless(ctx, state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, state.sessionStore,
cfg, state.sessionStore, nil, nil, nil) authenticateflow.IdentityProviderLookupFromCache(state.idpCache), nil, nil)
} else { } else {
state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, state.sessionStore) state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, state.sessionStore)
} }