Optimize identity provider lookup with new cache

This commit is contained in:
Joe Kralicky 2024-11-04 15:20:31 -05:00
parent c8b6b8f1a9
commit 8df3028533
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
12 changed files with 179 additions and 77 deletions

View file

@ -9,6 +9,7 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/authenticateflow"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
@ -52,6 +53,14 @@ func New(ctx context.Context, cfg *config.Config, options ...Option) (*Authentic
state: atomicutil.NewValue(newAuthenticateState()),
}
if authenticateConfig.getIdentityProvider == nil {
idpCache, err := config.NewIdentityProviderCache(cfg.Options)
if err != nil {
return nil, err
}
authenticateConfig.getIdentityProvider = authenticateflow.IdentityProviderLookupFromCache(idpCache)
}
a.options.Store(cfg.Options)
state, err := newAuthenticateStateFromConfig(ctx, cfg, authenticateConfig)

View file

@ -18,7 +18,6 @@ type Option func(*authenticateConfig)
func getAuthenticateConfig(options ...Option) *authenticateConfig {
cfg := new(authenticateConfig)
WithGetIdentityProvider(defaultGetIdentityProvider)(cfg)
for _, option := range options {
option(cfg)
}

View file

@ -407,7 +407,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
w.WriteHeader(http.StatusOK)
})
idp, _ := new(config.Options).GetIdentityProviderForID("")
idp, _ := new(config.Options).GetIdentityProviderForPolicy(nil)
tests := []struct {
name string

View file

@ -53,6 +53,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
return nil, err
}
a.state = atomicutil.NewValue(state)
a.currentOptions.Store(cfg.Options) // FIXME: this is refactored out in a different branch
return a, nil
}

View file

@ -14,7 +14,6 @@ import (
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
@ -49,8 +48,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
hreq := getHTTPRequestFromCheckRequest(in)
requestID := requestid.FromHTTPHeader(hreq.Header)
ctx = requestid.WithValue(ctx, requestID)
routeID := envoyconfig.ExtAuthzContextExtensionsRouteID(in.GetAttributes().GetContextExtensions())
sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq)
sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq, routeID)
var s sessionOrServiceAccount
var u *user.User
@ -120,23 +120,10 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
ID: sessionState.ID,
}
}
req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()))
req.Policy, _ = a.state.Load().idpCache.GetPolicyByID(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()))
return req, nil
}
func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy {
options := a.currentOptions.Load()
for p := range options.GetAllPolicies() {
id, _ := p.RouteID()
if id == routeID {
return p
}
}
return nil
}
func getHTTPRequestFromCheckRequest(req *envoy_service_auth_v3.CheckRequest) *http.Request {
hattrs := req.GetAttributes().GetRequest().GetHttp()
u := getCheckRequestURL(req)

View file

@ -17,9 +17,10 @@ import (
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/config/envoyconfig"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
@ -49,15 +50,25 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
-----END CERTIFICATE-----`
func Test_getEvaluatorRequest(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(&config.Options{
Policies: []config.Policy{{
From: "https://example.com",
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
policies := []config.Policy{{
From: "https://example.com",
To: mustParseWeightedURLs(t, "https://foo.bar"),
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
}}
policy0RouteID, err := policies[0].RouteID()
require.NoError(t, err)
a, err := New(context.Background(), &config.Config{
Options: &config.Options{
SharedKey: cryptutil.NewBase64Key(),
CookieSecret: cryptutil.NewBase64Key(),
Policies: policies,
},
})
require.NoError(t, err)
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
&envoy_service_auth_v3.CheckRequest{
@ -76,6 +87,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
Body: "BODY",
},
},
ContextExtensions: envoyconfig.MakeExtAuthzContextExtensions(false, policy0RouteID),
MetadataContext: &envoy_config_core_v3.Metadata{
FilterMetadata: map[string]*structpb.Struct{
"com.pomerium.client-certificate-info": {
@ -94,7 +106,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
)
require.NoError(t, err)
expect := &evaluator.Request{
Policy: &a.currentOptions.Load().Policies[0],
Policy: &policies[0],
Session: evaluator.RequestSession{
ID: "SESSION_ID",
},
@ -117,16 +129,24 @@ func Test_getEvaluatorRequest(t *testing.T) {
}
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(&config.Options{
Policies: []config.Policy{{
From: "https://example.com",
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
policies := []config.Policy{{
From: "https://example.com",
To: mustParseWeightedURLs(t, "https://foo.bar"),
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
})
}}
policy0RouteID, err := policies[0].RouteID()
require.NoError(t, err)
a, err := New(context.Background(), &config.Config{
Options: &config.Options{
SharedKey: cryptutil.NewBase64Key(),
CookieSecret: cryptutil.NewBase64Key(),
Policies: policies,
},
})
require.NoError(t, err)
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
&envoy_service_auth_v3.CheckRequest{
Attributes: &envoy_service_auth_v3.AttributeContext{
@ -144,11 +164,12 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
Body: "BODY",
},
},
ContextExtensions: envoyconfig.MakeExtAuthzContextExtensions(false, policy0RouteID),
},
}, nil)
require.NoError(t, err)
expect := &evaluator.Request{
Policy: &a.currentOptions.Load().Policies[0],
Policy: &policies[0],
Session: evaluator.RequestSession{},
HTTP: evaluator.NewRequestHTTP(
http.MethodGet,

View file

@ -29,6 +29,7 @@ type authorizeState struct {
dataBrokerClient databroker.DataBrokerServiceClient
auditEncryptor *protoutil.Encryptor
sessionStore *config.SessionStore
idpCache *config.IdentityProviderCache
authenticateFlow authenticateFlow
}
@ -79,13 +80,20 @@ func newAuthorizeStateFromConfig(
state.auditEncryptor = protoutil.NewEncryptor(auditKey)
}
state.sessionStore, err = config.NewSessionStore(cfg.Options)
idpCache, err := config.NewIdentityProviderCache(cfg.Options)
if err != nil {
return nil, err
}
state.idpCache = idpCache
state.sessionStore, err = config.NewSessionStore(cfg.Options, state.idpCache)
if err != nil {
return nil, fmt.Errorf("authorize: invalid session store: %w", err)
}
if cfg.Options.UseStatelessAuthenticateFlow() {
state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, nil, nil, nil, nil)
state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, nil,
authenticateflow.IdentityProviderLookupFromCache(idpCache), nil, nil)
} else {
state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, nil)
}

View file

@ -1,26 +1,12 @@
package config
import (
"fmt"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/identity"
)
// GetIdentityProviderForID returns the identity provider associated with the given IDP id.
// If none is found the default provider is returned.
func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, error) {
for p := range o.GetAllPolicies() {
idp, err := o.GetIdentityProviderForPolicy(p)
if err != nil {
return nil, err
}
if idp.GetId() == idpID {
return idp, nil
}
}
return o.GetIdentityProviderForPolicy(nil)
}
// GetIdentityProviderForPolicy gets the identity provider associated with the given policy.
// If policy is nil, or changes none of the default settings, the default provider is returned.
func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provider, error) {
@ -69,3 +55,71 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity
}
return o.GetIdentityProviderForPolicy(nil)
}
type IdentityProviderCache struct {
idpsByRouteID map[uint64]*identity.Provider
policiesByRouteID map[uint64]*Policy
idpsByID map[string]*identity.Provider
}
func NewIdentityProviderCache(opts *Options) (*IdentityProviderCache, error) {
rt := &IdentityProviderCache{
idpsByRouteID: make(map[uint64]*identity.Provider, opts.NumPolicies()),
policiesByRouteID: make(map[uint64]*Policy, opts.NumPolicies()),
idpsByID: make(map[string]*identity.Provider),
}
for policy := range opts.GetAllPolicies() {
id, err := policy.RouteID()
if err != nil {
return nil, err
}
idp, err := opts.GetIdentityProviderForPolicy(policy)
if err != nil {
return nil, err
}
rt.idpsByRouteID[id] = idp
rt.policiesByRouteID[id] = policy
if _, ok := rt.idpsByID[idp.Id]; !ok {
rt.idpsByID[idp.Id] = idp
}
}
return rt, nil
}
func (rt *IdentityProviderCache) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provider, error) {
routeID, err := policy.RouteID()
if err != nil {
return nil, err
}
idp, ok := rt.idpsByRouteID[routeID]
if !ok {
return nil, fmt.Errorf("no identity provider found for route %d", routeID)
}
return idp, nil
}
func (rt *IdentityProviderCache) GetIdentityProviderForRouteID(routeID uint64) (*identity.Provider, error) {
idp, ok := rt.idpsByRouteID[routeID]
if !ok {
return nil, fmt.Errorf("no identity provider found for route %d", routeID)
}
return idp, nil
}
func (rt *IdentityProviderCache) GetIdentityProviderByID(idpID string) (*identity.Provider, error) {
idp, ok := rt.idpsByID[idpID]
if !ok {
return nil, fmt.Errorf("no identity provider found for id %s", idpID)
}
return idp, nil
}
func (rt *IdentityProviderCache) GetPolicyByID(routeID uint64) (*Policy, error) {
policy, ok := rt.policiesByRouteID[routeID]
if !ok {
return nil, fmt.Errorf("no policy found for route %d", routeID)
}
return policy, nil
}

View file

@ -15,18 +15,20 @@ import (
// A SessionStore saves and loads sessions based on the options.
type SessionStore struct {
store sessions.SessionStore
loader sessions.SessionLoader
options *Options
encoder encoding.MarshalUnmarshaler
store sessions.SessionStore
loader sessions.SessionLoader
options *Options
encoder encoding.MarshalUnmarshaler
idpCache *IdentityProviderCache
}
var _ sessions.SessionStore = (*SessionStore)(nil)
// NewSessionStore creates a new SessionStore from the Options.
func NewSessionStore(options *Options) (*SessionStore, error) {
func NewSessionStore(options *Options, idpCache *IdentityProviderCache) (*SessionStore, error) {
store := &SessionStore{
options: options,
options: options,
idpCache: idpCache,
}
sharedKey, err := options.GetSharedKey()
@ -86,7 +88,7 @@ func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, e
}
// LoadSessionStateAndCheckIDP loads the session state from a request and checks that the idp id matches.
func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request) (*sessions.State, error) {
func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request, routeID uint64) (*sessions.State, error) {
state, err := store.LoadSessionState(r)
if err != nil {
return nil, err
@ -94,7 +96,7 @@ func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request) (*sessio
// confirm that the identity provider id matches the state
if state.IdentityProviderID != "" {
idp, err := store.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
idp, err := store.idpCache.GetIdentityProviderForRouteID(routeID)
if err != nil {
return nil, err
}

View file

@ -42,7 +42,10 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
})
require.NoError(t, options.Validate())
store, err := NewSessionStore(options)
idpCache, err := NewIdentityProviderCache(options)
require.NoError(t, err)
store, err := NewSessionStore(options, idpCache)
require.NoError(t, err)
idp1, err := options.GetIdentityProviderForPolicy(nil)
@ -57,6 +60,11 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, idp3)
policy0Id, err := options.Policies[0].RouteID()
require.NoError(t, err)
policy1Id, err := options.Policies[1].RouteID()
require.NoError(t, err)
makeJWS := func(t *testing.T, state *sessions.State) string {
e, err := jws.NewHS256Signer(sharedKey)
require.NoError(t, err)
@ -70,7 +78,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
t.Run("mssing", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://p1.example.com", nil)
require.NoError(t, err)
s, err := store.LoadSessionStateAndCheckIDP(r)
s, err := store.LoadSessionStateAndCheckIDP(r, 0)
assert.ErrorIs(t, err, sessions.ErrNoSessionFound)
assert.Nil(t, s)
})
@ -85,7 +93,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
urlutil.QuerySession: {rawJWS},
}.Encode(), nil)
require.NoError(t, err)
s, err := store.LoadSessionStateAndCheckIDP(r)
s, err := store.LoadSessionStateAndCheckIDP(r, policy0Id)
assert.NoError(t, err)
assert.Empty(t, cmp.Diff(&sessions.State{
Issuer: "authenticate.example.com",
@ -103,7 +111,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
require.NoError(t, err)
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
s, err := store.LoadSessionStateAndCheckIDP(r)
s, err := store.LoadSessionStateAndCheckIDP(r, policy1Id)
assert.NoError(t, err)
assert.Empty(t, cmp.Diff(&sessions.State{
Issuer: "authenticate.example.com",
@ -121,7 +129,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
require.NoError(t, err)
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
s, err := store.LoadSessionStateAndCheckIDP(r)
s, err := store.LoadSessionStateAndCheckIDP(r, policy1Id)
assert.Error(t, err)
assert.Nil(t, s)
})
@ -134,7 +142,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil)
require.NoError(t, err)
r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS)
s, err := store.LoadSessionStateAndCheckIDP(r)
s, err := store.LoadSessionStateAndCheckIDP(r, policy1Id)
assert.NoError(t, err)
assert.Empty(t, cmp.Diff(&sessions.State{
Issuer: "authenticate.example.com",

View file

@ -1,13 +1,14 @@
package authenticate
package authenticateflow
import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/urlutil"
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/identity"
"github.com/pomerium/pomerium/pkg/identity/oauth"
)
func defaultGetIdentityProvider(options *config.Options, idpID string) (identity.Authenticator, error) {
func NewAuthenticator(options *config.Options, idp *identitypb.Provider) (identity.Authenticator, error) {
authenticateURL, err := options.GetAuthenticateURL()
if err != nil {
return nil, err
@ -19,10 +20,6 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity
}
redirectURL.Path = options.AuthenticateCallbackPath
idp, err := options.GetIdentityProviderForID(idpID)
if err != nil {
return nil, err
}
return identity.NewAuthenticator(oauth.Options{
RedirectURL: redirectURL,
ProviderName: idp.GetType(),
@ -33,3 +30,13 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity
AuthCodeOptions: idp.GetRequestParams(),
})
}
func IdentityProviderLookupFromCache(idpCache *config.IdentityProviderCache) func(*config.Options, string) (identity.Authenticator, error) {
return func(options *config.Options, idpID string) (identity.Authenticator, error) {
idp, err := idpCache.GetIdentityProviderByID(idpID)
if err != nil {
return nil, err
}
return NewAuthenticator(options, idp)
}
}

View file

@ -26,6 +26,7 @@ type proxyState struct {
sharedKey []byte
sessionStore *config.SessionStore
idpCache *config.IdentityProviderCache
dataBrokerClient databroker.DataBrokerServiceClient
programmaticRedirectDomainWhitelist []string
authenticateFlow authenticateFlow
@ -47,12 +48,17 @@ func newProxyStateFromConfig(ctx context.Context, cfg *config.Config) (*proxySta
state.authenticateSigninURL = state.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
state.authenticateRefreshURL = state.authenticateURL.ResolveReference(&url.URL{Path: refreshURL})
state.idpCache, err = config.NewIdentityProviderCache(cfg.Options)
if err != nil {
return nil, err
}
state.sharedKey, err = cfg.Options.GetSharedKey()
if err != nil {
return nil, err
}
state.sessionStore, err = config.NewSessionStore(cfg.Options)
state.sessionStore, err = config.NewSessionStore(cfg.Options, state.idpCache)
if err != nil {
return nil, err
}
@ -71,8 +77,8 @@ func newProxyStateFromConfig(ctx context.Context, cfg *config.Config) (*proxySta
state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist
if cfg.Options.UseStatelessAuthenticateFlow() {
state.authenticateFlow, err = authenticateflow.NewStateless(ctx,
cfg, state.sessionStore, nil, nil, nil)
state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, state.sessionStore,
authenticateflow.IdentityProviderLookupFromCache(state.idpCache), nil, nil)
} else {
state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, state.sessionStore)
}