core/authorize: cache prepared query building

This commit is contained in:
Caleb Doxsey 2023-11-01 17:19:42 -06:00
parent 0b79a28328
commit 42233223cb
11 changed files with 143 additions and 67 deletions

View file

@ -26,7 +26,7 @@ import (
// Authorize struct holds // Authorize struct holds
type Authorize struct { type Authorize struct {
state *atomicutil.Value[*authorizeState] state *atomicutil.Value[*authorizeState]
store *store.Store compiler *evaluator.RegoCompiler
currentOptions *atomicutil.Value[*config.Options] currentOptions *atomicutil.Value[*config.Options]
accessTracker *AccessTracker accessTracker *AccessTracker
globalCache storage.Cache globalCache storage.Cache
@ -41,12 +41,12 @@ type Authorize struct {
func New(cfg *config.Config) (*Authorize, error) { func New(cfg *config.Config) (*Authorize, error) {
a := &Authorize{ a := &Authorize{
currentOptions: config.NewAtomicOptions(), currentOptions: config.NewAtomicOptions(),
store: store.New(), compiler: evaluator.NewRegoCompiler(store.New()),
globalCache: storage.NewGlobalCache(time.Minute), globalCache: storage.NewGlobalCache(time.Minute),
} }
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod) a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
state, err := newAuthorizeStateFromConfig(cfg, a.store) state, err := newAuthorizeStateFromConfig(cfg, a.compiler)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -86,14 +86,17 @@ func validateOptions(o *config.Options) error {
} }
// newPolicyEvaluator returns an policy evaluator. // newPolicyEvaluator returns an policy evaluator.
func newPolicyEvaluator(opts *config.Options, store *store.Store) (*evaluator.Evaluator, error) { func newPolicyEvaluator(opts *config.Options, compiler *evaluator.RegoCompiler) (*evaluator.Evaluator, error) {
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
return int64(len(opts.GetAllPolicies()))
})
ctx := context.Background() ctx := context.Background()
ctx, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator") ctx, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator")
defer span.End() defer span.End()
allPolicies := opts.GetAllPolicies()
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
return int64(len(allPolicies))
})
clientCA, err := opts.DownstreamMTLS.GetCA() clientCA, err := opts.DownstreamMTLS.GetCA()
if err != nil { if err != nil {
return nil, fmt.Errorf("authorize: invalid client CA: %w", err) return nil, fmt.Errorf("authorize: invalid client CA: %w", err)
@ -126,8 +129,8 @@ func newPolicyEvaluator(opts *config.Options, store *store.Store) (*evaluator.Ev
"authorize: internal error: couldn't build client cert constraints: %w", err) "authorize: internal error: couldn't build client cert constraints: %w", err)
} }
return evaluator.New(ctx, store, return evaluator.New(ctx, compiler,
evaluator.WithPolicies(opts.GetAllPolicies()), evaluator.WithPolicies(allPolicies),
evaluator.WithClientCA(clientCA), evaluator.WithClientCA(clientCA),
evaluator.WithAddDefaultClientCertificateRule(addDefaultClientCertificateRule), evaluator.WithAddDefaultClientCertificateRule(addDefaultClientCertificateRule),
evaluator.WithClientCRL(clientCRL), evaluator.WithClientCRL(clientCRL),
@ -142,7 +145,7 @@ func newPolicyEvaluator(opts *config.Options, store *store.Store) (*evaluator.Ev
// OnConfigChange updates internal structures based on config.Options // OnConfigChange updates internal structures based on config.Options
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) { func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
a.currentOptions.Store(cfg.Options) a.currentOptions.Store(cfg.Options)
if state, err := newAuthorizeStateFromConfig(cfg, a.store); err != nil { if state, err := newAuthorizeStateFromConfig(cfg, a.compiler); err != nil {
log.Error(ctx).Err(err).Msg("authorize: error updating state") log.Error(ctx).Err(err).Msg("authorize: error updating state")
} else { } else {
a.state.Store(state) a.state.Store(state)

View file

@ -176,10 +176,11 @@ func TestNewPolicyEvaluator_addDefaultClientCertificateRule(t *testing.T) {
c := &cases[i] c := &cases[i]
t.Run(c.label, func(t *testing.T) { t.Run(c.label, func(t *testing.T) {
store := store.New() store := store.New()
compiler := evaluator.NewRegoCompiler(store)
c.opts.Policies = []config.Policy{{ c.opts.Policies = []config.Policy{{
To: mustParseWeightedURLs(t, "http://example.com"), To: mustParseWeightedURLs(t, "http://example.com"),
}} }}
e, err := newPolicyEvaluator(c.opts, store) e, err := newPolicyEvaluator(c.opts, compiler)
require.NoError(t, err) require.NoError(t, err)
r, err := e.Evaluate(context.Background(), &evaluator.Request{ r, err := e.Evaluate(context.Background(), &evaluator.Request{

View file

@ -130,8 +130,8 @@ func TestAuthorize_okResponse(t *testing.T) {
} }
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(opt) a.currentOptions.Store(opt)
a.store = store.New() a.compiler = evaluator.NewRegoCompiler(store.New())
pe, err := newPolicyEvaluator(opt, a.store) pe, err := newPolicyEvaluator(opt, a.compiler)
require.NoError(t, err) require.NoError(t, err)
a.state.Load().evaluator = pe a.state.Load().evaluator = pe

View file

@ -12,7 +12,6 @@ import (
"github.com/open-policy-agent/opa/rego" "github.com/open-policy-agent/opa/rego"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
@ -89,7 +88,7 @@ type Result struct {
// An Evaluator evaluates policies. // An Evaluator evaluates policies.
type Evaluator struct { type Evaluator struct {
store *store.Store compiler *RegoCompiler
policyEvaluators map[uint64]*PolicyEvaluator policyEvaluators map[uint64]*PolicyEvaluator
headersEvaluators *HeadersEvaluator headersEvaluators *HeadersEvaluator
clientCA []byte clientCA []byte
@ -98,8 +97,8 @@ type Evaluator struct {
} }
// New creates a new Evaluator. // New creates a new Evaluator.
func New(ctx context.Context, store *store.Store, options ...Option) (*Evaluator, error) { func New(ctx context.Context, compiler *RegoCompiler, options ...Option) (*Evaluator, error) {
e := &Evaluator{store: store} e := &Evaluator{compiler: compiler}
cfg := getConfig(options...) cfg := getConfig(options...)
@ -108,7 +107,7 @@ func New(ctx context.Context, store *store.Store, options ...Option) (*Evaluator
return nil, err return nil, err
} }
e.headersEvaluators, err = NewHeadersEvaluator(ctx, store) e.headersEvaluators, err = NewHeadersEvaluator(ctx, compiler)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -124,8 +123,7 @@ func New(ctx context.Context, store *store.Store, options ...Option) (*Evaluator
if err != nil { if err != nil {
return nil, fmt.Errorf("authorize: error computing policy route id: %w", err) return nil, fmt.Errorf("authorize: error computing policy route id: %w", err)
} }
policyEvaluator, err := policyEvaluator, err := NewPolicyEvaluator(ctx, compiler, &configPolicy, cfg.addDefaultClientCertificateRule)
NewPolicyEvaluator(ctx, store, &configPolicy, cfg.addDefaultClientCertificateRule)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -257,12 +255,12 @@ func (e *Evaluator) updateStore(cfg *evaluatorConfig) error {
return fmt.Errorf("authorize: couldn't create signer: %w", err) return fmt.Errorf("authorize: couldn't create signer: %w", err)
} }
e.store.UpdateGoogleCloudServerlessAuthenticationServiceAccount( e.compiler.Store.UpdateGoogleCloudServerlessAuthenticationServiceAccount(
cfg.googleCloudServerlessAuthenticationServiceAccount, cfg.googleCloudServerlessAuthenticationServiceAccount,
) )
e.store.UpdateJWTClaimHeaders(cfg.jwtClaimsHeaders) e.compiler.Store.UpdateJWTClaimHeaders(cfg.jwtClaimsHeaders)
e.store.UpdateRoutePolicies(cfg.policies) e.compiler.Store.UpdateRoutePolicies(cfg.policies)
e.store.UpdateSigningKey(jwk) e.compiler.Store.UpdateSigningKey(jwk)
return nil return nil
} }

View file

@ -36,7 +36,8 @@ func TestEvaluator(t *testing.T) {
store := store.New() store := store.New()
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY")) store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
store.UpdateSigningKey(privateJWK) store.UpdateSigningKey(privateJWK)
e, err := New(ctx, store, options...) compiler := NewRegoCompiler(store)
e, err := New(ctx, compiler, options...)
require.NoError(t, err) require.NoError(t, err)
return e.Evaluate(ctx, req) return e.Evaluate(ctx, req)
} }

View file

@ -12,7 +12,6 @@ import (
"github.com/open-policy-agent/opa/types" "github.com/open-policy-agent/opa/types"
"github.com/pomerium/pomerium/authorize/evaluator/opa" "github.com/pomerium/pomerium/authorize/evaluator/opa"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
) )
@ -100,17 +99,8 @@ type HeadersEvaluator struct {
} }
// NewHeadersEvaluator creates a new HeadersEvaluator. // NewHeadersEvaluator creates a new HeadersEvaluator.
func NewHeadersEvaluator(ctx context.Context, store *store.Store) (*HeadersEvaluator, error) { func NewHeadersEvaluator(ctx context.Context, compiler *RegoCompiler) (*HeadersEvaluator, error) {
r := rego.New( q, err := compiler.CompileHeadersQuery(ctx, opa.HeadersRego)
rego.Store(store),
rego.Module("pomerium.headers", opa.HeadersRego),
rego.Query("result = data.pomerium.headers"),
getGoogleCloudServerlessHeadersRegoOption,
variableSubstitutionFunctionRegoOption,
store.GetDataBrokerRecordOption(),
)
q, err := r.PrepareForEval(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -76,7 +76,8 @@ func TestHeadersEvaluator(t *testing.T) {
store := store.New() store := store.New()
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY")) store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
store.UpdateSigningKey(privateJWK) store.UpdateSigningKey(privateJWK)
e, err := NewHeadersEvaluator(ctx, store) compiler := NewRegoCompiler(store)
e, err := NewHeadersEvaluator(ctx, compiler)
require.NoError(t, err) require.NoError(t, err)
return e.Evaluate(ctx, input) return e.Evaluate(ctx, input)
} }

View file

@ -3,17 +3,15 @@ package evaluator
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"github.com/cespare/xxhash/v2"
"github.com/open-policy-agent/opa/rego" "github.com/open-policy-agent/opa/rego"
octrace "go.opencensus.io/trace" octrace "go.opencensus.io/trace"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/contextutil" "github.com/pomerium/pomerium/pkg/contextutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/policy" "github.com/pomerium/pomerium/pkg/policy"
"github.com/pomerium/pomerium/pkg/policy/criteria" "github.com/pomerium/pomerium/pkg/policy/criteria"
) )
@ -98,7 +96,7 @@ type policyQuery struct {
} }
func (q policyQuery) checksum() string { func (q policyQuery) checksum() string {
return fmt.Sprintf("%x", cryptutil.Hash("script", []byte(q.script))) return fmt.Sprintf("%x", xxhash.Sum64String(q.script))
} }
// A PolicyEvaluator evaluates policies. // A PolicyEvaluator evaluates policies.
@ -108,7 +106,9 @@ type PolicyEvaluator struct {
// NewPolicyEvaluator creates a new PolicyEvaluator. // NewPolicyEvaluator creates a new PolicyEvaluator.
func NewPolicyEvaluator( func NewPolicyEvaluator(
ctx context.Context, store *store.Store, configPolicy *config.Policy, ctx context.Context,
compiler *RegoCompiler,
configPolicy *config.Policy,
addDefaultClientCertificateRule bool, addDefaultClientCertificateRule bool,
) (*PolicyEvaluator, error) { ) (*PolicyEvaluator, error) {
e := new(PolicyEvaluator) e := new(PolicyEvaluator)
@ -151,26 +151,7 @@ func NewPolicyEvaluator(
Interface("to", configPolicy.To). Interface("to", configPolicy.To).
Msg("authorize: rego script for policy evaluation") Msg("authorize: rego script for policy evaluation")
r := rego.New( q, err := compiler.CompilePolicyQuery(ctx, e.queries[i].script)
rego.Store(store),
rego.Module("pomerium.policy", e.queries[i].script),
rego.Query("result = data.pomerium.policy"),
getGoogleCloudServerlessHeadersRegoOption,
store.GetDataBrokerRecordOption(),
)
q, err := r.PrepareForEval(ctx)
// if no package is in the src, add it
if err != nil && strings.Contains(err.Error(), "package expected") {
r := rego.New(
rego.Store(store),
rego.Module("pomerium.policy", "package pomerium.policy\n\n"+e.queries[i].script),
rego.Query("result = data.pomerium.policy"),
getGoogleCloudServerlessHeadersRegoOption,
store.GetDataBrokerRecordOption(),
)
q, err = r.PrepareForEval(ctx)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -40,7 +40,8 @@ func TestPolicyEvaluator(t *testing.T) {
store := store.New() store := store.New()
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY")) store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
store.UpdateSigningKey(privateJWK) store.UpdateSigningKey(privateJWK)
e, err := NewPolicyEvaluator(ctx, store, policy, addDefaultClientCertificateRule) compiler := NewRegoCompiler(store)
e, err := NewPolicyEvaluator(ctx, compiler, policy, addDefaultClientCertificateRule)
require.NoError(t, err) require.NoError(t, err)
return e.Evaluate(ctx, input) return e.Evaluate(ctx, input)
} }

View file

@ -0,0 +1,101 @@
package evaluator
import (
"context"
"fmt"
"strings"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/open-policy-agent/opa/rego"
"github.com/pomerium/pomerium/authorize/evaluator/opa"
"github.com/pomerium/pomerium/authorize/internal/store"
)
// A RegoCompiler compiles rego scripts.
type RegoCompiler struct {
Store *store.Store
policyCache *lru.Cache[string, rego.PreparedEvalQuery]
headersCache *lru.Cache[string, rego.PreparedEvalQuery]
}
// NewRegoCompiler creates a new RegoCompiler using the given store.
func NewRegoCompiler(store *store.Store) *RegoCompiler {
policyCache, err := lru.New[string, rego.PreparedEvalQuery](10_000)
if err != nil {
panic(fmt.Errorf("failed to create lru cache for policy rego scripts: %w", err))
}
headersCache, err := lru.New[string, rego.PreparedEvalQuery](1)
if err != nil {
panic(fmt.Errorf("failed to create lru cache for headers rego scripts: %w", err))
}
return &RegoCompiler{
Store: store,
policyCache: policyCache,
headersCache: headersCache,
}
}
// CompileHeadersQuery compiles a headers query.
func (rc *RegoCompiler) CompileHeadersQuery(
ctx context.Context,
script string,
) (rego.PreparedEvalQuery, error) {
if q, ok := rc.headersCache.Get(script); ok {
return q, nil
}
r := rego.New(
rego.Store(rc.Store),
rego.Module("pomerium.headers", opa.HeadersRego),
rego.Query("result = data.pomerium.headers"),
getGoogleCloudServerlessHeadersRegoOption,
variableSubstitutionFunctionRegoOption,
rc.Store.GetDataBrokerRecordOption(),
)
q, err := r.PrepareForEval(ctx)
if err != nil {
return q, err
}
rc.headersCache.Add(script, q)
return q, nil
}
// CompilePolicyQuery compiles a policy query.
func (rc *RegoCompiler) CompilePolicyQuery(
ctx context.Context,
script string,
) (rego.PreparedEvalQuery, error) {
if q, ok := rc.policyCache.Get(script); ok {
return q, nil
}
r := rego.New(
rego.Store(rc.Store),
rego.Module("pomerium.policy", script),
rego.Query("result = data.pomerium.policy"),
getGoogleCloudServerlessHeadersRegoOption,
rc.Store.GetDataBrokerRecordOption(),
)
q, err := r.PrepareForEval(ctx)
// if no package is in the src, add it
if err != nil && strings.Contains(err.Error(), "package expected") {
r := rego.New(
rego.Store(rc.Store),
rego.Module("pomerium.policy", "package pomerium.policy\n\n"+script),
rego.Query("result = data.pomerium.policy"),
getGoogleCloudServerlessHeadersRegoOption,
rc.Store.GetDataBrokerRecordOption(),
)
q, err = r.PrepareForEval(ctx)
}
if err != nil {
return q, err
}
rc.policyCache.Add(script, q)
return q, nil
}

View file

@ -7,7 +7,6 @@ import (
googlegrpc "google.golang.org/grpc" googlegrpc "google.golang.org/grpc"
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -28,7 +27,7 @@ type authorizeState struct {
authenticateKeyFetcher hpke.KeyFetcher authenticateKeyFetcher hpke.KeyFetcher
} }
func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*authorizeState, error) { func newAuthorizeStateFromConfig(cfg *config.Config, compiler *evaluator.RegoCompiler) (*authorizeState, error) {
if err := validateOptions(cfg.Options); err != nil { if err := validateOptions(cfg.Options); err != nil {
return nil, fmt.Errorf("authorize: bad options: %w", err) return nil, fmt.Errorf("authorize: bad options: %w", err)
} }
@ -37,7 +36,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*autho
var err error var err error
state.evaluator, err = newPolicyEvaluator(cfg.Options, store) state.evaluator, err = newPolicyEvaluator(cfg.Options, compiler)
if err != nil { if err != nil {
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err) return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
} }