diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 79c6fdf01..bd0370c84 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -9,6 +9,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/internal/authenticateflow" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/cryptutil" ) @@ -52,6 +53,14 @@ func New(ctx context.Context, cfg *config.Config, options ...Option) (*Authentic state: atomicutil.NewValue(newAuthenticateState()), } + if authenticateConfig.getIdentityProvider == nil { + idpCache, err := config.NewIdentityProviderCache(cfg.Options) + if err != nil { + return nil, err + } + authenticateConfig.getIdentityProvider = authenticateflow.IdentityProviderLookupFromCache(idpCache) + } + a.options.Store(cfg.Options) state, err := newAuthenticateStateFromConfig(ctx, cfg, authenticateConfig) diff --git a/authenticate/config.go b/authenticate/config.go index ef293808b..765d2809e 100644 --- a/authenticate/config.go +++ b/authenticate/config.go @@ -18,7 +18,6 @@ type Option func(*authenticateConfig) func getAuthenticateConfig(options ...Option) *authenticateConfig { cfg := new(authenticateConfig) - WithGetIdentityProvider(defaultGetIdentityProvider)(cfg) for _, option := range options { option(cfg) } diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index c8f5342a5..6a4aaf79a 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -407,7 +407,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { w.WriteHeader(http.StatusOK) }) - idp, _ := new(config.Options).GetIdentityProviderForID("") + idp, _ := new(config.Options).GetIdentityProviderForPolicy(nil) tests := []struct { name string diff --git a/authorize/authorize.go b/authorize/authorize.go index d6c2b3ffa..6c6d09f0d 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -53,6 +53,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) { return nil, err } a.state = atomicutil.NewValue(state) + a.currentOptions.Store(cfg.Options) // FIXME: this is refactored out in a different branch return a, nil } diff --git a/authorize/grpc.go b/authorize/grpc.go index a1c51ca1a..fd7f82a73 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -14,7 +14,6 @@ import ( "google.golang.org/protobuf/types/known/structpb" "github.com/pomerium/pomerium/authorize/evaluator" - "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config/envoyconfig" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" @@ -49,8 +48,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe hreq := getHTTPRequestFromCheckRequest(in) requestID := requestid.FromHTTPHeader(hreq.Header) ctx = requestid.WithValue(ctx, requestID) + routeID := envoyconfig.ExtAuthzContextExtensionsRouteID(in.GetAttributes().GetContextExtensions()) - sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq) + sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq, routeID) var s sessionOrServiceAccount var u *user.User @@ -120,23 +120,10 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest( ID: sessionState.ID, } } - req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions())) + req.Policy, _ = a.state.Load().idpCache.GetPolicyByID(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions())) return req, nil } -func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy { - options := a.currentOptions.Load() - - for p := range options.GetAllPolicies() { - id, _ := p.RouteID() - if id == routeID { - return p - } - } - - return nil -} - func getHTTPRequestFromCheckRequest(req *envoy_service_auth_v3.CheckRequest) *http.Request { hattrs := req.GetAttributes().GetRequest().GetHttp() u := getCheckRequestURL(req) diff --git a/authorize/grpc_test.go b/authorize/grpc_test.go index 8d239e4c7..d60d40f5c 100644 --- a/authorize/grpc_test.go +++ b/authorize/grpc_test.go @@ -17,9 +17,10 @@ import ( "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/config/envoyconfig" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" ) @@ -49,15 +50,25 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA== -----END CERTIFICATE-----` func Test_getEvaluatorRequest(t *testing.T) { - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} - a.currentOptions.Store(&config.Options{ - Policies: []config.Policy{{ - From: "https://example.com", - SubPolicies: []config.SubPolicy{{ - Rego: []string{"allow = true"}, - }}, + policies := []config.Policy{{ + From: "https://example.com", + To: mustParseWeightedURLs(t, "https://foo.bar"), + SubPolicies: []config.SubPolicy{{ + Rego: []string{"allow = true"}, }}, + }} + + policy0RouteID, err := policies[0].RouteID() + require.NoError(t, err) + + a, err := New(context.Background(), &config.Config{ + Options: &config.Options{ + SharedKey: cryptutil.NewBase64Key(), + CookieSecret: cryptutil.NewBase64Key(), + Policies: policies, + }, }) + require.NoError(t, err) actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(), &envoy_service_auth_v3.CheckRequest{ @@ -76,6 +87,7 @@ func Test_getEvaluatorRequest(t *testing.T) { Body: "BODY", }, }, + ContextExtensions: envoyconfig.MakeExtAuthzContextExtensions(false, policy0RouteID), MetadataContext: &envoy_config_core_v3.Metadata{ FilterMetadata: map[string]*structpb.Struct{ "com.pomerium.client-certificate-info": { @@ -94,7 +106,7 @@ func Test_getEvaluatorRequest(t *testing.T) { ) require.NoError(t, err) expect := &evaluator.Request{ - Policy: &a.currentOptions.Load().Policies[0], + Policy: &policies[0], Session: evaluator.RequestSession{ ID: "SESSION_ID", }, @@ -117,16 +129,24 @@ func Test_getEvaluatorRequest(t *testing.T) { } func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) { - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} - a.currentOptions.Store(&config.Options{ - Policies: []config.Policy{{ - From: "https://example.com", - SubPolicies: []config.SubPolicy{{ - Rego: []string{"allow = true"}, - }}, + policies := []config.Policy{{ + From: "https://example.com", + To: mustParseWeightedURLs(t, "https://foo.bar"), + SubPolicies: []config.SubPolicy{{ + Rego: []string{"allow = true"}, }}, - }) + }} + policy0RouteID, err := policies[0].RouteID() + require.NoError(t, err) + a, err := New(context.Background(), &config.Config{ + Options: &config.Options{ + SharedKey: cryptutil.NewBase64Key(), + CookieSecret: cryptutil.NewBase64Key(), + Policies: policies, + }, + }) + require.NoError(t, err) actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(), &envoy_service_auth_v3.CheckRequest{ Attributes: &envoy_service_auth_v3.AttributeContext{ @@ -144,11 +164,12 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) { Body: "BODY", }, }, + ContextExtensions: envoyconfig.MakeExtAuthzContextExtensions(false, policy0RouteID), }, }, nil) require.NoError(t, err) expect := &evaluator.Request{ - Policy: &a.currentOptions.Load().Policies[0], + Policy: &policies[0], Session: evaluator.RequestSession{}, HTTP: evaluator.NewRequestHTTP( http.MethodGet, diff --git a/authorize/state.go b/authorize/state.go index a94e0643f..9387db63e 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -29,6 +29,7 @@ type authorizeState struct { dataBrokerClient databroker.DataBrokerServiceClient auditEncryptor *protoutil.Encryptor sessionStore *config.SessionStore + idpCache *config.IdentityProviderCache authenticateFlow authenticateFlow } @@ -79,13 +80,20 @@ func newAuthorizeStateFromConfig( state.auditEncryptor = protoutil.NewEncryptor(auditKey) } - state.sessionStore, err = config.NewSessionStore(cfg.Options) + idpCache, err := config.NewIdentityProviderCache(cfg.Options) + if err != nil { + return nil, err + } + state.idpCache = idpCache + + state.sessionStore, err = config.NewSessionStore(cfg.Options, state.idpCache) if err != nil { return nil, fmt.Errorf("authorize: invalid session store: %w", err) } if cfg.Options.UseStatelessAuthenticateFlow() { - state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, nil, nil, nil, nil) + state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, nil, + authenticateflow.IdentityProviderLookupFromCache(idpCache), nil, nil) } else { state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, nil) } diff --git a/config/identity.go b/config/identity.go index 806a698ec..57e0da421 100644 --- a/config/identity.go +++ b/config/identity.go @@ -1,26 +1,12 @@ package config import ( + "fmt" + "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/grpc/identity" ) -// GetIdentityProviderForID returns the identity provider associated with the given IDP id. -// If none is found the default provider is returned. -func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, error) { - for p := range o.GetAllPolicies() { - idp, err := o.GetIdentityProviderForPolicy(p) - if err != nil { - return nil, err - } - if idp.GetId() == idpID { - return idp, nil - } - } - - return o.GetIdentityProviderForPolicy(nil) -} - // GetIdentityProviderForPolicy gets the identity provider associated with the given policy. // If policy is nil, or changes none of the default settings, the default provider is returned. func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provider, error) { @@ -69,3 +55,71 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity } return o.GetIdentityProviderForPolicy(nil) } + +type IdentityProviderCache struct { + idpsByRouteID map[uint64]*identity.Provider + policiesByRouteID map[uint64]*Policy + idpsByID map[string]*identity.Provider +} + +func NewIdentityProviderCache(opts *Options) (*IdentityProviderCache, error) { + rt := &IdentityProviderCache{ + idpsByRouteID: make(map[uint64]*identity.Provider, opts.NumPolicies()), + policiesByRouteID: make(map[uint64]*Policy, opts.NumPolicies()), + idpsByID: make(map[string]*identity.Provider), + } + + for policy := range opts.GetAllPolicies() { + id, err := policy.RouteID() + if err != nil { + return nil, err + } + idp, err := opts.GetIdentityProviderForPolicy(policy) + if err != nil { + return nil, err + } + rt.idpsByRouteID[id] = idp + rt.policiesByRouteID[id] = policy + + if _, ok := rt.idpsByID[idp.Id]; !ok { + rt.idpsByID[idp.Id] = idp + } + } + return rt, nil +} + +func (rt *IdentityProviderCache) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provider, error) { + routeID, err := policy.RouteID() + if err != nil { + return nil, err + } + idp, ok := rt.idpsByRouteID[routeID] + if !ok { + return nil, fmt.Errorf("no identity provider found for route %d", routeID) + } + return idp, nil +} + +func (rt *IdentityProviderCache) GetIdentityProviderForRouteID(routeID uint64) (*identity.Provider, error) { + idp, ok := rt.idpsByRouteID[routeID] + if !ok { + return nil, fmt.Errorf("no identity provider found for route %d", routeID) + } + return idp, nil +} + +func (rt *IdentityProviderCache) GetIdentityProviderByID(idpID string) (*identity.Provider, error) { + idp, ok := rt.idpsByID[idpID] + if !ok { + return nil, fmt.Errorf("no identity provider found for id %s", idpID) + } + return idp, nil +} + +func (rt *IdentityProviderCache) GetPolicyByID(routeID uint64) (*Policy, error) { + policy, ok := rt.policiesByRouteID[routeID] + if !ok { + return nil, fmt.Errorf("no policy found for route %d", routeID) + } + return policy, nil +} diff --git a/config/session.go b/config/session.go index 95f42bdf0..0a85e8c0c 100644 --- a/config/session.go +++ b/config/session.go @@ -15,18 +15,20 @@ import ( // A SessionStore saves and loads sessions based on the options. type SessionStore struct { - store sessions.SessionStore - loader sessions.SessionLoader - options *Options - encoder encoding.MarshalUnmarshaler + store sessions.SessionStore + loader sessions.SessionLoader + options *Options + encoder encoding.MarshalUnmarshaler + idpCache *IdentityProviderCache } var _ sessions.SessionStore = (*SessionStore)(nil) // NewSessionStore creates a new SessionStore from the Options. -func NewSessionStore(options *Options) (*SessionStore, error) { +func NewSessionStore(options *Options, idpCache *IdentityProviderCache) (*SessionStore, error) { store := &SessionStore{ - options: options, + options: options, + idpCache: idpCache, } sharedKey, err := options.GetSharedKey() @@ -86,7 +88,7 @@ func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, e } // LoadSessionStateAndCheckIDP loads the session state from a request and checks that the idp id matches. -func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request) (*sessions.State, error) { +func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request, routeID uint64) (*sessions.State, error) { state, err := store.LoadSessionState(r) if err != nil { return nil, err @@ -94,7 +96,7 @@ func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request) (*sessio // confirm that the identity provider id matches the state if state.IdentityProviderID != "" { - idp, err := store.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String()) + idp, err := store.idpCache.GetIdentityProviderForRouteID(routeID) if err != nil { return nil, err } diff --git a/config/session_test.go b/config/session_test.go index 936b45c7c..77e71c2f0 100644 --- a/config/session_test.go +++ b/config/session_test.go @@ -42,7 +42,10 @@ func TestSessionStore_LoadSessionState(t *testing.T) { }) require.NoError(t, options.Validate()) - store, err := NewSessionStore(options) + idpCache, err := NewIdentityProviderCache(options) + require.NoError(t, err) + + store, err := NewSessionStore(options, idpCache) require.NoError(t, err) idp1, err := options.GetIdentityProviderForPolicy(nil) @@ -57,6 +60,11 @@ func TestSessionStore_LoadSessionState(t *testing.T) { require.NoError(t, err) require.NotNil(t, idp3) + policy0Id, err := options.Policies[0].RouteID() + require.NoError(t, err) + policy1Id, err := options.Policies[1].RouteID() + require.NoError(t, err) + makeJWS := func(t *testing.T, state *sessions.State) string { e, err := jws.NewHS256Signer(sharedKey) require.NoError(t, err) @@ -70,7 +78,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { t.Run("mssing", func(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://p1.example.com", nil) require.NoError(t, err) - s, err := store.LoadSessionStateAndCheckIDP(r) + s, err := store.LoadSessionStateAndCheckIDP(r, 0) assert.ErrorIs(t, err, sessions.ErrNoSessionFound) assert.Nil(t, s) }) @@ -85,7 +93,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { urlutil.QuerySession: {rawJWS}, }.Encode(), nil) require.NoError(t, err) - s, err := store.LoadSessionStateAndCheckIDP(r) + s, err := store.LoadSessionStateAndCheckIDP(r, policy0Id) assert.NoError(t, err) assert.Empty(t, cmp.Diff(&sessions.State{ Issuer: "authenticate.example.com", @@ -103,7 +111,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) require.NoError(t, err) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) - s, err := store.LoadSessionStateAndCheckIDP(r) + s, err := store.LoadSessionStateAndCheckIDP(r, policy1Id) assert.NoError(t, err) assert.Empty(t, cmp.Diff(&sessions.State{ Issuer: "authenticate.example.com", @@ -121,7 +129,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) require.NoError(t, err) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) - s, err := store.LoadSessionStateAndCheckIDP(r) + s, err := store.LoadSessionStateAndCheckIDP(r, policy1Id) assert.Error(t, err) assert.Nil(t, s) }) @@ -134,7 +142,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) require.NoError(t, err) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) - s, err := store.LoadSessionStateAndCheckIDP(r) + s, err := store.LoadSessionStateAndCheckIDP(r, policy1Id) assert.NoError(t, err) assert.Empty(t, cmp.Diff(&sessions.State{ Issuer: "authenticate.example.com", diff --git a/authenticate/identity.go b/internal/authenticateflow/identity.go similarity index 57% rename from authenticate/identity.go rename to internal/authenticateflow/identity.go index 8ea432e15..2afd17471 100644 --- a/authenticate/identity.go +++ b/internal/authenticateflow/identity.go @@ -1,13 +1,14 @@ -package authenticate +package authenticateflow import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/urlutil" + identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" "github.com/pomerium/pomerium/pkg/identity" "github.com/pomerium/pomerium/pkg/identity/oauth" ) -func defaultGetIdentityProvider(options *config.Options, idpID string) (identity.Authenticator, error) { +func NewAuthenticator(options *config.Options, idp *identitypb.Provider) (identity.Authenticator, error) { authenticateURL, err := options.GetAuthenticateURL() if err != nil { return nil, err @@ -19,10 +20,6 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity } redirectURL.Path = options.AuthenticateCallbackPath - idp, err := options.GetIdentityProviderForID(idpID) - if err != nil { - return nil, err - } return identity.NewAuthenticator(oauth.Options{ RedirectURL: redirectURL, ProviderName: idp.GetType(), @@ -33,3 +30,13 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity AuthCodeOptions: idp.GetRequestParams(), }) } + +func IdentityProviderLookupFromCache(idpCache *config.IdentityProviderCache) func(*config.Options, string) (identity.Authenticator, error) { + return func(options *config.Options, idpID string) (identity.Authenticator, error) { + idp, err := idpCache.GetIdentityProviderByID(idpID) + if err != nil { + return nil, err + } + return NewAuthenticator(options, idp) + } +} diff --git a/proxy/state.go b/proxy/state.go index 5a7727e13..a32b19948 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -26,6 +26,7 @@ type proxyState struct { sharedKey []byte sessionStore *config.SessionStore + idpCache *config.IdentityProviderCache dataBrokerClient databroker.DataBrokerServiceClient programmaticRedirectDomainWhitelist []string authenticateFlow authenticateFlow @@ -47,12 +48,17 @@ func newProxyStateFromConfig(ctx context.Context, cfg *config.Config) (*proxySta state.authenticateSigninURL = state.authenticateURL.ResolveReference(&url.URL{Path: signinURL}) state.authenticateRefreshURL = state.authenticateURL.ResolveReference(&url.URL{Path: refreshURL}) + state.idpCache, err = config.NewIdentityProviderCache(cfg.Options) + if err != nil { + return nil, err + } + state.sharedKey, err = cfg.Options.GetSharedKey() if err != nil { return nil, err } - state.sessionStore, err = config.NewSessionStore(cfg.Options) + state.sessionStore, err = config.NewSessionStore(cfg.Options, state.idpCache) if err != nil { return nil, err } @@ -71,8 +77,8 @@ func newProxyStateFromConfig(ctx context.Context, cfg *config.Config) (*proxySta state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist if cfg.Options.UseStatelessAuthenticateFlow() { - state.authenticateFlow, err = authenticateflow.NewStateless(ctx, - cfg, state.sessionStore, nil, nil, nil) + state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, state.sessionStore, + authenticateflow.IdentityProviderLookupFromCache(state.idpCache), nil, nil) } else { state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, state.sessionStore) }