mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 02:16:28 +02:00
Optimize identity provider lookup with new cache
This commit is contained in:
parent
c8b6b8f1a9
commit
8df3028533
12 changed files with 179 additions and 77 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
@ -52,6 +53,14 @@ func New(ctx context.Context, cfg *config.Config, options ...Option) (*Authentic
|
|||
state: atomicutil.NewValue(newAuthenticateState()),
|
||||
}
|
||||
|
||||
if authenticateConfig.getIdentityProvider == nil {
|
||||
idpCache, err := config.NewIdentityProviderCache(cfg.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authenticateConfig.getIdentityProvider = authenticateflow.IdentityProviderLookupFromCache(idpCache)
|
||||
}
|
||||
|
||||
a.options.Store(cfg.Options)
|
||||
|
||||
state, err := newAuthenticateStateFromConfig(ctx, cfg, authenticateConfig)
|
||||
|
|
|
@ -18,7 +18,6 @@ type Option func(*authenticateConfig)
|
|||
|
||||
func getAuthenticateConfig(options ...Option) *authenticateConfig {
|
||||
cfg := new(authenticateConfig)
|
||||
WithGetIdentityProvider(defaultGetIdentityProvider)(cfg)
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
|
|
|
@ -407,7 +407,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
idp, _ := new(config.Options).GetIdentityProviderForID("")
|
||||
idp, _ := new(config.Options).GetIdentityProviderForPolicy(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
@ -53,6 +53,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
|||
return nil, err
|
||||
}
|
||||
a.state = atomicutil.NewValue(state)
|
||||
a.currentOptions.Store(cfg.Options) // FIXME: this is refactored out in a different branch
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -49,8 +48,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
hreq := getHTTPRequestFromCheckRequest(in)
|
||||
requestID := requestid.FromHTTPHeader(hreq.Header)
|
||||
ctx = requestid.WithValue(ctx, requestID)
|
||||
routeID := envoyconfig.ExtAuthzContextExtensionsRouteID(in.GetAttributes().GetContextExtensions())
|
||||
|
||||
sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq)
|
||||
sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq, routeID)
|
||||
|
||||
var s sessionOrServiceAccount
|
||||
var u *user.User
|
||||
|
@ -120,23 +120,10 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
|
|||
ID: sessionState.ID,
|
||||
}
|
||||
}
|
||||
req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()))
|
||||
req.Policy, _ = a.state.Load().idpCache.GetPolicyByID(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()))
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy {
|
||||
options := a.currentOptions.Load()
|
||||
|
||||
for p := range options.GetAllPolicies() {
|
||||
id, _ := p.RouteID()
|
||||
if id == routeID {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHTTPRequestFromCheckRequest(req *envoy_service_auth_v3.CheckRequest) *http.Request {
|
||||
hattrs := req.GetAttributes().GetRequest().GetHttp()
|
||||
u := getCheckRequestURL(req)
|
||||
|
|
|
@ -17,9 +17,10 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
@ -49,15 +50,25 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
|
|||
-----END CERTIFICATE-----`
|
||||
|
||||
func Test_getEvaluatorRequest(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(&config.Options{
|
||||
Policies: []config.Policy{{
|
||||
From: "https://example.com",
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
policies := []config.Policy{{
|
||||
From: "https://example.com",
|
||||
To: mustParseWeightedURLs(t, "https://foo.bar"),
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
}}
|
||||
|
||||
policy0RouteID, err := policies[0].RouteID()
|
||||
require.NoError(t, err)
|
||||
|
||||
a, err := New(context.Background(), &config.Config{
|
||||
Options: &config.Options{
|
||||
SharedKey: cryptutil.NewBase64Key(),
|
||||
CookieSecret: cryptutil.NewBase64Key(),
|
||||
Policies: policies,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
|
||||
&envoy_service_auth_v3.CheckRequest{
|
||||
|
@ -76,6 +87,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
|
|||
Body: "BODY",
|
||||
},
|
||||
},
|
||||
ContextExtensions: envoyconfig.MakeExtAuthzContextExtensions(false, policy0RouteID),
|
||||
MetadataContext: &envoy_config_core_v3.Metadata{
|
||||
FilterMetadata: map[string]*structpb.Struct{
|
||||
"com.pomerium.client-certificate-info": {
|
||||
|
@ -94,7 +106,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
|
|||
)
|
||||
require.NoError(t, err)
|
||||
expect := &evaluator.Request{
|
||||
Policy: &a.currentOptions.Load().Policies[0],
|
||||
Policy: &policies[0],
|
||||
Session: evaluator.RequestSession{
|
||||
ID: "SESSION_ID",
|
||||
},
|
||||
|
@ -117,16 +129,24 @@ func Test_getEvaluatorRequest(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(&config.Options{
|
||||
Policies: []config.Policy{{
|
||||
From: "https://example.com",
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
policies := []config.Policy{{
|
||||
From: "https://example.com",
|
||||
To: mustParseWeightedURLs(t, "https://foo.bar"),
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
})
|
||||
}}
|
||||
policy0RouteID, err := policies[0].RouteID()
|
||||
require.NoError(t, err)
|
||||
|
||||
a, err := New(context.Background(), &config.Config{
|
||||
Options: &config.Options{
|
||||
SharedKey: cryptutil.NewBase64Key(),
|
||||
CookieSecret: cryptutil.NewBase64Key(),
|
||||
Policies: policies,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
|
||||
&envoy_service_auth_v3.CheckRequest{
|
||||
Attributes: &envoy_service_auth_v3.AttributeContext{
|
||||
|
@ -144,11 +164,12 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
|||
Body: "BODY",
|
||||
},
|
||||
},
|
||||
ContextExtensions: envoyconfig.MakeExtAuthzContextExtensions(false, policy0RouteID),
|
||||
},
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
expect := &evaluator.Request{
|
||||
Policy: &a.currentOptions.Load().Policies[0],
|
||||
Policy: &policies[0],
|
||||
Session: evaluator.RequestSession{},
|
||||
HTTP: evaluator.NewRequestHTTP(
|
||||
http.MethodGet,
|
||||
|
|
|
@ -29,6 +29,7 @@ type authorizeState struct {
|
|||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
auditEncryptor *protoutil.Encryptor
|
||||
sessionStore *config.SessionStore
|
||||
idpCache *config.IdentityProviderCache
|
||||
authenticateFlow authenticateFlow
|
||||
}
|
||||
|
||||
|
@ -79,13 +80,20 @@ func newAuthorizeStateFromConfig(
|
|||
state.auditEncryptor = protoutil.NewEncryptor(auditKey)
|
||||
}
|
||||
|
||||
state.sessionStore, err = config.NewSessionStore(cfg.Options)
|
||||
idpCache, err := config.NewIdentityProviderCache(cfg.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state.idpCache = idpCache
|
||||
|
||||
state.sessionStore, err = config.NewSessionStore(cfg.Options, state.idpCache)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: invalid session store: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Options.UseStatelessAuthenticateFlow() {
|
||||
state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, nil, nil, nil, nil)
|
||||
state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, nil,
|
||||
authenticateflow.IdentityProviderLookupFromCache(idpCache), nil, nil)
|
||||
} else {
|
||||
state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, nil)
|
||||
}
|
||||
|
|
|
@ -1,26 +1,12 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
)
|
||||
|
||||
// GetIdentityProviderForID returns the identity provider associated with the given IDP id.
|
||||
// If none is found the default provider is returned.
|
||||
func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, error) {
|
||||
for p := range o.GetAllPolicies() {
|
||||
idp, err := o.GetIdentityProviderForPolicy(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if idp.GetId() == idpID {
|
||||
return idp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return o.GetIdentityProviderForPolicy(nil)
|
||||
}
|
||||
|
||||
// GetIdentityProviderForPolicy gets the identity provider associated with the given policy.
|
||||
// If policy is nil, or changes none of the default settings, the default provider is returned.
|
||||
func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provider, error) {
|
||||
|
@ -69,3 +55,71 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity
|
|||
}
|
||||
return o.GetIdentityProviderForPolicy(nil)
|
||||
}
|
||||
|
||||
type IdentityProviderCache struct {
|
||||
idpsByRouteID map[uint64]*identity.Provider
|
||||
policiesByRouteID map[uint64]*Policy
|
||||
idpsByID map[string]*identity.Provider
|
||||
}
|
||||
|
||||
func NewIdentityProviderCache(opts *Options) (*IdentityProviderCache, error) {
|
||||
rt := &IdentityProviderCache{
|
||||
idpsByRouteID: make(map[uint64]*identity.Provider, opts.NumPolicies()),
|
||||
policiesByRouteID: make(map[uint64]*Policy, opts.NumPolicies()),
|
||||
idpsByID: make(map[string]*identity.Provider),
|
||||
}
|
||||
|
||||
for policy := range opts.GetAllPolicies() {
|
||||
id, err := policy.RouteID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idp, err := opts.GetIdentityProviderForPolicy(policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rt.idpsByRouteID[id] = idp
|
||||
rt.policiesByRouteID[id] = policy
|
||||
|
||||
if _, ok := rt.idpsByID[idp.Id]; !ok {
|
||||
rt.idpsByID[idp.Id] = idp
|
||||
}
|
||||
}
|
||||
return rt, nil
|
||||
}
|
||||
|
||||
func (rt *IdentityProviderCache) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provider, error) {
|
||||
routeID, err := policy.RouteID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idp, ok := rt.idpsByRouteID[routeID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no identity provider found for route %d", routeID)
|
||||
}
|
||||
return idp, nil
|
||||
}
|
||||
|
||||
func (rt *IdentityProviderCache) GetIdentityProviderForRouteID(routeID uint64) (*identity.Provider, error) {
|
||||
idp, ok := rt.idpsByRouteID[routeID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no identity provider found for route %d", routeID)
|
||||
}
|
||||
return idp, nil
|
||||
}
|
||||
|
||||
func (rt *IdentityProviderCache) GetIdentityProviderByID(idpID string) (*identity.Provider, error) {
|
||||
idp, ok := rt.idpsByID[idpID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no identity provider found for id %s", idpID)
|
||||
}
|
||||
return idp, nil
|
||||
}
|
||||
|
||||
func (rt *IdentityProviderCache) GetPolicyByID(routeID uint64) (*Policy, error) {
|
||||
policy, ok := rt.policiesByRouteID[routeID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no policy found for route %d", routeID)
|
||||
}
|
||||
return policy, nil
|
||||
}
|
||||
|
|
|
@ -15,18 +15,20 @@ import (
|
|||
|
||||
// A SessionStore saves and loads sessions based on the options.
|
||||
type SessionStore struct {
|
||||
store sessions.SessionStore
|
||||
loader sessions.SessionLoader
|
||||
options *Options
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
store sessions.SessionStore
|
||||
loader sessions.SessionLoader
|
||||
options *Options
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
idpCache *IdentityProviderCache
|
||||
}
|
||||
|
||||
var _ sessions.SessionStore = (*SessionStore)(nil)
|
||||
|
||||
// NewSessionStore creates a new SessionStore from the Options.
|
||||
func NewSessionStore(options *Options) (*SessionStore, error) {
|
||||
func NewSessionStore(options *Options, idpCache *IdentityProviderCache) (*SessionStore, error) {
|
||||
store := &SessionStore{
|
||||
options: options,
|
||||
options: options,
|
||||
idpCache: idpCache,
|
||||
}
|
||||
|
||||
sharedKey, err := options.GetSharedKey()
|
||||
|
@ -86,7 +88,7 @@ func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, e
|
|||
}
|
||||
|
||||
// LoadSessionStateAndCheckIDP loads the session state from a request and checks that the idp id matches.
|
||||
func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request) (*sessions.State, error) {
|
||||
func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request, routeID uint64) (*sessions.State, error) {
|
||||
state, err := store.LoadSessionState(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -94,7 +96,7 @@ func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request) (*sessio
|
|||
|
||||
// confirm that the identity provider id matches the state
|
||||
if state.IdentityProviderID != "" {
|
||||
idp, err := store.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
|
||||
idp, err := store.idpCache.GetIdentityProviderForRouteID(routeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -42,7 +42,10 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, options.Validate())
|
||||
|
||||
store, err := NewSessionStore(options)
|
||||
idpCache, err := NewIdentityProviderCache(options)
|
||||
require.NoError(t, err)
|
||||
|
||||
store, err := NewSessionStore(options, idpCache)
|
||||
require.NoError(t, err)
|
||||
|
||||
idp1, err := options.GetIdentityProviderForPolicy(nil)
|
||||
|
@ -57,6 +60,11 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.NotNil(t, idp3)
|
||||
|
||||
policy0Id, err := options.Policies[0].RouteID()
|
||||
require.NoError(t, err)
|
||||
policy1Id, err := options.Policies[1].RouteID()
|
||||
require.NoError(t, err)
|
||||
|
||||
makeJWS := func(t *testing.T, state *sessions.State) string {
|
||||
e, err := jws.NewHS256Signer(sharedKey)
|
||||
require.NoError(t, err)
|
||||
|
@ -70,7 +78,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
|
|||
t.Run("mssing", func(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodGet, "https://p1.example.com", nil)
|
||||
require.NoError(t, err)
|
||||
s, err := store.LoadSessionStateAndCheckIDP(r)
|
||||
s, err := store.LoadSessionStateAndCheckIDP(r, 0)
|
||||
assert.ErrorIs(t, err, sessions.ErrNoSessionFound)
|
||||
assert.Nil(t, s)
|
||||
})
|
||||
|
@ -85,7 +93,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
|
|||
urlutil.QuerySession: {rawJWS},
|
||||
}.Encode(), nil)
|
||||
require.NoError(t, err)
|
||||
s, err := store.LoadSessionStateAndCheckIDP(r)
|
||||
s, err := store.LoadSessionStateAndCheckIDP(r, policy0Id)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff(&sessions.State{
|
||||
Issuer: "authenticate.example.com",
|
||||
|
@ -103,7 +111,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
|
|||
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
|
||||
require.NoError(t, err)
|
||||
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
|
||||
s, err := store.LoadSessionStateAndCheckIDP(r)
|
||||
s, err := store.LoadSessionStateAndCheckIDP(r, policy1Id)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff(&sessions.State{
|
||||
Issuer: "authenticate.example.com",
|
||||
|
@ -121,7 +129,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
|
|||
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
|
||||
require.NoError(t, err)
|
||||
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
|
||||
s, err := store.LoadSessionStateAndCheckIDP(r)
|
||||
s, err := store.LoadSessionStateAndCheckIDP(r, policy1Id)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, s)
|
||||
})
|
||||
|
@ -134,7 +142,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
|
|||
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
|
||||
require.NoError(t, err)
|
||||
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
|
||||
s, err := store.LoadSessionStateAndCheckIDP(r)
|
||||
s, err := store.LoadSessionStateAndCheckIDP(r, policy1Id)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff(&sessions.State{
|
||||
Issuer: "authenticate.example.com",
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
package authenticate
|
||||
package authenticateflow
|
||||
|
||||
import (
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
"github.com/pomerium/pomerium/pkg/identity"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
)
|
||||
|
||||
func defaultGetIdentityProvider(options *config.Options, idpID string) (identity.Authenticator, error) {
|
||||
func NewAuthenticator(options *config.Options, idp *identitypb.Provider) (identity.Authenticator, error) {
|
||||
authenticateURL, err := options.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -19,10 +20,6 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity
|
|||
}
|
||||
redirectURL.Path = options.AuthenticateCallbackPath
|
||||
|
||||
idp, err := options.GetIdentityProviderForID(idpID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return identity.NewAuthenticator(oauth.Options{
|
||||
RedirectURL: redirectURL,
|
||||
ProviderName: idp.GetType(),
|
||||
|
@ -33,3 +30,13 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity
|
|||
AuthCodeOptions: idp.GetRequestParams(),
|
||||
})
|
||||
}
|
||||
|
||||
func IdentityProviderLookupFromCache(idpCache *config.IdentityProviderCache) func(*config.Options, string) (identity.Authenticator, error) {
|
||||
return func(options *config.Options, idpID string) (identity.Authenticator, error) {
|
||||
idp, err := idpCache.GetIdentityProviderByID(idpID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewAuthenticator(options, idp)
|
||||
}
|
||||
}
|
|
@ -26,6 +26,7 @@ type proxyState struct {
|
|||
|
||||
sharedKey []byte
|
||||
sessionStore *config.SessionStore
|
||||
idpCache *config.IdentityProviderCache
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
programmaticRedirectDomainWhitelist []string
|
||||
authenticateFlow authenticateFlow
|
||||
|
@ -47,12 +48,17 @@ func newProxyStateFromConfig(ctx context.Context, cfg *config.Config) (*proxySta
|
|||
state.authenticateSigninURL = state.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
|
||||
state.authenticateRefreshURL = state.authenticateURL.ResolveReference(&url.URL{Path: refreshURL})
|
||||
|
||||
state.idpCache, err = config.NewIdentityProviderCache(cfg.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.sharedKey, err = cfg.Options.GetSharedKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.sessionStore, err = config.NewSessionStore(cfg.Options)
|
||||
state.sessionStore, err = config.NewSessionStore(cfg.Options, state.idpCache)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -71,8 +77,8 @@ func newProxyStateFromConfig(ctx context.Context, cfg *config.Config) (*proxySta
|
|||
state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist
|
||||
|
||||
if cfg.Options.UseStatelessAuthenticateFlow() {
|
||||
state.authenticateFlow, err = authenticateflow.NewStateless(ctx,
|
||||
cfg, state.sessionStore, nil, nil, nil)
|
||||
state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, state.sessionStore,
|
||||
authenticateflow.IdentityProviderLookupFromCache(state.idpCache), nil, nil)
|
||||
} else {
|
||||
state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, state.sessionStore)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue