Optimize evaluator

This optimizes the Evaluator in the Authorize service to scale to very
large numbers of routes. Additional caching was also added when building
rego policy query evaluators in parallel to allow sharing work and to
avoid building evaluators for scripts with the same contents.
This commit is contained in:
Joe Kralicky 2024-07-22 16:39:26 -04:00
parent 526e2a58d6
commit a396c2eab3
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
16 changed files with 1539 additions and 483 deletions

View file

@ -4,6 +4,7 @@ package authorize
import (
"context"
"errors"
"fmt"
"slices"
"sync"
@ -40,26 +41,29 @@ type Authorize struct {
}
// New validates and creates a new Authorize service from a set of config options.
func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
func New() *Authorize {
a := &Authorize{
currentOptions: config.NewAtomicOptions(),
state: atomicutil.NewValue[*authorizeState](nil),
store: store.New(),
globalCache: storage.NewGlobalCache(time.Minute),
}
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, nil)
if err != nil {
return nil, err
return a
}
a.state = atomicutil.NewValue(state)
return a, nil
func (a *Authorize) HasValidState() bool {
return a.state.Load() != nil
}
// GetDataBrokerServiceClient returns the current DataBrokerServiceClient.
func (a *Authorize) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return a.state.Load().dataBrokerClient
state := a.state.Load()
if state == nil {
return nil
}
return state.dataBrokerClient
}
// Run runs the authorize service.
@ -70,7 +74,11 @@ func (a *Authorize) Run(ctx context.Context) error {
return nil
})
eg.Go(func() error {
_ = grpc.WaitForReady(ctx, a.state.Load().dataBrokerClientConnection, time.Second*10)
state := a.state.Load()
if state == nil {
return errors.New("authorize: invalid configuration")
}
_ = grpc.WaitForReady(ctx, state.dataBrokerClientConnection, time.Second*10)
return nil
})
return eg.Wait()
@ -151,7 +159,11 @@ func newPolicyEvaluator(
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
currentState := a.state.Load()
a.currentOptions.Store(cfg.Options)
if state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, currentState.evaluator); err != nil {
var prev *evaluator.Evaluator
if currentState != nil {
prev = currentState.evaluator
}
if state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, prev); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
} else {
a.state.Store(state)

View file

@ -82,7 +82,7 @@ func TestNew(t *testing.T) {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := New(context.Background(), &config.Config{Options: &tt.config})
_, err := newAuthorizeStateFromConfig(context.Background(), &config.Config{Options: &tt.config}, store.New(), nil)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
@ -114,12 +114,12 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
SharedKey: tc.SharedKey,
Policies: tc.Policies,
}
a, err := New(context.Background(), &config.Config{Options: o})
require.NoError(t, err)
require.NotNil(t, a)
cfg := &config.Config{Options: o}
a := New()
a.OnConfigChange(context.Background(), cfg)
require.True(t, a.HasValidState())
oldPe := a.state.Load().evaluator
cfg := &config.Config{Options: o}
assertFunc := assert.True
o.SigningKey = "bad-share-key"
if tc.expectedChange {

View file

@ -34,8 +34,10 @@ func TestAuthorize_handleResult(t *testing.T) {
t.Cleanup(authnSrv.Close)
opt.AuthenticateURLString = authnSrv.URL
a, err := New(context.Background(), &config.Config{Options: opt})
require.NoError(t, err)
cfg := &config.Config{Options: opt}
a := New()
a.OnConfigChange(context.Background(), cfg)
require.True(t, a.HasValidState())
t.Run("user-unauthenticated", func(t *testing.T) {
res, err := a.handleResult(context.Background(),
@ -327,8 +329,10 @@ func TestRequireLogin(t *testing.T) {
t.Cleanup(authnSrv.Close)
opt.AuthenticateURLString = authnSrv.URL
a, err := New(context.Background(), &config.Config{Options: opt})
require.NoError(t, err)
cfg := &config.Config{Options: opt}
a := New()
a.OnConfigChange(context.Background(), cfg)
require.True(t, a.HasValidState())
t.Run("accept empty", func(t *testing.T) {
res, err := a.requireLoginResponse(context.Background(),

View file

@ -65,12 +65,13 @@ func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) {
t.Cleanup(clearTimeout)
opt := config.NewDefaultOptions()
a, err := New(context.Background(), &config.Config{Options: opt})
require.NoError(t, err)
a := New()
a.OnConfigChange(context.Background(), &config.Config{Options: opt})
require.True(t, a.HasValidState())
s1 := &session.Session{Id: "s1", ExpiresAt: timestamppb.New(time.Now().Add(-time.Second))}
sq := storage.NewStaticQuerier(s1)
qctx := storage.WithQuerier(ctx, sq)
_, err = a.getDataBrokerSessionOrServiceAccount(qctx, "s1", 0)
_, err := a.getDataBrokerSessionOrServiceAccount(qctx, "s1", 0)
assert.ErrorIs(t, err, session.ErrSessionExpired)
}

View file

@ -1,6 +1,8 @@
package evaluator
import (
"sync"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/hashutil"
)
@ -15,22 +17,26 @@ type evaluatorConfig struct {
AuthenticateURL string
GoogleCloudServerlessAuthenticationServiceAccount string
JWTClaimsHeaders config.JWTClaimHeaders
cacheKeyOnce sync.Once
computedCacheKey uint64
}
// cacheKey() returns a hash over the configuration, except for the policies.
func (e *evaluatorConfig) cacheKey() uint64 {
return hashutil.MustHash(e)
e.cacheKeyOnce.Do(func() {
e.computedCacheKey = hashutil.MustHash(e)
})
return e.computedCacheKey
}
// An Option customizes the evaluator config.
type Option func(*evaluatorConfig)
func getConfig(options ...Option) *evaluatorConfig {
cfg := new(evaluatorConfig)
for _, o := range options {
o(cfg)
func (e *evaluatorConfig) apply(options ...Option) {
for _, opt := range options {
opt(e)
}
return cfg
}
// WithPolicies sets the policies in the config.

View file

@ -2,21 +2,29 @@
package evaluator
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"maps"
"math/bits"
"net/http"
"net/url"
"runtime"
rttrace "runtime/trace"
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/go-jose/go-jose/v3"
"github.com/hashicorp/go-set/v3"
"github.com/open-policy-agent/opa/rego"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/singleflight"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/errgrouputil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace"
@ -90,51 +98,190 @@ type Result struct {
Traces []contextutil.PolicyEvaluationTrace
}
type PolicyEvaluatorCacheStats struct {
CacheHits int64
CacheMisses int64
}
type PolicyEvaluatorCache struct {
evalsMu sync.RWMutex
evaluatorsByRouteID map[uint64]*PolicyEvaluator
cacheHits atomic.Int64
cacheMisses atomic.Int64
}
type QueryCacheStats struct {
CacheHits int64
CacheMisses int64
BuildsSucceeded int64
BuildsFailed int64
BuildsShared int64
}
type QueryCache struct {
queriesMu sync.RWMutex
queriesByScriptChecksum map[string]rego.PreparedEvalQuery
sf singleflight.Group
cacheHits atomic.Int64
cacheMisses atomic.Int64
buildsSucceeded atomic.Int64
buildsFailed atomic.Int64
buildsShared atomic.Int64
}
func NewPolicyEvaluatorCache(initialSize int) *PolicyEvaluatorCache {
return &PolicyEvaluatorCache{
evaluatorsByRouteID: make(map[uint64]*PolicyEvaluator, initialSize),
}
}
func NewQueryCache(initialSize int) *QueryCache {
return &QueryCache{
queriesByScriptChecksum: make(map[string]rego.PreparedEvalQuery, initialSize),
}
}
func (c *PolicyEvaluatorCache) NumCachedEvaluators() int {
c.evalsMu.RLock()
defer c.evalsMu.RUnlock()
return len(c.evaluatorsByRouteID)
}
func (c *PolicyEvaluatorCache) StoreEvaluator(routeID uint64, eval *PolicyEvaluator) {
c.evalsMu.Lock()
defer c.evalsMu.Unlock()
c.evaluatorsByRouteID[routeID] = eval
}
func (c *PolicyEvaluatorCache) LookupEvaluator(routeID uint64) (*PolicyEvaluator, bool) {
c.evalsMu.RLock()
defer c.evalsMu.RUnlock()
eval, ok := c.evaluatorsByRouteID[routeID]
if ok {
c.cacheHits.Add(1)
} else {
c.cacheMisses.Add(1)
}
return eval, ok
}
func (c *PolicyEvaluatorCache) Stats() PolicyEvaluatorCacheStats {
return PolicyEvaluatorCacheStats{
CacheHits: c.cacheHits.Load(),
CacheMisses: c.cacheMisses.Load(),
}
}
func (c *QueryCache) NumCachedQueries() int {
c.queriesMu.RLock()
defer c.queriesMu.RUnlock()
return len(c.queriesByScriptChecksum)
}
func (c *QueryCache) Stats() QueryCacheStats {
return QueryCacheStats{
CacheHits: c.cacheHits.Load(),
CacheMisses: c.cacheMisses.Load(),
BuildsSucceeded: c.buildsSucceeded.Load(),
BuildsFailed: c.buildsFailed.Load(),
BuildsShared: c.buildsShared.Load(),
}
}
func (c *QueryCache) LookupOrBuild(q *policyQuery, builder func() (rego.PreparedEvalQuery, error)) (rego.PreparedEvalQuery, bool, error) {
checksum := q.checksum()
c.queriesMu.RLock()
cached, ok := c.queriesByScriptChecksum[checksum]
c.queriesMu.RUnlock()
if ok {
c.cacheHits.Add(1)
return cached, true, nil
}
c.cacheMisses.Add(1)
var ours bool
pq, err, shared := c.sf.Do(checksum, func() (any, error) {
ours = true
res, err := builder()
if err == nil {
c.queriesMu.Lock()
c.queriesByScriptChecksum[checksum] = res
c.queriesMu.Unlock()
c.buildsSucceeded.Add(1)
} else {
c.buildsFailed.Add(1)
}
return res, err
})
if err != nil {
return rego.PreparedEvalQuery{}, false, err
}
if shared && !ours {
c.buildsShared.Add(1)
}
return pq.(rego.PreparedEvalQuery), false, nil
}
// An Evaluator evaluates policies.
type Evaluator struct {
opts *evaluatorConfig
store *store.Store
policyEvaluators map[uint64]*PolicyEvaluator
headersEvaluators *HeadersEvaluator
clientCA []byte
clientCRL []byte
clientCertConstraints ClientCertConstraints
cfgCacheKey uint64
evalCache *PolicyEvaluatorCache
queryCache *QueryCache
headersEvaluator *HeadersEvaluator
}
// New creates a new Evaluator.
func New(
ctx context.Context, store *store.Store, previous *Evaluator, options ...Option,
ctx context.Context,
store *store.Store,
previous *Evaluator,
options ...Option,
) (*Evaluator, error) {
cfg := getConfig(options...)
ctx, task := rttrace.NewTask(ctx, "evaluator.New")
defer task.End()
defer rttrace.StartRegion(ctx, "evaluator.New").End()
err := updateStore(ctx, store, cfg)
var opts evaluatorConfig
opts.apply(options...)
e := &Evaluator{
opts: &opts,
store: store,
}
if previous == nil || opts.cacheKey() != previous.opts.cacheKey() || store != previous.store {
var err error
rttrace.WithRegion(ctx, "update store", func() {
err = updateStore(ctx, store, &opts, previous)
})
if err != nil {
return nil, err
}
e := &Evaluator{
store: store,
clientCA: cfg.ClientCA,
clientCRL: cfg.ClientCRL,
clientCertConstraints: cfg.ClientCertConstraints,
cfgCacheKey: cfg.cacheKey(),
rttrace.WithRegion(ctx, "create headers evaluator", func() {
e.headersEvaluator, err = NewHeadersEvaluator(ctx, store)
})
if err != nil {
return nil, err
}
e.evalCache = NewPolicyEvaluatorCache(len(opts.Policies))
e.queryCache = NewQueryCache(len(opts.Policies))
} else {
// If there is a previous Evaluator constructed from the same settings, we
// can reuse the HeadersEvaluator along with any PolicyEvaluators for
// unchanged policies.
var cachedPolicyEvaluators map[uint64]*PolicyEvaluator
if previous != nil && previous.cfgCacheKey == e.cfgCacheKey {
e.headersEvaluators = previous.headersEvaluators
cachedPolicyEvaluators = previous.policyEvaluators
} else {
e.headersEvaluators, err = NewHeadersEvaluator(ctx, store)
if err != nil {
return nil, err
e.headersEvaluator = previous.headersEvaluator
e.evalCache = previous.evalCache
e.queryCache = previous.queryCache
}
}
e.policyEvaluators, err = getOrCreatePolicyEvaluators(ctx, cfg, store, cachedPolicyEvaluators)
var err error
rttrace.WithRegion(ctx, "update policy evaluators", func() {
err = getOrCreatePolicyEvaluators(ctx, &opts, store, e.evalCache, e.queryCache)
})
if err != nil {
return nil, err
}
@ -142,62 +289,321 @@ func New(
return e, nil
}
var (
workerPoolSize = runtime.NumCPU() - 1
workerPoolMu sync.Mutex
workerPoolTaskQueue = make(chan func(), (workerPoolSize+1)*2)
)
func init() {
for i := 0; i < workerPoolSize; i++ {
// the worker function is separate so that it shows up in stack traces as
// 'worker' instead of an anonymous function in init()
go worker()
}
}
func worker() {
workerPoolMu.Lock()
queue := workerPoolTaskQueue
workerPoolMu.Unlock()
for fn := range queue {
fn()
}
}
type chunkSizes interface {
uint8 | uint16 | uint32 | uint64
}
// returns the size in bits for any allowed type in chunkSizes
func chunkSize[T chunkSizes]() int {
return int(unsafe.Sizeof(T(0))) * 8
}
type workerContext[T chunkSizes] struct {
context.Context
Cfg *evaluatorConfig
Store *store.Store
StatusBits []T
Evaluators []routeEvaluator
EvalCache *PolicyEvaluatorCache
QueryCache *QueryCache
Errs *sync.Map
}
type routeEvaluator struct {
id uint64
evaluator *PolicyEvaluator
ID uint64 // route id
Evaluator *PolicyEvaluator // the compiled evaluator
ComputedChecksum uint64 // cached evaluator checksum
}
// partition represents a range of chunks (fixed-size blocks of policies)
// corresponding to the Cfg.Policies, StatusBits, and Evaluators fields in
// the workerContext. It is the slice [Begin*chunkSize:End*chunkSize] w.r.t.
// those fields, but each index represents a unit of work that can be done in
// parallel with work on other chunks.
type partition struct{ Begin, End int }
// computeChecksums is a worker task that computes policy checksums and updates
// StatusBits to flag policies that need to be rebuilt. It operates on entire
// chunks (fixed-size blocks of policies), given by the start and end indexes
// in the partition argument, and updates the corresponding indexes of the
// StatusBits field of the worker context for those chunks.
func computeChecksums[T chunkSizes](wctx *workerContext[T], part partition) {
defer rttrace.StartRegion(wctx, "worker-checksum").End()
for chunkIdx := part.Begin; chunkIdx < part.End; chunkIdx++ {
var chunkStatus T
chunkSize := chunkSize[T]()
off := chunkIdx * chunkSize // chunk offset
// If there are fewer than chunkSize policies remaining in the actual list,
// don't go beyond the end
limit := min(chunkSize, len(wctx.Cfg.Policies)-off)
popcount := 0
for i := range limit {
p := wctx.Cfg.Policies[off+i]
// Compute the route id; this value is reused later as the route name
// when computing the checksum
id, err := p.RouteID()
if err != nil {
wctx.Errs.Store(off+i, fmt.Errorf("authorize: error computing policy route id: %w", err))
continue
}
eval := &wctx.Evaluators[off+i]
eval.ID = id
// Compute the policy checksum and cache it in the evaluator, reusing
// the route ID from before (to avoid needing to compute it again)
eval.ComputedChecksum = p.Checksum()
// eval.ComputedChecksum = p.ChecksumWithID(id) // TODO: update this when merged
// Check if there is an existing evaluator cached for the route ID
// NB: the route ID is composed of a subset of fields of the Policy; this
// means the cache will hit if the route ID fields are the same, even if
// other fields in the policy differ.
cached, ok := wctx.EvalCache.LookupEvaluator(id)
if !ok {
rttrace.Logf(wctx, "", "policy for route ID %d not found in cache", id)
chunkStatus |= T(1 << i)
popcount++
} else if cached.policyChecksum != eval.ComputedChecksum {
// Route ID is the same, but the full checksum differs
rttrace.Logf(wctx, "", "policy for route ID %d changed", id)
chunkStatus |= T(1 << i)
popcount++
}
// On a cache hit, chunkStatus for the ith bit stays at 0
}
// Set chunkStatus bitmask all at once (for better locality)
wctx.StatusBits[chunkIdx] = chunkStatus
rttrace.Logf(wctx, "", "chunk %d: %d/%d changed", chunkIdx, popcount, limit)
}
}
// buildEvaluators is a worker task that creates new policy evaluators. It
// operates on entire chunks (fixed-size blocks of policies), given by the start
// and end indexes in the partition argument, and updates the corresponding
// indexes of the Evaluators field of the worker context for those chunks.
func buildEvaluators[T chunkSizes](wctx *workerContext[T], part partition) {
chunkSize := chunkSize[T]()
defer rttrace.StartRegion(wctx, "worker-build").End()
addDefaultCert := wctx.Cfg.AddDefaultClientCertificateRule
var err error
for chunkIdx := part.Begin; chunkIdx < part.End; chunkIdx++ {
// Obtain the bitmask computed by computeChecksums for this chunk
stat := wctx.StatusBits[chunkIdx]
rttrace.Logf(wctx, "", "chunk %d: status: %0*b", chunkIdx, chunkSize, stat)
// Iterate over all the set bits in stat. This works by finding the
// lowest set bit, zeroing it, and repeating. The go compiler will
// replace [bits.TrailingZeros64] with intrinsics on most platforms.
for stat != 0 {
bit := bits.TrailingZeros64(uint64(stat)) // find the lowest set bit
stat &= (stat - 1) // clear the lowest set bit
idx := (chunkSize * chunkIdx) + bit
p := wctx.Cfg.Policies[idx]
eval := &wctx.Evaluators[idx]
eval.Evaluator, err = NewPolicyEvaluator(wctx, wctx.Store, p, eval.ComputedChecksum, addDefaultCert, wctx.QueryCache)
if err != nil {
wctx.Errs.Store(idx, err)
}
}
}
}
// bestChunkSize determines the chunk size (8, 16, 32, or 64) to use for the
// given number of policies and workers.
func bestChunkSize(numPolicies, numWorkers int) int {
// use the chunk size that results in the largest number of chunks without
// going past the number of workers. this results in the following behavior:
// - as the number of policies increases, chunk size tends to increase
// - as the number of workers increases, chunk size tends to decrease
sizes := []int{64, 32, 16, 8}
sizeIdx := 0
for i, size := range sizes {
if float64(numPolicies)/float64(size) > float64(numWorkers) {
break
}
sizeIdx = i
}
return sizes[sizeIdx]
}
func getOrCreatePolicyEvaluators(
ctx context.Context, cfg *evaluatorConfig, store *store.Store,
cachedPolicyEvaluators map[uint64]*PolicyEvaluator,
) (map[uint64]*PolicyEvaluator, error) {
ctx context.Context,
cfg *evaluatorConfig,
store *store.Store,
evalCache *PolicyEvaluatorCache,
queryCache *QueryCache,
) error {
chunkSize := bestChunkSize(len(cfg.Policies), workerPoolSize)
switch chunkSize {
case 8:
return getOrCreatePolicyEvaluatorsT[uint8](ctx, cfg, store, evalCache, queryCache)
case 16:
return getOrCreatePolicyEvaluatorsT[uint16](ctx, cfg, store, evalCache, queryCache)
case 32:
return getOrCreatePolicyEvaluatorsT[uint32](ctx, cfg, store, evalCache, queryCache)
case 64:
return getOrCreatePolicyEvaluatorsT[uint64](ctx, cfg, store, evalCache, queryCache)
}
panic("unreachable")
}
func getOrCreatePolicyEvaluatorsT[T chunkSizes](
ctx context.Context,
cfg *evaluatorConfig,
store *store.Store,
evalCache *PolicyEvaluatorCache,
queryCache *QueryCache,
) error {
chunkSize := bestChunkSize(len(cfg.Policies), workerPoolSize)
rttrace.Logf(ctx, "", "eval cache size: %d; query cache size: %d; chunk size: %d",
evalCache.NumCachedEvaluators(), queryCache.NumCachedQueries(), chunkSize)
now := time.Now()
var reusedCount int
m := make(map[uint64]*PolicyEvaluator)
var builders []errgrouputil.BuilderFunc[routeEvaluator]
for i := range cfg.Policies {
configPolicy := cfg.Policies[i]
id, err := configPolicy.RouteID()
if err != nil {
return nil, fmt.Errorf("authorize: error computing policy route id: %w", err)
// Split the policy list into chunks which can individually be operated on in
// parallel with other chunks. Each chunk has a corresponding bitmask in the
// statusBits list which is used to indicate to workers which policy
// evaluators (at indexes in the chunk corresponding to set bits) need to be
// built, or rebuilt due to changes.
numChunks := len(cfg.Policies) / chunkSize
if len(cfg.Policies)%chunkSize != 0 {
numChunks++
}
p := cachedPolicyEvaluators[id]
if p != nil && p.policyChecksum == configPolicy.Checksum() {
m[id] = p
reusedCount++
statusBits := make([]T, numChunks) // bits map directly to policy indexes
evaluators := make([]routeEvaluator, len(cfg.Policies))
if len(evaluators) == 0 {
return nil // nothing to do
}
// Limit the number of workers to the size of the worker pool; since we are
// manually distributing chunks between workers, we can avoid spawning more
// goroutines than we need, and instead giving each worker additional chunks.
numWorkers := min(workerPoolSize, numChunks)
// Each worker is given a minimum number of chunks, then the remainder are
// spread evenly between workers.
minChunksPerWorker := numChunks / numWorkers
overflow := numChunks % numWorkers // number of workers which get an additional chunk
wctx := &workerContext[T]{
Context: ctx,
Cfg: cfg,
Store: store,
StatusBits: statusBits,
Evaluators: evaluators,
EvalCache: evalCache,
QueryCache: queryCache,
Errs: &sync.Map{}, // policy index->error
}
// First, build a list of partitions (start/end chunk indexes) to send to
// each worker.
partitions := make([]partition, numWorkers)
rttrace.WithRegion(ctx, "partitioning", func() {
chunkIdx := 0
for workerIdx := range numWorkers {
chunkStart := chunkIdx
chunkEnd := chunkStart + minChunksPerWorker
if workerIdx < overflow {
chunkEnd++
}
chunkIdx = chunkEnd
partitions[workerIdx] = partition{chunkStart, chunkEnd}
}
})
// Compute all route checksums in parallel to determine which routes need to
// be rebuilt.
rttrace.WithRegion(ctx, "computing checksums", func() {
var wg sync.WaitGroup
for _, part := range partitions {
wg.Add(1)
workerPoolTaskQueue <- func() {
defer wg.Done()
computeChecksums(wctx, part)
}
}
wg.Wait()
})
hasErrs := false
wctx.Errs.Range(func(key, value any) bool {
log.Ctx(ctx).Error().Int("policy-index", key.(int)).Msg(value.(error).Error())
hasErrs = true
return true
})
if hasErrs {
return fmt.Errorf("authorize: error computing one or more policy route IDs")
}
// After all checksums are computed and status bits populated, build the
// required evaluators.
rttrace.WithRegion(ctx, "building evaluators", func() {
var wg sync.WaitGroup
for _, part := range partitions {
// Adjust the partition to skip over chunks with 0 bits set
for part.Begin < part.End && statusBits[part.Begin] == 0 {
part.Begin++
}
for part.Begin < (part.End)-1 && statusBits[part.End-1] == 0 {
part.End--
}
if part.Begin == part.End {
continue
}
builders = append(builders, func(ctx context.Context) (*routeEvaluator, error) {
evaluator, err := NewPolicyEvaluator(ctx, store, configPolicy, cfg.AddDefaultClientCertificateRule)
if err != nil {
return nil, fmt.Errorf("authorize: error building evaluator for route id=%s: %w", configPolicy.ID, err)
wg.Add(1)
workerPoolTaskQueue <- func() {
defer wg.Done()
buildEvaluators(wctx, part)
}
return &routeEvaluator{
id: id,
evaluator: evaluator,
}, nil
}
wg.Wait()
})
hasErrs = false
wctx.Errs.Range(func(key, value any) bool {
log.Ctx(ctx).Error().Int("policy-index", key.(int)).Msg(value.(error).Error())
hasErrs = true
return true
})
if hasErrs {
return fmt.Errorf("authorize: error building policy evaluators")
}
evals, errs := errgrouputil.Build(ctx, builders...)
if len(errs) > 0 {
for _, err := range errs {
log.Ctx(ctx).Error().Msg(err.Error())
// Store updated evaluators in the cache
updatedCount := 0
for _, p := range evaluators {
if p.Evaluator != nil { // these are only set when modified
updatedCount++
evalCache.StoreEvaluator(p.ID, p.Evaluator)
}
return nil, fmt.Errorf("authorize: error building policy evaluators")
}
for _, p := range evals {
m[p.id] = p.evaluator
}
log.Ctx(ctx).Debug().
Dur("duration", time.Since(now)).
Int("reused-policies", reusedCount).
Int("created-policies", len(cfg.Policies)-reusedCount).
Int("reused-policies", len(cfg.Policies)-updatedCount).
Int("created-policies", updatedCount).
Msg("updated policy evaluators")
return m, nil
return nil
}
// Evaluate evaluates the rego for the given policy and generates the identity headers.
@ -272,7 +678,7 @@ func (e *Evaluator) evaluatePolicy(ctx context.Context, req *Request) (*PolicyRe
return nil, fmt.Errorf("authorize: error computing policy route id: %w", err)
}
policyEvaluator, ok := e.policyEvaluators[id]
policyEvaluator, ok := e.evalCache.LookupEvaluator(id)
if !ok {
return &PolicyResponse{
Deny: NewRuleResult(true, criteria.ReasonRouteNotFound),
@ -285,7 +691,7 @@ func (e *Evaluator) evaluatePolicy(ctx context.Context, req *Request) (*PolicyRe
}
isValidClientCertificate, err := isValidClientCertificate(
clientCA, string(e.clientCRL), req.HTTP.ClientCertificate, e.clientCertConstraints)
clientCA, string(e.opts.ClientCRL), req.HTTP.ClientCertificate, e.opts.ClientCertConstraints)
if err != nil {
return nil, fmt.Errorf("authorize: error validating client certificate: %w", err)
}
@ -303,7 +709,7 @@ func (e *Evaluator) evaluateHeaders(ctx context.Context, req *Request) (*Headers
return nil, err
}
headersReq.Session = req.Session
res, err := e.headersEvaluators.Evaluate(ctx, headersReq)
res, err := e.headersEvaluator.Evaluate(ctx, headersReq)
if err != nil {
return nil, err
}
@ -322,29 +728,36 @@ func (e *Evaluator) getClientCA(policy *config.Policy) (string, error) {
return string(bs), nil
}
return string(e.clientCA), nil
return string(e.opts.ClientCA), nil
}
func updateStore(ctx context.Context, store *store.Store, cfg *evaluatorConfig) error {
jwk, err := getJWK(ctx, cfg)
func updateStore(ctx context.Context, store *store.Store, cfg *evaluatorConfig, previous *Evaluator) error {
if previous == nil || !bytes.Equal(cfg.SigningKey, previous.opts.SigningKey) {
jwk, err := getJWK(ctx, cfg.SigningKey)
if err != nil {
return fmt.Errorf("authorize: couldn't create signer: %w", err)
}
store.UpdateSigningKey(jwk)
}
if previous == nil || cfg.GoogleCloudServerlessAuthenticationServiceAccount != previous.opts.GoogleCloudServerlessAuthenticationServiceAccount {
store.UpdateGoogleCloudServerlessAuthenticationServiceAccount(
cfg.GoogleCloudServerlessAuthenticationServiceAccount,
)
store.UpdateJWTClaimHeaders(cfg.JWTClaimsHeaders)
store.UpdateRoutePolicies(cfg.Policies)
store.UpdateSigningKey(jwk)
}
if previous == nil || !maps.Equal(cfg.JWTClaimsHeaders, previous.opts.JWTClaimsHeaders) {
store.UpdateJWTClaimHeaders(cfg.JWTClaimsHeaders)
}
store.UpdateRoutePolicies(cfg.Policies)
return nil
}
func getJWK(ctx context.Context, cfg *evaluatorConfig) (*jose.JSONWebKey, error) {
func getJWK(ctx context.Context, signingKey []byte) (*jose.JSONWebKey, error) {
var decodedCert []byte
// if we don't have a signing key, generate one
if len(cfg.SigningKey) == 0 {
if len(signingKey) == 0 {
key, err := cryptutil.NewSigningKey()
if err != nil {
return nil, fmt.Errorf("couldn't generate signing key: %w", err)
@ -354,7 +767,7 @@ func getJWK(ctx context.Context, cfg *evaluatorConfig) (*jose.JSONWebKey, error)
return nil, fmt.Errorf("bad signing key: %w", err)
}
} else {
decodedCert = cfg.SigningKey
decodedCert = signingKey
}
jwk, err := cryptutil.PrivateJWKFromBytes(decodedCert)

View file

@ -0,0 +1,49 @@
package evaluator
import (
"maps"
)
func (c *PolicyEvaluatorCache) XClone() *PolicyEvaluatorCache {
c.evalsMu.Lock()
defer c.evalsMu.Unlock()
return &PolicyEvaluatorCache{
evaluatorsByRouteID: maps.Clone(c.evaluatorsByRouteID),
}
}
func (e *Evaluator) XEvaluatorCache() *PolicyEvaluatorCache {
return e.evalCache
}
func (e *Evaluator) XQueryCache() *QueryCache {
return e.queryCache
}
var (
XGetGoogleCloudServerlessTokenSource = getGoogleCloudServerlessTokenSource
XIsValidClientCertificate = isValidClientCertificate
XNormalizeServiceAccount = normalizeServiceAccount
XBestChunkSize = bestChunkSize
XGetUserPrincipalNamesFromSAN = getUserPrincipalNamesFromSAN
)
var OIDUserPrincipalName = oidUserPrincipalName
func XWorkerPoolSize() int {
return workerPoolSize
}
func OverrideWorkerPoolSizeForTesting(newSize int) {
if newSize == workerPoolSize {
return
}
workerPoolMu.Lock()
workerPoolSize = newSize
close(workerPoolTaskQueue) // this will stop existing workers
workerPoolTaskQueue = make(chan func(), (workerPoolSize+1)*2)
workerPoolMu.Unlock()
for i := 0; i < workerPoolSize; i++ {
go worker()
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
package evaluator
package evaluator_test
import (
"encoding/asn1"
@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config"
)
@ -167,20 +168,20 @@ dgwikvJkMOfcuexx
func Test_isValidClientCertificate(t *testing.T) {
t.Parallel()
var noConstraints ClientCertConstraints
var noConstraints evaluator.ClientCertConstraints
t.Run("no ca", func(t *testing.T) {
valid, err := isValidClientCertificate(
"", "", ClientCertificateInfo{Leaf: "WHATEVER!"}, noConstraints)
valid, err := evaluator.XIsValidClientCertificate(
"", "", evaluator.ClientCertificateInfo{Leaf: "WHATEVER!"}, noConstraints)
assert.NoError(t, err, "should not return an error")
assert.True(t, valid, "should return true")
})
t.Run("no cert", func(t *testing.T) {
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{}, noConstraints)
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{}, noConstraints)
assert.NoError(t, err, "should not return an error")
assert.False(t, valid, "should return false")
})
t.Run("valid cert", func(t *testing.T) {
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCert,
}, noConstraints)
@ -188,7 +189,7 @@ func Test_isValidClientCertificate(t *testing.T) {
assert.True(t, valid, "should return true")
})
t.Run("valid cert with intermediate", func(t *testing.T) {
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidIntermediateCert,
Intermediates: testIntermediateCA,
@ -197,7 +198,7 @@ func Test_isValidClientCertificate(t *testing.T) {
assert.True(t, valid, "should return true")
})
t.Run("valid cert missing intermediate", func(t *testing.T) {
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidIntermediateCert,
Intermediates: "",
@ -206,7 +207,7 @@ func Test_isValidClientCertificate(t *testing.T) {
assert.False(t, valid, "should return false")
})
t.Run("intermediate CA as root", func(t *testing.T) {
valid, err := isValidClientCertificate(testIntermediateCA, "", ClientCertificateInfo{
valid, err := evaluator.XIsValidClientCertificate(testIntermediateCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidIntermediateCert,
}, noConstraints)
@ -214,7 +215,7 @@ func Test_isValidClientCertificate(t *testing.T) {
assert.True(t, valid, "should return true")
})
t.Run("unsigned cert", func(t *testing.T) {
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testUntrustedCert,
}, noConstraints)
@ -222,7 +223,7 @@ func Test_isValidClientCertificate(t *testing.T) {
assert.False(t, valid, "should return false")
})
t.Run("not a cert", func(t *testing.T) {
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: "WHATEVER!",
}, noConstraints)
@ -230,22 +231,22 @@ func Test_isValidClientCertificate(t *testing.T) {
assert.False(t, valid, "should return false")
})
t.Run("revoked cert", func(t *testing.T) {
revokedCertInfo := ClientCertificateInfo{
revokedCertInfo := evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testRevokedCert,
}
// The "revoked cert" should otherwise be valid (when no CRL is specified).
valid, err := isValidClientCertificate(testCA, "", revokedCertInfo, noConstraints)
valid, err := evaluator.XIsValidClientCertificate(testCA, "", revokedCertInfo, noConstraints)
assert.NoError(t, err, "should not return an error")
assert.True(t, valid, "should return true")
valid, err = isValidClientCertificate(testCA, testCRL, revokedCertInfo, noConstraints)
valid, err = evaluator.XIsValidClientCertificate(testCA, testCRL, revokedCertInfo, noConstraints)
assert.NoError(t, err, "should not return an error")
assert.False(t, valid, "should return false")
// Specifying a CRL containing the revoked cert should not affect other certs.
valid, err = isValidClientCertificate(testCA, testCRL, ClientCertificateInfo{
valid, err = evaluator.XIsValidClientCertificate(testCA, testCRL, evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCert,
}, noConstraints)
@ -253,51 +254,51 @@ func Test_isValidClientCertificate(t *testing.T) {
assert.True(t, valid, "should return true")
})
t.Run("chain too deep", func(t *testing.T) {
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidIntermediateCert,
Intermediates: testIntermediateCA,
}, ClientCertConstraints{MaxVerifyDepth: 1})
}, evaluator.ClientCertConstraints{MaxVerifyDepth: 1})
assert.NoError(t, err, "should not return an error")
assert.False(t, valid, "should return false")
})
t.Run("any SAN", func(t *testing.T) {
matchAnySAN := ClientCertConstraints{SANMatchers: SANMatchers{
matchAnySAN := evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
config.SANTypeDNS: regexp.MustCompile("^.*$"),
config.SANTypeEmail: regexp.MustCompile("^.*$"),
config.SANTypeIPAddress: regexp.MustCompile("^.*$"),
config.SANTypeURI: regexp.MustCompile("^.*$"),
}}
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCert, // no SANs
}, matchAnySAN)
assert.NoError(t, err, "should not return an error")
assert.False(t, valid, "should return false")
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithDNSSANs,
}, matchAnySAN)
assert.NoError(t, err, "should not return an error")
assert.True(t, valid, "should return true")
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithEmailSAN,
}, matchAnySAN)
assert.NoError(t, err, "should not return an error")
assert.True(t, valid, "should return true")
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithIPSAN,
}, matchAnySAN)
assert.NoError(t, err, "should not return an error")
assert.True(t, valid, "should return true")
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithURISAN,
}, matchAnySAN)
@ -305,95 +306,95 @@ func Test_isValidClientCertificate(t *testing.T) {
assert.True(t, valid, "should return true")
})
t.Run("DNS SAN", func(t *testing.T) {
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithDNSSANs,
}, ClientCertConstraints{SANMatchers: SANMatchers{
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
config.SANTypeDNS: regexp.MustCompile(`^a\..*\.example\.com$`),
}})
assert.NoError(t, err, "should not return an error")
assert.True(t, valid, "should return true")
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithDNSSANs,
}, ClientCertConstraints{SANMatchers: SANMatchers{
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
config.SANTypeEmail: regexp.MustCompile(`^a\..*\.example\.com$`), // mismatched type
}})
assert.NoError(t, err, "should not return an error")
assert.False(t, valid, "should return false")
})
t.Run("email SAN", func(t *testing.T) {
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithEmailSAN,
}, ClientCertConstraints{SANMatchers: SANMatchers{
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
config.SANTypeEmail: regexp.MustCompile(`^.*@example\.com$`),
}})
assert.NoError(t, err, "should not return an error")
assert.True(t, valid, "should return true")
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithEmailSAN,
}, ClientCertConstraints{SANMatchers: SANMatchers{
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
config.SANTypeIPAddress: regexp.MustCompile(`^.*@example\.com$`), // mismatched type
}})
assert.NoError(t, err, "should not return an error")
assert.False(t, valid, "should return false")
})
t.Run("IP address SAN", func(t *testing.T) {
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithIPSAN,
}, ClientCertConstraints{SANMatchers: SANMatchers{
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
config.SANTypeIPAddress: regexp.MustCompile(`^192\.168\.10\..*$`),
}})
assert.NoError(t, err, "should not return an error")
assert.True(t, valid, "should return true")
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithIPSAN,
}, ClientCertConstraints{SANMatchers: SANMatchers{
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
config.SANTypeURI: regexp.MustCompile(`^192\.168\.10\..*$`), // mismatched type
}})
assert.NoError(t, err, "should not return an error")
assert.False(t, valid, "should return false")
})
t.Run("URI SAN", func(t *testing.T) {
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithURISAN,
}, ClientCertConstraints{SANMatchers: SANMatchers{
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
config.SANTypeURI: regexp.MustCompile(`^spiffe://example\.com/.*$`),
}})
assert.NoError(t, err, "should not return an error")
assert.True(t, valid, "should return true")
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithURISAN,
}, ClientCertConstraints{SANMatchers: SANMatchers{
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
config.SANTypeDNS: regexp.MustCompile(`^spiffe://example\.com/.*$`), // mismatched type
}})
assert.NoError(t, err, "should not return an error")
assert.False(t, valid, "should return false")
})
t.Run("UserPrincipalName SAN", func(t *testing.T) {
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithUPNSAN,
}, ClientCertConstraints{SANMatchers: SANMatchers{
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
config.SANTypeUserPrincipalName: regexp.MustCompile(`^test_device$`),
}})
assert.NoError(t, err, "should not return an error")
assert.True(t, valid, "should return true")
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
Presented: true,
Leaf: testValidCertWithURISAN,
}, ClientCertConstraints{SANMatchers: SANMatchers{
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
config.SANTypeDNS: regexp.MustCompile(`^test-device$`), // mismatched type
}})
assert.NoError(t, err, "should not return an error")
@ -406,26 +407,26 @@ func TestClientCertConstraintsFromConfig(t *testing.T) {
t.Run("default constraints", func(t *testing.T) {
var s config.DownstreamMTLSSettings
c, err := ClientCertConstraintsFromConfig(&s)
c, err := evaluator.ClientCertConstraintsFromConfig(&s)
assert.NoError(t, err)
assert.Equal(t, &ClientCertConstraints{MaxVerifyDepth: 1}, c)
assert.Equal(t, &evaluator.ClientCertConstraints{MaxVerifyDepth: 1}, c)
})
t.Run("no constraints", func(t *testing.T) {
s := config.DownstreamMTLSSettings{
MaxVerifyDepth: new(uint32),
}
c, err := ClientCertConstraintsFromConfig(&s)
c, err := evaluator.ClientCertConstraintsFromConfig(&s)
assert.NoError(t, err)
assert.Equal(t, &ClientCertConstraints{}, c)
assert.Equal(t, &evaluator.ClientCertConstraints{}, c)
})
t.Run("larger max depth", func(t *testing.T) {
depth := uint32(5)
s := config.DownstreamMTLSSettings{
MaxVerifyDepth: &depth,
}
c, err := ClientCertConstraintsFromConfig(&s)
c, err := evaluator.ClientCertConstraintsFromConfig(&s)
assert.NoError(t, err)
assert.Equal(t, &ClientCertConstraints{MaxVerifyDepth: 5}, c)
assert.Equal(t, &evaluator.ClientCertConstraints{MaxVerifyDepth: 5}, c)
})
t.Run("one SAN match", func(t *testing.T) {
s := config.DownstreamMTLSSettings{
@ -433,11 +434,11 @@ func TestClientCertConstraintsFromConfig(t *testing.T) {
{Type: config.SANTypeDNS, Pattern: `.*\.corp\.example\.com`},
},
}
c, err := ClientCertConstraintsFromConfig(&s)
c, err := evaluator.ClientCertConstraintsFromConfig(&s)
assert.NoError(t, err)
assert.Equal(t, &ClientCertConstraints{
assert.Equal(t, &evaluator.ClientCertConstraints{
MaxVerifyDepth: 1,
SANMatchers: SANMatchers{
SANMatchers: evaluator.SANMatchers{
config.SANTypeDNS: regexp.MustCompile(`^(.*\.corp\.example\.com)$`),
},
}, c)
@ -449,11 +450,11 @@ func TestClientCertConstraintsFromConfig(t *testing.T) {
{Type: config.SANTypeDNS, Pattern: `.*\.bar\.example\.com`},
},
}
c, err := ClientCertConstraintsFromConfig(&s)
c, err := evaluator.ClientCertConstraintsFromConfig(&s)
assert.NoError(t, err)
assert.Equal(t, &ClientCertConstraints{
assert.Equal(t, &evaluator.ClientCertConstraints{
MaxVerifyDepth: 1,
SANMatchers: SANMatchers{
SANMatchers: evaluator.SANMatchers{
config.SANTypeDNS: regexp.MustCompile(`^(.*\.foo\.example\.com)|(.*\.bar\.example\.com)$`),
},
}, c)
@ -465,11 +466,11 @@ func TestClientCertConstraintsFromConfig(t *testing.T) {
{Type: config.SANTypeEmail, Pattern: `.*@example\.com`},
},
}
c, err := ClientCertConstraintsFromConfig(&s)
c, err := evaluator.ClientCertConstraintsFromConfig(&s)
assert.NoError(t, err)
assert.Equal(t, &ClientCertConstraints{
assert.Equal(t, &evaluator.ClientCertConstraints{
MaxVerifyDepth: 1,
SANMatchers: SANMatchers{
SANMatchers: evaluator.SANMatchers{
config.SANTypeDNS: regexp.MustCompile(`^(.*\.foo\.example\.com)$`),
config.SANTypeEmail: regexp.MustCompile(`^(.*@example\.com)$`),
},
@ -481,7 +482,7 @@ func TestClientCertConstraintsFromConfig(t *testing.T) {
{Type: config.SANTypeDNS, Pattern: "["},
},
}
_, err := ClientCertConstraintsFromConfig(&s)
_, err := evaluator.ClientCertConstraintsFromConfig(&s)
require.Error(t, err)
})
}
@ -514,7 +515,7 @@ func TestGetUserPrincipalNamesFromSAN(t *testing.T) {
san, err := asn1.Marshal(SAN{upn("hello")})
require.NoError(t, err)
names, err := getUserPrincipalNamesFromSAN(san)
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
assert.NoError(t, err)
assert.Equal(t, []string{"hello"}, names)
})
@ -527,7 +528,7 @@ func TestGetUserPrincipalNamesFromSAN(t *testing.T) {
san, err := asn1.Marshal(SAN{upn("foo"), upn("bar"), upn("baz")})
require.NoError(t, err)
names, err := getUserPrincipalNamesFromSAN(san)
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
assert.NoError(t, err)
assert.Equal(t, []string{"foo", "bar", "baz"}, names)
})
@ -544,7 +545,7 @@ func TestGetUserPrincipalNamesFromSAN(t *testing.T) {
})
require.NoError(t, err)
names, err := getUserPrincipalNamesFromSAN(san)
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
assert.NoError(t, err)
assert.Empty(t, names)
})
@ -571,7 +572,7 @@ func TestGetUserPrincipalNamesFromSAN(t *testing.T) {
})
require.NoError(t, err)
names, err := getUserPrincipalNamesFromSAN(san)
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
assert.NoError(t, err)
assert.Equal(t, []string{"UserPrincipalName"}, names)
})
@ -587,24 +588,24 @@ func TestGetUserPrincipalNamesFromSAN(t *testing.T) {
})
require.NoError(t, err)
names, err := getUserPrincipalNamesFromSAN(san)
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
assert.ErrorContains(t, err, "expected UTF8String")
assert.Empty(t, names)
})
t.Run("EmptySAN", func(t *testing.T) {
names, err := getUserPrincipalNamesFromSAN(nil)
names, err := evaluator.XGetUserPrincipalNamesFromSAN(nil)
assert.ErrorContains(t, err, "error reading GeneralNames sequence")
assert.Empty(t, names)
})
t.Run("TruncatedGeneralName", func(t *testing.T) {
san := []byte{0x30, 0x02, 0x82, 0x05 /* 5 more bytes expected */}
names, err := getUserPrincipalNamesFromSAN(san)
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
assert.ErrorContains(t, err, "error reading GeneralName")
assert.Empty(t, names)
})
t.Run("OtherNameWrongTypeIDType", func(t *testing.T) {
san := []byte{0x30, 0x06, 0xa0, 0x04, 0x02 /* type Integer, not OID */, 0x02, 0x46, 0x01}
names, err := getUserPrincipalNamesFromSAN(san)
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
assert.ErrorContains(t, err, "error reading OtherName type ID")
assert.Empty(t, names)
})
@ -618,13 +619,13 @@ func TestGetUserPrincipalNamesFromSAN(t *testing.T) {
}
san, err := asn1.Marshal(SAN{
UPN: BadOtherName{
TypeID: oidUserPrincipalName,
TypeID: evaluator.OIDUserPrincipalName,
Value: UTF8String{"hello"},
},
})
require.NoError(t, err)
names, err := getUserPrincipalNamesFromSAN(san)
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
assert.ErrorContains(t, err, "error reading UserPrincipalName value")
assert.Empty(t, names)
})

View file

@ -1,4 +1,4 @@
package evaluator
package evaluator_test
import (
"net/http"
@ -7,17 +7,19 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/authorize/evaluator"
)
func withMockGCP(t *testing.T, f func()) {
originalGCPIdentityDocURL := GCPIdentityDocURL
originalGCPIdentityDocURL := evaluator.GCPIdentityDocURL
defer func() {
GCPIdentityDocURL = originalGCPIdentityDocURL
GCPIdentityNow = time.Now
evaluator.GCPIdentityDocURL = originalGCPIdentityDocURL
evaluator.GCPIdentityNow = time.Now
}()
now := time.Date(2020, 1, 1, 1, 0, 0, 0, time.UTC)
GCPIdentityNow = func() time.Time {
evaluator.GCPIdentityNow = func() time.Time {
return now
}
@ -28,13 +30,13 @@ func withMockGCP(t *testing.T, f func()) {
}))
defer srv.Close()
GCPIdentityDocURL = srv.URL
evaluator.GCPIdentityDocURL = srv.URL
f()
}
func TestGCPIdentityTokenSource(t *testing.T) {
withMockGCP(t, func() {
src, err := getGoogleCloudServerlessTokenSource("", "example")
src, err := evaluator.XGetGoogleCloudServerlessTokenSource("", "example")
assert.NoError(t, err)
token, err := src.Token()
@ -62,7 +64,7 @@ func Test_normalizeServiceAccount(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
gotServiceAccount, err := normalizeServiceAccount(tc.serviceAccount)
gotServiceAccount, err := evaluator.XNormalizeServiceAccount(tc.serviceAccount)
assert.True(t, (err != nil) == tc.wantError)
assert.Equal(t, tc.expectedServiceAccount, gotServiceAccount)
})

View file

@ -1,4 +1,4 @@
package evaluator
package evaluator_test
import (
"bytes"
@ -20,6 +20,7 @@ import (
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/cryptutil"
@ -29,7 +30,7 @@ import (
)
func TestNewHeadersRequestFromPolicy(t *testing.T) {
req, _ := NewHeadersRequestFromPolicy(&config.Policy{
req, _ := evaluator.NewHeadersRequestFromPolicy(&config.Policy{
EnableGoogleCloudServerlessAuthentication: true,
From: "https://*.example.com",
To: config.WeightedURLs{
@ -37,18 +38,18 @@ func TestNewHeadersRequestFromPolicy(t *testing.T) {
URL: *mustParseURL("http://to.example.com"),
},
},
}, RequestHTTP{
}, evaluator.RequestHTTP{
Hostname: "from.example.com",
ClientCertificate: ClientCertificateInfo{
ClientCertificate: evaluator.ClientCertificateInfo{
Leaf: "--- FAKE CERTIFICATE ---",
},
})
assert.Equal(t, &HeadersRequest{
assert.Equal(t, &evaluator.HeadersRequest{
EnableGoogleCloudServerlessAuthentication: true,
Issuer: "from.example.com",
Audience: "from.example.com",
ToAudience: "https://to.example.com",
ClientCertificate: ClientCertificateInfo{
ClientCertificate: evaluator.ClientCertificateInfo{
Leaf: "--- FAKE CERTIFICATE ---",
},
}, req)
@ -91,21 +92,21 @@ func TestNewHeadersRequestFromPolicy_IssuerFormat(t *testing.T) {
},
} {
policy.JWTIssuerFormat = tc.format
req, err := NewHeadersRequestFromPolicy(policy, RequestHTTP{
req, err := evaluator.NewHeadersRequestFromPolicy(policy, evaluator.RequestHTTP{
Hostname: "from.example.com",
ClientCertificate: ClientCertificateInfo{
ClientCertificate: evaluator.ClientCertificateInfo{
Leaf: "--- FAKE CERTIFICATE ---",
},
})
if tc.err != "" {
assert.ErrorContains(t, err, tc.err)
} else {
assert.Equal(t, &HeadersRequest{
assert.Equal(t, &evaluator.HeadersRequest{
EnableGoogleCloudServerlessAuthentication: true,
Issuer: tc.expectedIssuer,
Audience: tc.expectedAudience,
ToAudience: "https://to.example.com",
ClientCertificate: ClientCertificateInfo{
ClientCertificate: evaluator.ClientCertificateInfo{
Leaf: "--- FAKE CERTIFICATE ---",
},
}, req)
@ -114,8 +115,8 @@ func TestNewHeadersRequestFromPolicy_IssuerFormat(t *testing.T) {
}
func TestNewHeadersRequestFromPolicy_nil(t *testing.T) {
req, _ := NewHeadersRequestFromPolicy(nil, RequestHTTP{Hostname: "from.example.com"})
assert.Equal(t, &HeadersRequest{
req, _ := evaluator.NewHeadersRequestFromPolicy(nil, evaluator.RequestHTTP{Hostname: "from.example.com"})
assert.Equal(t, &evaluator.HeadersRequest{
Issuer: "from.example.com",
Audience: "from.example.com",
}, req)
@ -138,13 +139,13 @@ func TestHeadersEvaluator(t *testing.T) {
iat := time.Unix(1686870680, 0)
eval := func(t *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) {
eval := func(t *testing.T, data []proto.Message, input *evaluator.HeadersRequest) (*evaluator.HeadersResponse, error) {
ctx := context.Background()
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
store := store.New()
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
store.UpdateSigningKey(privateJWK)
e, err := NewHeadersEvaluator(ctx, store, rego.Time(iat))
e, err := evaluator.NewHeadersEvaluator(ctx, store, rego.Time(iat))
require.NoError(t, err)
return e.Evaluate(ctx, input, rego.EvalTime(iat))
}
@ -159,11 +160,11 @@ func TestHeadersEvaluator(t *testing.T) {
}},
}, IssuedAt: timestamppb.New(iat)},
},
&HeadersRequest{
&evaluator.HeadersRequest{
Issuer: "from.example.com",
Audience: "from.example.com",
ToAudience: "to.example.com",
Session: RequestSession{
Session: evaluator.RequestSession{
ID: "s1",
},
})
@ -215,11 +216,11 @@ func TestHeadersEvaluator(t *testing.T) {
AccessToken: "ACCESS_TOKEN",
}},
},
&HeadersRequest{
&evaluator.HeadersRequest{
Issuer: "from.example.com",
Audience: "from.example.com",
ToAudience: "to.example.com",
Session: RequestSession{ID: "s1"},
Session: evaluator.RequestSession{ID: "s1"},
SetRequestHeaders: map[string]string{
"X-Custom-Header": "CUSTOM_VALUE",
"X-ID-Token": "${pomerium.id_token}",
@ -228,7 +229,7 @@ func TestHeadersEvaluator(t *testing.T) {
"Authorization": "Bearer ${pomerium.jwt}",
"Foo": "escaped $$dollar sign",
},
ClientCertificate: ClientCertificateInfo{Leaf: testValidCert},
ClientCertificate: evaluator.ClientCertificateInfo{Leaf: testValidCert},
})
require.NoError(t, err)
@ -258,11 +259,11 @@ func TestHeadersEvaluator(t *testing.T) {
AccessToken: "ACCESS_TOKEN",
}},
},
&HeadersRequest{
&evaluator.HeadersRequest{
Issuer: "from.example.com",
Audience: "from.example.com",
ToAudience: "to.example.com",
Session: RequestSession{ID: "s1"},
Session: evaluator.RequestSession{ID: "s1"},
SetRequestHeaders: map[string]string{
"X-ID-Token": "${pomerium.id_token}",
},
@ -281,11 +282,11 @@ func TestHeadersEvaluator(t *testing.T) {
AccessToken: "ACCESS_TOKEN",
}},
},
&HeadersRequest{
&evaluator.HeadersRequest{
Issuer: "from.example.com",
Audience: "from.example.com",
ToAudience: "to.example.com",
Session: RequestSession{ID: "s1"},
Session: evaluator.RequestSession{ID: "s1"},
SetRequestHeaders: map[string]string{
"Authorization": "Bearer ${pomerium.id_token}",
},
@ -297,7 +298,7 @@ func TestHeadersEvaluator(t *testing.T) {
t.Run("set_request_headers no client cert", func(t *testing.T) {
output, err := eval(t, nil,
&HeadersRequest{
&evaluator.HeadersRequest{
Issuer: "from.example.com",
Audience: "from.example.com",
ToAudience: "to.example.com",
@ -318,12 +319,12 @@ func TestHeadersEvaluator(t *testing.T) {
&session.Session{Id: "s1", UserId: "u1"},
&user.User{Id: "u1", Email: "u1@example.com"},
},
&HeadersRequest{
&evaluator.HeadersRequest{
Issuer: "from.example.com",
Audience: "from.example.com",
ToAudience: "to.example.com",
KubernetesServiceAccountToken: "TOKEN",
Session: RequestSession{ID: "s1"},
Session: evaluator.RequestSession{ID: "s1"},
})
require.NoError(t, err)
assert.Equal(t, "Bearer TOKEN", output.Headers.Get("Authorization"))

View file

@ -3,11 +3,10 @@ package evaluator
import (
"context"
"fmt"
rttrace "runtime/trace"
"strings"
"github.com/open-policy-agent/opa/rego"
octrace "go.opencensus.io/trace"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
@ -16,6 +15,7 @@ import (
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/policy"
"github.com/pomerium/pomerium/pkg/policy/criteria"
octrace "go.opencensus.io/trace"
)
// PolicyRequest is the input to policy evaluation.
@ -109,25 +109,36 @@ type PolicyEvaluator struct {
// NewPolicyEvaluator creates a new PolicyEvaluator.
func NewPolicyEvaluator(
ctx context.Context, store *store.Store, configPolicy *config.Policy,
ctx context.Context,
store *store.Store,
configPolicy *config.Policy,
policyChecksum uint64,
addDefaultClientCertificateRule bool,
cache *QueryCache,
) (*PolicyEvaluator, error) {
e := new(PolicyEvaluator)
e.policyChecksum = configPolicy.Checksum()
e.policyChecksum = policyChecksum
var err error
rttrace.WithRegion(ctx, "Generate Rego", func() {
// generate the base rego script for the policy
ppl := configPolicy.ToPPL()
if addDefaultClientCertificateRule {
ppl.AddDefaultClientCertificateRule()
}
base, err := policy.GenerateRegoFromPolicy(ppl)
var base string
base, err = policy.GenerateRegoFromPolicy(ppl)
if err != nil {
return nil, err
return
}
e.queries = []policyQuery{{
script: base,
}}
})
if err != nil {
return nil, err
}
// add any custom rego
for _, sp := range configPolicy.SubPolicies {
@ -145,15 +156,12 @@ func NewPolicyEvaluator(
}
}
// for each script, create a rego and prepare a query.
// for each script, create a rego object and prepare a query.
rttrace.WithRegion(ctx, "Compile Rego", func() {
numCached := 0
for i := range e.queries {
log.Ctx(ctx).
Trace().
Str("script", e.queries[i].script).
Str("from", configPolicy.From).
Interface("to", configPolicy.To).
Msg("authorize: rego script for policy evaluation")
var cached bool
e.queries[i].PreparedEvalQuery, cached, err = cache.LookupOrBuild(&e.queries[i], func() (rego.PreparedEvalQuery, error) {
r := rego.New(
rego.Store(store),
rego.Module("pomerium.policy", e.queries[i].script),
@ -177,10 +185,20 @@ func NewPolicyEvaluator(
q, err = r.PrepareForEval(ctx)
}
if err != nil {
return nil, err
return rego.PreparedEvalQuery{}, err
}
e.queries[i].PreparedEvalQuery = q
return q, nil
})
if err != nil {
return
}
if cached {
numCached++
}
}
})
if err != nil {
return nil, err
}
return e, nil

View file

@ -1,4 +1,4 @@
package evaluator
package evaluator_test
import (
"context"
@ -13,6 +13,7 @@ import (
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/contextutil"
@ -34,13 +35,15 @@ func TestPolicyEvaluator(t *testing.T) {
var addDefaultClientCertificateRule bool
eval := func(t *testing.T, policy *config.Policy, data []proto.Message, input *PolicyRequest) (*PolicyResponse, error) {
eval := func(t *testing.T, policy *config.Policy, data []proto.Message, input *evaluator.PolicyRequest) (*evaluator.PolicyResponse, error) {
ctx := context.Background()
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
store := store.New()
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
store.UpdateSigningKey(privateJWK)
e, err := NewPolicyEvaluator(ctx, store, policy, addDefaultClientCertificateRule)
checksum := policy.Checksum()
queryCache := evaluator.NewQueryCache(1)
e, err := evaluator.NewPolicyEvaluator(ctx, store, policy, checksum, addDefaultClientCertificateRule, queryCache)
require.NoError(t, err)
return e.Evaluate(ctx, input)
}
@ -71,14 +74,14 @@ func TestPolicyEvaluator(t *testing.T) {
output, err := eval(t,
p1,
[]proto.Message{s1, u1, s2, u2},
&PolicyRequest{
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: RequestSession{ID: "s1"},
&evaluator.PolicyRequest{
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: evaluator.RequestSession{ID: "s1"},
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: NewRuleResult(true, criteria.ReasonEmailOK),
Deny: NewRuleResult(false),
assert.Equal(t, &evaluator.PolicyResponse{
Allow: evaluator.NewRuleResult(true, criteria.ReasonEmailOK),
Deny: evaluator.NewRuleResult(false),
Traces: []contextutil.PolicyEvaluationTrace{{Allow: true}},
}, output)
})
@ -86,14 +89,14 @@ func TestPolicyEvaluator(t *testing.T) {
output, err := eval(t,
p1,
[]proto.Message{s1, u1, s2, u2},
&PolicyRequest{
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: RequestSession{ID: "s2"},
&evaluator.PolicyRequest{
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: evaluator.RequestSession{ID: "s2"},
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: NewRuleResult(false, criteria.ReasonEmailUnauthorized, criteria.ReasonUserUnauthorized),
Deny: NewRuleResult(false),
assert.Equal(t, &evaluator.PolicyResponse{
Allow: evaluator.NewRuleResult(false, criteria.ReasonEmailUnauthorized, criteria.ReasonUserUnauthorized),
Deny: evaluator.NewRuleResult(false),
Traces: []contextutil.PolicyEvaluationTrace{{}},
}, output)
})
@ -105,16 +108,16 @@ func TestPolicyEvaluator(t *testing.T) {
output, err := eval(t,
p1,
[]proto.Message{s1, u1, s2, u2},
&PolicyRequest{
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: RequestSession{ID: "s1"},
&evaluator.PolicyRequest{
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: evaluator.RequestSession{ID: "s1"},
IsValidClientCertificate: true,
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: NewRuleResult(true, criteria.ReasonEmailOK),
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificate),
assert.Equal(t, &evaluator.PolicyResponse{
Allow: evaluator.NewRuleResult(true, criteria.ReasonEmailOK),
Deny: evaluator.NewRuleResult(false, criteria.ReasonValidClientCertificate),
Traces: []contextutil.PolicyEvaluationTrace{{Allow: true}},
}, output)
})
@ -122,16 +125,16 @@ func TestPolicyEvaluator(t *testing.T) {
output, err := eval(t,
p1,
[]proto.Message{s1, u1, s2, u2},
&PolicyRequest{
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: RequestSession{ID: "s1"},
&evaluator.PolicyRequest{
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: evaluator.RequestSession{ID: "s1"},
IsValidClientCertificate: false,
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: NewRuleResult(true, criteria.ReasonEmailOK),
Deny: NewRuleResult(true, criteria.ReasonClientCertificateRequired),
assert.Equal(t, &evaluator.PolicyResponse{
Allow: evaluator.NewRuleResult(true, criteria.ReasonEmailOK),
Deny: evaluator.NewRuleResult(true, criteria.ReasonClientCertificateRequired),
Traces: []contextutil.PolicyEvaluationTrace{{Allow: true, Deny: true}},
}, output)
})
@ -139,20 +142,20 @@ func TestPolicyEvaluator(t *testing.T) {
output, err := eval(t,
p1,
[]proto.Message{s1, u1, s2, u2},
&PolicyRequest{
HTTP: RequestHTTP{
&evaluator.PolicyRequest{
HTTP: evaluator.RequestHTTP{
Method: http.MethodGet,
URL: "https://from.example.com/path",
ClientCertificate: ClientCertificateInfo{Presented: true},
ClientCertificate: evaluator.ClientCertificateInfo{Presented: true},
},
Session: RequestSession{ID: "s1"},
Session: evaluator.RequestSession{ID: "s1"},
IsValidClientCertificate: false,
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: NewRuleResult(true, criteria.ReasonEmailOK),
Deny: NewRuleResult(true, criteria.ReasonInvalidClientCertificate),
assert.Equal(t, &evaluator.PolicyResponse{
Allow: evaluator.NewRuleResult(true, criteria.ReasonEmailOK),
Deny: evaluator.NewRuleResult(true, criteria.ReasonInvalidClientCertificate),
Traces: []contextutil.PolicyEvaluationTrace{{Allow: true, Deny: true}},
}, output)
})
@ -160,16 +163,16 @@ func TestPolicyEvaluator(t *testing.T) {
output, err := eval(t,
p1,
[]proto.Message{s1, u1, s2, u2},
&PolicyRequest{
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: RequestSession{ID: "s2"},
&evaluator.PolicyRequest{
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: evaluator.RequestSession{ID: "s2"},
IsValidClientCertificate: true,
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: NewRuleResult(false, criteria.ReasonEmailUnauthorized, criteria.ReasonUserUnauthorized),
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificate),
assert.Equal(t, &evaluator.PolicyResponse{
Allow: evaluator.NewRuleResult(false, criteria.ReasonEmailUnauthorized, criteria.ReasonUserUnauthorized),
Deny: evaluator.NewRuleResult(false, criteria.ReasonValidClientCertificate),
Traces: []contextutil.PolicyEvaluationTrace{{}},
}, output)
})
@ -192,16 +195,16 @@ func TestPolicyEvaluator(t *testing.T) {
output, err := eval(t,
p,
[]proto.Message{s1, u1, s2, u2},
&PolicyRequest{
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: RequestSession{ID: "s1"},
&evaluator.PolicyRequest{
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: evaluator.RequestSession{ID: "s1"},
IsValidClientCertificate: true,
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: NewRuleResult(true, criteria.ReasonAccept),
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificate),
assert.Equal(t, &evaluator.PolicyResponse{
Allow: evaluator.NewRuleResult(true, criteria.ReasonAccept),
Deny: evaluator.NewRuleResult(false, criteria.ReasonValidClientCertificate),
Traces: []contextutil.PolicyEvaluationTrace{{}, {ID: "p1", Allow: true}},
}, output)
})
@ -222,16 +225,16 @@ func TestPolicyEvaluator(t *testing.T) {
output, err := eval(t,
p,
[]proto.Message{s1, u1, s2, u2},
&PolicyRequest{
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: RequestSession{ID: "s1"},
&evaluator.PolicyRequest{
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: evaluator.RequestSession{ID: "s1"},
IsValidClientCertificate: true,
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: NewRuleResult(false),
Deny: NewRuleResult(true, criteria.ReasonAccept),
assert.Equal(t, &evaluator.PolicyResponse{
Allow: evaluator.NewRuleResult(false),
Deny: evaluator.NewRuleResult(true, criteria.ReasonAccept),
Traces: []contextutil.PolicyEvaluationTrace{{}, {ID: "p1", Deny: true}},
}, output)
})
@ -253,16 +256,16 @@ func TestPolicyEvaluator(t *testing.T) {
output, err := eval(t,
p,
[]proto.Message{s1, u1, s2, u2},
&PolicyRequest{
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: RequestSession{ID: "s1"},
&evaluator.PolicyRequest{
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: evaluator.RequestSession{ID: "s1"},
IsValidClientCertificate: false,
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: NewRuleResult(false),
Deny: NewRuleResult(true, criteria.ReasonAccept, criteria.ReasonClientCertificateRequired),
assert.Equal(t, &evaluator.PolicyResponse{
Allow: evaluator.NewRuleResult(false),
Deny: evaluator.NewRuleResult(true, criteria.ReasonAccept, criteria.ReasonClientCertificateRequired),
Traces: []contextutil.PolicyEvaluationTrace{{Deny: true}, {ID: "p1", Deny: true}},
}, output)
})
@ -292,16 +295,16 @@ func TestPolicyEvaluator(t *testing.T) {
output, err := eval(t,
p,
[]proto.Message{s1, u1, s2, u2, r1},
&PolicyRequest{
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: RequestSession{ID: "s1"},
&evaluator.PolicyRequest{
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: evaluator.RequestSession{ID: "s1"},
IsValidClientCertificate: true,
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: NewRuleResult(true),
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificate),
assert.Equal(t, &evaluator.PolicyResponse{
Allow: evaluator.NewRuleResult(true),
Deny: evaluator.NewRuleResult(false, criteria.ReasonValidClientCertificate),
Traces: []contextutil.PolicyEvaluationTrace{{}, {ID: "p1", Allow: true}},
}, output)
})
@ -315,16 +318,16 @@ func TestPolicyEvaluator(t *testing.T) {
UserId: "u1",
},
},
&PolicyRequest{
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: RequestSession{ID: "sa1"},
&evaluator.PolicyRequest{
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: evaluator.RequestSession{ID: "sa1"},
IsValidClientCertificate: true,
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: NewRuleResult(true, criteria.ReasonEmailOK),
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificate),
assert.Equal(t, &evaluator.PolicyResponse{
Allow: evaluator.NewRuleResult(true, criteria.ReasonEmailOK),
Deny: evaluator.NewRuleResult(false, criteria.ReasonValidClientCertificate),
Traces: []contextutil.PolicyEvaluationTrace{{Allow: true}},
}, output)
})
@ -339,16 +342,16 @@ func TestPolicyEvaluator(t *testing.T) {
ExpiresAt: timestamppb.New(time.Now().Add(-time.Second)),
},
},
&PolicyRequest{
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: RequestSession{ID: "sa1"},
&evaluator.PolicyRequest{
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
Session: evaluator.RequestSession{ID: "sa1"},
IsValidClientCertificate: true,
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: NewRuleResult(false, criteria.ReasonUserUnauthenticated),
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificate),
assert.Equal(t, &evaluator.PolicyResponse{
Allow: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
Deny: evaluator.NewRuleResult(false, criteria.ReasonValidClientCertificate),
Traces: []contextutil.PolicyEvaluationTrace{{Allow: false}},
}, output)
})

View file

@ -17,9 +17,10 @@ import (
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/config/envoyconfig"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
@ -49,15 +50,27 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
-----END CERTIFICATE-----`
func Test_getEvaluatorRequest(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(&config.Options{
Policies: []config.Policy{{
policies := []config.Policy{{
From: "https://example.com",
To: mustParseWeightedURLs(t, "https://foo.bar"),
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
}},
})
}}
policy0RouteID, err := policies[0].RouteID()
require.NoError(t, err)
cfg := &config.Config{
Options: &config.Options{
SharedKey: cryptutil.NewBase64Key(),
CookieSecret: cryptutil.NewBase64Key(),
Policies: policies,
},
}
a := New()
a.OnConfigChange(context.Background(), cfg)
require.True(t, a.HasValidState())
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
&envoy_service_auth_v3.CheckRequest{
@ -76,6 +89,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
Body: "BODY",
},
},
ContextExtensions: envoyconfig.MakeExtAuthzContextExtensions(false, policy0RouteID),
MetadataContext: &envoy_config_core_v3.Metadata{
FilterMetadata: map[string]*structpb.Struct{
"com.pomerium.client-certificate-info": {
@ -117,15 +131,27 @@ func Test_getEvaluatorRequest(t *testing.T) {
}
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(&config.Options{
Policies: []config.Policy{{
policies := []config.Policy{{
From: "https://example.com",
To: mustParseWeightedURLs(t, "https://foo.bar"),
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
}},
})
}}
policy0RouteID, err := policies[0].RouteID()
require.NoError(t, err)
cfg := &config.Config{
Options: &config.Options{
SharedKey: cryptutil.NewBase64Key(),
CookieSecret: cryptutil.NewBase64Key(),
Policies: policies,
},
}
a := New()
a.OnConfigChange(context.Background(), cfg)
require.True(t, a.HasValidState())
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
&envoy_service_auth_v3.CheckRequest{
@ -144,11 +170,12 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
Body: "BODY",
},
},
ContextExtensions: envoyconfig.MakeExtAuthzContextExtensions(false, policy0RouteID),
},
}, nil)
require.NoError(t, err)
expect := &evaluator.Request{
Policy: &a.currentOptions.Load().Policies[0],
Policy: &policies[0],
Session: evaluator.RequestSession{},
HTTP: evaluator.NewRequestHTTP(
http.MethodGet,

2
go.mod
View file

@ -80,6 +80,7 @@ require (
go.uber.org/mock v0.5.0
go.uber.org/zap v1.27.0
golang.org/x/crypto v0.28.0
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa
golang.org/x/net v0.30.0
golang.org/x/oauth2 v0.23.0
golang.org/x/sync v0.8.0
@ -227,7 +228,6 @@ require (
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.28.0 // indirect
go.opentelemetry.io/proto/otlp v1.3.1 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
golang.org/x/mod v0.20.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/tools v0.24.0 // indirect

View file

@ -210,16 +210,13 @@ func setupAuthenticate(ctx context.Context, src config.Source, controlPlane *con
}
func setupAuthorize(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) {
svc, err := authorize.New(ctx, src.GetConfig())
if err != nil {
return nil, fmt.Errorf("error creating authorize service: %w", err)
}
envoy_service_auth_v3.RegisterAuthorizationServer(controlPlane.GRPCServer, svc)
a := authorize.New()
envoy_service_auth_v3.RegisterAuthorizationServer(controlPlane.GRPCServer, a)
log.Ctx(ctx).Info().Msg("enabled authorize service")
src.OnConfigChange(ctx, svc.OnConfigChange)
svc.OnConfigChange(ctx, src.GetConfig())
return svc, nil
src.OnConfigChange(ctx, a.OnConfigChange)
a.OnConfigChange(ctx, src.GetConfig())
return a, nil
}
func setupDataBroker(ctx context.Context,