mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
Optimize evaluator
This optimizes the Evaluator in the Authorize service to scale to very large numbers of routes. Additional caching was also added when building rego policy query evaluators in parallel to allow sharing work and to avoid building evaluators for scripts with the same contents.
This commit is contained in:
parent
526e2a58d6
commit
a396c2eab3
16 changed files with 1539 additions and 483 deletions
|
@ -4,6 +4,7 @@ package authorize
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
|
@ -40,26 +41,29 @@ type Authorize struct {
|
|||
}
|
||||
|
||||
// New validates and creates a new Authorize service from a set of config options.
|
||||
func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
||||
func New() *Authorize {
|
||||
a := &Authorize{
|
||||
currentOptions: config.NewAtomicOptions(),
|
||||
state: atomicutil.NewValue[*authorizeState](nil),
|
||||
store: store.New(),
|
||||
globalCache: storage.NewGlobalCache(time.Minute),
|
||||
}
|
||||
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
||||
|
||||
state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return a
|
||||
}
|
||||
a.state = atomicutil.NewValue(state)
|
||||
|
||||
return a, nil
|
||||
func (a *Authorize) HasValidState() bool {
|
||||
return a.state.Load() != nil
|
||||
}
|
||||
|
||||
// GetDataBrokerServiceClient returns the current DataBrokerServiceClient.
|
||||
func (a *Authorize) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
||||
return a.state.Load().dataBrokerClient
|
||||
state := a.state.Load()
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
return state.dataBrokerClient
|
||||
}
|
||||
|
||||
// Run runs the authorize service.
|
||||
|
@ -70,7 +74,11 @@ func (a *Authorize) Run(ctx context.Context) error {
|
|||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
_ = grpc.WaitForReady(ctx, a.state.Load().dataBrokerClientConnection, time.Second*10)
|
||||
state := a.state.Load()
|
||||
if state == nil {
|
||||
return errors.New("authorize: invalid configuration")
|
||||
}
|
||||
_ = grpc.WaitForReady(ctx, state.dataBrokerClientConnection, time.Second*10)
|
||||
return nil
|
||||
})
|
||||
return eg.Wait()
|
||||
|
@ -151,7 +159,11 @@ func newPolicyEvaluator(
|
|||
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
currentState := a.state.Load()
|
||||
a.currentOptions.Store(cfg.Options)
|
||||
if state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, currentState.evaluator); err != nil {
|
||||
var prev *evaluator.Evaluator
|
||||
if currentState != nil {
|
||||
prev = currentState.evaluator
|
||||
}
|
||||
if state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, prev); err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
|
||||
} else {
|
||||
a.state.Store(state)
|
||||
|
|
|
@ -82,7 +82,7 @@ func TestNew(t *testing.T) {
|
|||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := New(context.Background(), &config.Config{Options: &tt.config})
|
||||
_, err := newAuthorizeStateFromConfig(context.Background(), &config.Config{Options: &tt.config}, store.New(), nil)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -114,12 +114,12 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
|
|||
SharedKey: tc.SharedKey,
|
||||
Policies: tc.Policies,
|
||||
}
|
||||
a, err := New(context.Background(), &config.Config{Options: o})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, a)
|
||||
cfg := &config.Config{Options: o}
|
||||
a := New()
|
||||
a.OnConfigChange(context.Background(), cfg)
|
||||
require.True(t, a.HasValidState())
|
||||
|
||||
oldPe := a.state.Load().evaluator
|
||||
cfg := &config.Config{Options: o}
|
||||
assertFunc := assert.True
|
||||
o.SigningKey = "bad-share-key"
|
||||
if tc.expectedChange {
|
||||
|
|
|
@ -34,8 +34,10 @@ func TestAuthorize_handleResult(t *testing.T) {
|
|||
t.Cleanup(authnSrv.Close)
|
||||
opt.AuthenticateURLString = authnSrv.URL
|
||||
|
||||
a, err := New(context.Background(), &config.Config{Options: opt})
|
||||
require.NoError(t, err)
|
||||
cfg := &config.Config{Options: opt}
|
||||
a := New()
|
||||
a.OnConfigChange(context.Background(), cfg)
|
||||
require.True(t, a.HasValidState())
|
||||
|
||||
t.Run("user-unauthenticated", func(t *testing.T) {
|
||||
res, err := a.handleResult(context.Background(),
|
||||
|
@ -327,8 +329,10 @@ func TestRequireLogin(t *testing.T) {
|
|||
t.Cleanup(authnSrv.Close)
|
||||
opt.AuthenticateURLString = authnSrv.URL
|
||||
|
||||
a, err := New(context.Background(), &config.Config{Options: opt})
|
||||
require.NoError(t, err)
|
||||
cfg := &config.Config{Options: opt}
|
||||
a := New()
|
||||
a.OnConfigChange(context.Background(), cfg)
|
||||
require.True(t, a.HasValidState())
|
||||
|
||||
t.Run("accept empty", func(t *testing.T) {
|
||||
res, err := a.requireLoginResponse(context.Background(),
|
||||
|
|
|
@ -65,12 +65,13 @@ func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) {
|
|||
t.Cleanup(clearTimeout)
|
||||
|
||||
opt := config.NewDefaultOptions()
|
||||
a, err := New(context.Background(), &config.Config{Options: opt})
|
||||
require.NoError(t, err)
|
||||
a := New()
|
||||
a.OnConfigChange(context.Background(), &config.Config{Options: opt})
|
||||
require.True(t, a.HasValidState())
|
||||
|
||||
s1 := &session.Session{Id: "s1", ExpiresAt: timestamppb.New(time.Now().Add(-time.Second))}
|
||||
sq := storage.NewStaticQuerier(s1)
|
||||
qctx := storage.WithQuerier(ctx, sq)
|
||||
_, err = a.getDataBrokerSessionOrServiceAccount(qctx, "s1", 0)
|
||||
_, err := a.getDataBrokerSessionOrServiceAccount(qctx, "s1", 0)
|
||||
assert.ErrorIs(t, err, session.ErrSessionExpired)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package evaluator
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/hashutil"
|
||||
)
|
||||
|
@ -15,22 +17,26 @@ type evaluatorConfig struct {
|
|||
AuthenticateURL string
|
||||
GoogleCloudServerlessAuthenticationServiceAccount string
|
||||
JWTClaimsHeaders config.JWTClaimHeaders
|
||||
|
||||
cacheKeyOnce sync.Once
|
||||
computedCacheKey uint64
|
||||
}
|
||||
|
||||
// cacheKey() returns a hash over the configuration, except for the policies.
|
||||
func (e *evaluatorConfig) cacheKey() uint64 {
|
||||
return hashutil.MustHash(e)
|
||||
e.cacheKeyOnce.Do(func() {
|
||||
e.computedCacheKey = hashutil.MustHash(e)
|
||||
})
|
||||
return e.computedCacheKey
|
||||
}
|
||||
|
||||
// An Option customizes the evaluator config.
|
||||
type Option func(*evaluatorConfig)
|
||||
|
||||
func getConfig(options ...Option) *evaluatorConfig {
|
||||
cfg := new(evaluatorConfig)
|
||||
for _, o := range options {
|
||||
o(cfg)
|
||||
func (e *evaluatorConfig) apply(options ...Option) {
|
||||
for _, opt := range options {
|
||||
opt(e)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// WithPolicies sets the policies in the config.
|
||||
|
|
|
@ -2,21 +2,29 @@
|
|||
package evaluator
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"maps"
|
||||
"math/bits"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
rttrace "runtime/trace"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/hashicorp/go-set/v3"
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/errgrouputil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
|
@ -90,51 +98,190 @@ type Result struct {
|
|||
Traces []contextutil.PolicyEvaluationTrace
|
||||
}
|
||||
|
||||
type PolicyEvaluatorCacheStats struct {
|
||||
CacheHits int64
|
||||
CacheMisses int64
|
||||
}
|
||||
|
||||
type PolicyEvaluatorCache struct {
|
||||
evalsMu sync.RWMutex
|
||||
evaluatorsByRouteID map[uint64]*PolicyEvaluator
|
||||
|
||||
cacheHits atomic.Int64
|
||||
cacheMisses atomic.Int64
|
||||
}
|
||||
|
||||
type QueryCacheStats struct {
|
||||
CacheHits int64
|
||||
CacheMisses int64
|
||||
BuildsSucceeded int64
|
||||
BuildsFailed int64
|
||||
BuildsShared int64
|
||||
}
|
||||
|
||||
type QueryCache struct {
|
||||
queriesMu sync.RWMutex
|
||||
queriesByScriptChecksum map[string]rego.PreparedEvalQuery
|
||||
sf singleflight.Group
|
||||
|
||||
cacheHits atomic.Int64
|
||||
cacheMisses atomic.Int64
|
||||
buildsSucceeded atomic.Int64
|
||||
buildsFailed atomic.Int64
|
||||
buildsShared atomic.Int64
|
||||
}
|
||||
|
||||
func NewPolicyEvaluatorCache(initialSize int) *PolicyEvaluatorCache {
|
||||
return &PolicyEvaluatorCache{
|
||||
evaluatorsByRouteID: make(map[uint64]*PolicyEvaluator, initialSize),
|
||||
}
|
||||
}
|
||||
|
||||
func NewQueryCache(initialSize int) *QueryCache {
|
||||
return &QueryCache{
|
||||
queriesByScriptChecksum: make(map[string]rego.PreparedEvalQuery, initialSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PolicyEvaluatorCache) NumCachedEvaluators() int {
|
||||
c.evalsMu.RLock()
|
||||
defer c.evalsMu.RUnlock()
|
||||
return len(c.evaluatorsByRouteID)
|
||||
}
|
||||
|
||||
func (c *PolicyEvaluatorCache) StoreEvaluator(routeID uint64, eval *PolicyEvaluator) {
|
||||
c.evalsMu.Lock()
|
||||
defer c.evalsMu.Unlock()
|
||||
c.evaluatorsByRouteID[routeID] = eval
|
||||
}
|
||||
|
||||
func (c *PolicyEvaluatorCache) LookupEvaluator(routeID uint64) (*PolicyEvaluator, bool) {
|
||||
c.evalsMu.RLock()
|
||||
defer c.evalsMu.RUnlock()
|
||||
eval, ok := c.evaluatorsByRouteID[routeID]
|
||||
if ok {
|
||||
c.cacheHits.Add(1)
|
||||
} else {
|
||||
c.cacheMisses.Add(1)
|
||||
}
|
||||
return eval, ok
|
||||
}
|
||||
|
||||
func (c *PolicyEvaluatorCache) Stats() PolicyEvaluatorCacheStats {
|
||||
return PolicyEvaluatorCacheStats{
|
||||
CacheHits: c.cacheHits.Load(),
|
||||
CacheMisses: c.cacheMisses.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *QueryCache) NumCachedQueries() int {
|
||||
c.queriesMu.RLock()
|
||||
defer c.queriesMu.RUnlock()
|
||||
return len(c.queriesByScriptChecksum)
|
||||
}
|
||||
|
||||
func (c *QueryCache) Stats() QueryCacheStats {
|
||||
return QueryCacheStats{
|
||||
CacheHits: c.cacheHits.Load(),
|
||||
CacheMisses: c.cacheMisses.Load(),
|
||||
BuildsSucceeded: c.buildsSucceeded.Load(),
|
||||
BuildsFailed: c.buildsFailed.Load(),
|
||||
BuildsShared: c.buildsShared.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *QueryCache) LookupOrBuild(q *policyQuery, builder func() (rego.PreparedEvalQuery, error)) (rego.PreparedEvalQuery, bool, error) {
|
||||
checksum := q.checksum()
|
||||
c.queriesMu.RLock()
|
||||
cached, ok := c.queriesByScriptChecksum[checksum]
|
||||
c.queriesMu.RUnlock()
|
||||
if ok {
|
||||
c.cacheHits.Add(1)
|
||||
return cached, true, nil
|
||||
}
|
||||
c.cacheMisses.Add(1)
|
||||
var ours bool
|
||||
pq, err, shared := c.sf.Do(checksum, func() (any, error) {
|
||||
ours = true
|
||||
res, err := builder()
|
||||
if err == nil {
|
||||
c.queriesMu.Lock()
|
||||
c.queriesByScriptChecksum[checksum] = res
|
||||
c.queriesMu.Unlock()
|
||||
c.buildsSucceeded.Add(1)
|
||||
} else {
|
||||
c.buildsFailed.Add(1)
|
||||
}
|
||||
return res, err
|
||||
})
|
||||
if err != nil {
|
||||
return rego.PreparedEvalQuery{}, false, err
|
||||
}
|
||||
if shared && !ours {
|
||||
c.buildsShared.Add(1)
|
||||
}
|
||||
return pq.(rego.PreparedEvalQuery), false, nil
|
||||
}
|
||||
|
||||
// An Evaluator evaluates policies.
|
||||
type Evaluator struct {
|
||||
opts *evaluatorConfig
|
||||
store *store.Store
|
||||
policyEvaluators map[uint64]*PolicyEvaluator
|
||||
headersEvaluators *HeadersEvaluator
|
||||
clientCA []byte
|
||||
clientCRL []byte
|
||||
clientCertConstraints ClientCertConstraints
|
||||
|
||||
cfgCacheKey uint64
|
||||
evalCache *PolicyEvaluatorCache
|
||||
queryCache *QueryCache
|
||||
headersEvaluator *HeadersEvaluator
|
||||
}
|
||||
|
||||
// New creates a new Evaluator.
|
||||
func New(
|
||||
ctx context.Context, store *store.Store, previous *Evaluator, options ...Option,
|
||||
ctx context.Context,
|
||||
store *store.Store,
|
||||
previous *Evaluator,
|
||||
options ...Option,
|
||||
) (*Evaluator, error) {
|
||||
cfg := getConfig(options...)
|
||||
ctx, task := rttrace.NewTask(ctx, "evaluator.New")
|
||||
defer task.End()
|
||||
defer rttrace.StartRegion(ctx, "evaluator.New").End()
|
||||
|
||||
err := updateStore(ctx, store, cfg)
|
||||
var opts evaluatorConfig
|
||||
opts.apply(options...)
|
||||
|
||||
e := &Evaluator{
|
||||
opts: &opts,
|
||||
store: store,
|
||||
}
|
||||
|
||||
if previous == nil || opts.cacheKey() != previous.opts.cacheKey() || store != previous.store {
|
||||
var err error
|
||||
rttrace.WithRegion(ctx, "update store", func() {
|
||||
err = updateStore(ctx, store, &opts, previous)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e := &Evaluator{
|
||||
store: store,
|
||||
clientCA: cfg.ClientCA,
|
||||
clientCRL: cfg.ClientCRL,
|
||||
clientCertConstraints: cfg.ClientCertConstraints,
|
||||
cfgCacheKey: cfg.cacheKey(),
|
||||
rttrace.WithRegion(ctx, "create headers evaluator", func() {
|
||||
e.headersEvaluator, err = NewHeadersEvaluator(ctx, store)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e.evalCache = NewPolicyEvaluatorCache(len(opts.Policies))
|
||||
e.queryCache = NewQueryCache(len(opts.Policies))
|
||||
} else {
|
||||
// If there is a previous Evaluator constructed from the same settings, we
|
||||
// can reuse the HeadersEvaluator along with any PolicyEvaluators for
|
||||
// unchanged policies.
|
||||
var cachedPolicyEvaluators map[uint64]*PolicyEvaluator
|
||||
if previous != nil && previous.cfgCacheKey == e.cfgCacheKey {
|
||||
e.headersEvaluators = previous.headersEvaluators
|
||||
cachedPolicyEvaluators = previous.policyEvaluators
|
||||
} else {
|
||||
e.headersEvaluators, err = NewHeadersEvaluator(ctx, store)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
e.headersEvaluator = previous.headersEvaluator
|
||||
e.evalCache = previous.evalCache
|
||||
e.queryCache = previous.queryCache
|
||||
}
|
||||
}
|
||||
e.policyEvaluators, err = getOrCreatePolicyEvaluators(ctx, cfg, store, cachedPolicyEvaluators)
|
||||
|
||||
var err error
|
||||
rttrace.WithRegion(ctx, "update policy evaluators", func() {
|
||||
err = getOrCreatePolicyEvaluators(ctx, &opts, store, e.evalCache, e.queryCache)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -142,62 +289,321 @@ func New(
|
|||
return e, nil
|
||||
}
|
||||
|
||||
var (
|
||||
workerPoolSize = runtime.NumCPU() - 1
|
||||
workerPoolMu sync.Mutex
|
||||
workerPoolTaskQueue = make(chan func(), (workerPoolSize+1)*2)
|
||||
)
|
||||
|
||||
func init() {
|
||||
for i := 0; i < workerPoolSize; i++ {
|
||||
// the worker function is separate so that it shows up in stack traces as
|
||||
// 'worker' instead of an anonymous function in init()
|
||||
go worker()
|
||||
}
|
||||
}
|
||||
|
||||
func worker() {
|
||||
workerPoolMu.Lock()
|
||||
queue := workerPoolTaskQueue
|
||||
workerPoolMu.Unlock()
|
||||
for fn := range queue {
|
||||
fn()
|
||||
}
|
||||
}
|
||||
|
||||
type chunkSizes interface {
|
||||
uint8 | uint16 | uint32 | uint64
|
||||
}
|
||||
|
||||
// returns the size in bits for any allowed type in chunkSizes
|
||||
func chunkSize[T chunkSizes]() int {
|
||||
return int(unsafe.Sizeof(T(0))) * 8
|
||||
}
|
||||
|
||||
type workerContext[T chunkSizes] struct {
|
||||
context.Context
|
||||
Cfg *evaluatorConfig
|
||||
Store *store.Store
|
||||
StatusBits []T
|
||||
Evaluators []routeEvaluator
|
||||
EvalCache *PolicyEvaluatorCache
|
||||
QueryCache *QueryCache
|
||||
Errs *sync.Map
|
||||
}
|
||||
|
||||
type routeEvaluator struct {
|
||||
id uint64
|
||||
evaluator *PolicyEvaluator
|
||||
ID uint64 // route id
|
||||
Evaluator *PolicyEvaluator // the compiled evaluator
|
||||
ComputedChecksum uint64 // cached evaluator checksum
|
||||
}
|
||||
|
||||
// partition represents a range of chunks (fixed-size blocks of policies)
|
||||
// corresponding to the Cfg.Policies, StatusBits, and Evaluators fields in
|
||||
// the workerContext. It is the slice [Begin*chunkSize:End*chunkSize] w.r.t.
|
||||
// those fields, but each index represents a unit of work that can be done in
|
||||
// parallel with work on other chunks.
|
||||
type partition struct{ Begin, End int }
|
||||
|
||||
// computeChecksums is a worker task that computes policy checksums and updates
|
||||
// StatusBits to flag policies that need to be rebuilt. It operates on entire
|
||||
// chunks (fixed-size blocks of policies), given by the start and end indexes
|
||||
// in the partition argument, and updates the corresponding indexes of the
|
||||
// StatusBits field of the worker context for those chunks.
|
||||
func computeChecksums[T chunkSizes](wctx *workerContext[T], part partition) {
|
||||
defer rttrace.StartRegion(wctx, "worker-checksum").End()
|
||||
for chunkIdx := part.Begin; chunkIdx < part.End; chunkIdx++ {
|
||||
var chunkStatus T
|
||||
chunkSize := chunkSize[T]()
|
||||
off := chunkIdx * chunkSize // chunk offset
|
||||
// If there are fewer than chunkSize policies remaining in the actual list,
|
||||
// don't go beyond the end
|
||||
limit := min(chunkSize, len(wctx.Cfg.Policies)-off)
|
||||
popcount := 0
|
||||
for i := range limit {
|
||||
p := wctx.Cfg.Policies[off+i]
|
||||
// Compute the route id; this value is reused later as the route name
|
||||
// when computing the checksum
|
||||
id, err := p.RouteID()
|
||||
if err != nil {
|
||||
wctx.Errs.Store(off+i, fmt.Errorf("authorize: error computing policy route id: %w", err))
|
||||
continue
|
||||
}
|
||||
eval := &wctx.Evaluators[off+i]
|
||||
eval.ID = id
|
||||
// Compute the policy checksum and cache it in the evaluator, reusing
|
||||
// the route ID from before (to avoid needing to compute it again)
|
||||
eval.ComputedChecksum = p.Checksum()
|
||||
// eval.ComputedChecksum = p.ChecksumWithID(id) // TODO: update this when merged
|
||||
|
||||
// Check if there is an existing evaluator cached for the route ID
|
||||
// NB: the route ID is composed of a subset of fields of the Policy; this
|
||||
// means the cache will hit if the route ID fields are the same, even if
|
||||
// other fields in the policy differ.
|
||||
cached, ok := wctx.EvalCache.LookupEvaluator(id)
|
||||
if !ok {
|
||||
rttrace.Logf(wctx, "", "policy for route ID %d not found in cache", id)
|
||||
chunkStatus |= T(1 << i)
|
||||
popcount++
|
||||
} else if cached.policyChecksum != eval.ComputedChecksum {
|
||||
// Route ID is the same, but the full checksum differs
|
||||
rttrace.Logf(wctx, "", "policy for route ID %d changed", id)
|
||||
chunkStatus |= T(1 << i)
|
||||
popcount++
|
||||
}
|
||||
// On a cache hit, chunkStatus for the ith bit stays at 0
|
||||
}
|
||||
// Set chunkStatus bitmask all at once (for better locality)
|
||||
wctx.StatusBits[chunkIdx] = chunkStatus
|
||||
rttrace.Logf(wctx, "", "chunk %d: %d/%d changed", chunkIdx, popcount, limit)
|
||||
}
|
||||
}
|
||||
|
||||
// buildEvaluators is a worker task that creates new policy evaluators. It
|
||||
// operates on entire chunks (fixed-size blocks of policies), given by the start
|
||||
// and end indexes in the partition argument, and updates the corresponding
|
||||
// indexes of the Evaluators field of the worker context for those chunks.
|
||||
func buildEvaluators[T chunkSizes](wctx *workerContext[T], part partition) {
|
||||
chunkSize := chunkSize[T]()
|
||||
defer rttrace.StartRegion(wctx, "worker-build").End()
|
||||
addDefaultCert := wctx.Cfg.AddDefaultClientCertificateRule
|
||||
var err error
|
||||
for chunkIdx := part.Begin; chunkIdx < part.End; chunkIdx++ {
|
||||
// Obtain the bitmask computed by computeChecksums for this chunk
|
||||
stat := wctx.StatusBits[chunkIdx]
|
||||
rttrace.Logf(wctx, "", "chunk %d: status: %0*b", chunkIdx, chunkSize, stat)
|
||||
|
||||
// Iterate over all the set bits in stat. This works by finding the
|
||||
// lowest set bit, zeroing it, and repeating. The go compiler will
|
||||
// replace [bits.TrailingZeros64] with intrinsics on most platforms.
|
||||
for stat != 0 {
|
||||
bit := bits.TrailingZeros64(uint64(stat)) // find the lowest set bit
|
||||
stat &= (stat - 1) // clear the lowest set bit
|
||||
idx := (chunkSize * chunkIdx) + bit
|
||||
p := wctx.Cfg.Policies[idx]
|
||||
eval := &wctx.Evaluators[idx]
|
||||
eval.Evaluator, err = NewPolicyEvaluator(wctx, wctx.Store, p, eval.ComputedChecksum, addDefaultCert, wctx.QueryCache)
|
||||
if err != nil {
|
||||
wctx.Errs.Store(idx, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// bestChunkSize determines the chunk size (8, 16, 32, or 64) to use for the
|
||||
// given number of policies and workers.
|
||||
func bestChunkSize(numPolicies, numWorkers int) int {
|
||||
// use the chunk size that results in the largest number of chunks without
|
||||
// going past the number of workers. this results in the following behavior:
|
||||
// - as the number of policies increases, chunk size tends to increase
|
||||
// - as the number of workers increases, chunk size tends to decrease
|
||||
sizes := []int{64, 32, 16, 8}
|
||||
sizeIdx := 0
|
||||
for i, size := range sizes {
|
||||
if float64(numPolicies)/float64(size) > float64(numWorkers) {
|
||||
break
|
||||
}
|
||||
sizeIdx = i
|
||||
}
|
||||
return sizes[sizeIdx]
|
||||
}
|
||||
|
||||
func getOrCreatePolicyEvaluators(
|
||||
ctx context.Context, cfg *evaluatorConfig, store *store.Store,
|
||||
cachedPolicyEvaluators map[uint64]*PolicyEvaluator,
|
||||
) (map[uint64]*PolicyEvaluator, error) {
|
||||
ctx context.Context,
|
||||
cfg *evaluatorConfig,
|
||||
store *store.Store,
|
||||
evalCache *PolicyEvaluatorCache,
|
||||
queryCache *QueryCache,
|
||||
) error {
|
||||
chunkSize := bestChunkSize(len(cfg.Policies), workerPoolSize)
|
||||
switch chunkSize {
|
||||
case 8:
|
||||
return getOrCreatePolicyEvaluatorsT[uint8](ctx, cfg, store, evalCache, queryCache)
|
||||
case 16:
|
||||
return getOrCreatePolicyEvaluatorsT[uint16](ctx, cfg, store, evalCache, queryCache)
|
||||
case 32:
|
||||
return getOrCreatePolicyEvaluatorsT[uint32](ctx, cfg, store, evalCache, queryCache)
|
||||
case 64:
|
||||
return getOrCreatePolicyEvaluatorsT[uint64](ctx, cfg, store, evalCache, queryCache)
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func getOrCreatePolicyEvaluatorsT[T chunkSizes](
|
||||
ctx context.Context,
|
||||
cfg *evaluatorConfig,
|
||||
store *store.Store,
|
||||
evalCache *PolicyEvaluatorCache,
|
||||
queryCache *QueryCache,
|
||||
) error {
|
||||
chunkSize := bestChunkSize(len(cfg.Policies), workerPoolSize)
|
||||
rttrace.Logf(ctx, "", "eval cache size: %d; query cache size: %d; chunk size: %d",
|
||||
evalCache.NumCachedEvaluators(), queryCache.NumCachedQueries(), chunkSize)
|
||||
now := time.Now()
|
||||
|
||||
var reusedCount int
|
||||
m := make(map[uint64]*PolicyEvaluator)
|
||||
var builders []errgrouputil.BuilderFunc[routeEvaluator]
|
||||
for i := range cfg.Policies {
|
||||
configPolicy := cfg.Policies[i]
|
||||
id, err := configPolicy.RouteID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: error computing policy route id: %w", err)
|
||||
// Split the policy list into chunks which can individually be operated on in
|
||||
// parallel with other chunks. Each chunk has a corresponding bitmask in the
|
||||
// statusBits list which is used to indicate to workers which policy
|
||||
// evaluators (at indexes in the chunk corresponding to set bits) need to be
|
||||
// built, or rebuilt due to changes.
|
||||
numChunks := len(cfg.Policies) / chunkSize
|
||||
if len(cfg.Policies)%chunkSize != 0 {
|
||||
numChunks++
|
||||
}
|
||||
p := cachedPolicyEvaluators[id]
|
||||
if p != nil && p.policyChecksum == configPolicy.Checksum() {
|
||||
m[id] = p
|
||||
reusedCount++
|
||||
statusBits := make([]T, numChunks) // bits map directly to policy indexes
|
||||
evaluators := make([]routeEvaluator, len(cfg.Policies))
|
||||
if len(evaluators) == 0 {
|
||||
return nil // nothing to do
|
||||
}
|
||||
// Limit the number of workers to the size of the worker pool; since we are
|
||||
// manually distributing chunks between workers, we can avoid spawning more
|
||||
// goroutines than we need, and instead giving each worker additional chunks.
|
||||
numWorkers := min(workerPoolSize, numChunks)
|
||||
// Each worker is given a minimum number of chunks, then the remainder are
|
||||
// spread evenly between workers.
|
||||
minChunksPerWorker := numChunks / numWorkers
|
||||
overflow := numChunks % numWorkers // number of workers which get an additional chunk
|
||||
|
||||
wctx := &workerContext[T]{
|
||||
Context: ctx,
|
||||
Cfg: cfg,
|
||||
Store: store,
|
||||
StatusBits: statusBits,
|
||||
Evaluators: evaluators,
|
||||
EvalCache: evalCache,
|
||||
QueryCache: queryCache,
|
||||
Errs: &sync.Map{}, // policy index->error
|
||||
}
|
||||
|
||||
// First, build a list of partitions (start/end chunk indexes) to send to
|
||||
// each worker.
|
||||
partitions := make([]partition, numWorkers)
|
||||
rttrace.WithRegion(ctx, "partitioning", func() {
|
||||
chunkIdx := 0
|
||||
for workerIdx := range numWorkers {
|
||||
chunkStart := chunkIdx
|
||||
chunkEnd := chunkStart + minChunksPerWorker
|
||||
if workerIdx < overflow {
|
||||
chunkEnd++
|
||||
}
|
||||
chunkIdx = chunkEnd
|
||||
partitions[workerIdx] = partition{chunkStart, chunkEnd}
|
||||
}
|
||||
})
|
||||
|
||||
// Compute all route checksums in parallel to determine which routes need to
|
||||
// be rebuilt.
|
||||
rttrace.WithRegion(ctx, "computing checksums", func() {
|
||||
var wg sync.WaitGroup
|
||||
for _, part := range partitions {
|
||||
wg.Add(1)
|
||||
workerPoolTaskQueue <- func() {
|
||||
defer wg.Done()
|
||||
computeChecksums(wctx, part)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
hasErrs := false
|
||||
wctx.Errs.Range(func(key, value any) bool {
|
||||
log.Ctx(ctx).Error().Int("policy-index", key.(int)).Msg(value.(error).Error())
|
||||
hasErrs = true
|
||||
return true
|
||||
})
|
||||
if hasErrs {
|
||||
return fmt.Errorf("authorize: error computing one or more policy route IDs")
|
||||
}
|
||||
|
||||
// After all checksums are computed and status bits populated, build the
|
||||
// required evaluators.
|
||||
rttrace.WithRegion(ctx, "building evaluators", func() {
|
||||
var wg sync.WaitGroup
|
||||
for _, part := range partitions {
|
||||
// Adjust the partition to skip over chunks with 0 bits set
|
||||
for part.Begin < part.End && statusBits[part.Begin] == 0 {
|
||||
part.Begin++
|
||||
}
|
||||
for part.Begin < (part.End)-1 && statusBits[part.End-1] == 0 {
|
||||
part.End--
|
||||
}
|
||||
if part.Begin == part.End {
|
||||
continue
|
||||
}
|
||||
builders = append(builders, func(ctx context.Context) (*routeEvaluator, error) {
|
||||
evaluator, err := NewPolicyEvaluator(ctx, store, configPolicy, cfg.AddDefaultClientCertificateRule)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: error building evaluator for route id=%s: %w", configPolicy.ID, err)
|
||||
wg.Add(1)
|
||||
workerPoolTaskQueue <- func() {
|
||||
defer wg.Done()
|
||||
buildEvaluators(wctx, part)
|
||||
}
|
||||
return &routeEvaluator{
|
||||
id: id,
|
||||
evaluator: evaluator,
|
||||
}, nil
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
hasErrs = false
|
||||
wctx.Errs.Range(func(key, value any) bool {
|
||||
log.Ctx(ctx).Error().Int("policy-index", key.(int)).Msg(value.(error).Error())
|
||||
hasErrs = true
|
||||
return true
|
||||
})
|
||||
if hasErrs {
|
||||
return fmt.Errorf("authorize: error building policy evaluators")
|
||||
}
|
||||
|
||||
evals, errs := errgrouputil.Build(ctx, builders...)
|
||||
if len(errs) > 0 {
|
||||
for _, err := range errs {
|
||||
log.Ctx(ctx).Error().Msg(err.Error())
|
||||
// Store updated evaluators in the cache
|
||||
updatedCount := 0
|
||||
for _, p := range evaluators {
|
||||
if p.Evaluator != nil { // these are only set when modified
|
||||
updatedCount++
|
||||
evalCache.StoreEvaluator(p.ID, p.Evaluator)
|
||||
}
|
||||
return nil, fmt.Errorf("authorize: error building policy evaluators")
|
||||
}
|
||||
|
||||
for _, p := range evals {
|
||||
m[p.id] = p.evaluator
|
||||
}
|
||||
|
||||
log.Ctx(ctx).Debug().
|
||||
Dur("duration", time.Since(now)).
|
||||
Int("reused-policies", reusedCount).
|
||||
Int("created-policies", len(cfg.Policies)-reusedCount).
|
||||
Int("reused-policies", len(cfg.Policies)-updatedCount).
|
||||
Int("created-policies", updatedCount).
|
||||
Msg("updated policy evaluators")
|
||||
return m, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Evaluate evaluates the rego for the given policy and generates the identity headers.
|
||||
|
@ -272,7 +678,7 @@ func (e *Evaluator) evaluatePolicy(ctx context.Context, req *Request) (*PolicyRe
|
|||
return nil, fmt.Errorf("authorize: error computing policy route id: %w", err)
|
||||
}
|
||||
|
||||
policyEvaluator, ok := e.policyEvaluators[id]
|
||||
policyEvaluator, ok := e.evalCache.LookupEvaluator(id)
|
||||
if !ok {
|
||||
return &PolicyResponse{
|
||||
Deny: NewRuleResult(true, criteria.ReasonRouteNotFound),
|
||||
|
@ -285,7 +691,7 @@ func (e *Evaluator) evaluatePolicy(ctx context.Context, req *Request) (*PolicyRe
|
|||
}
|
||||
|
||||
isValidClientCertificate, err := isValidClientCertificate(
|
||||
clientCA, string(e.clientCRL), req.HTTP.ClientCertificate, e.clientCertConstraints)
|
||||
clientCA, string(e.opts.ClientCRL), req.HTTP.ClientCertificate, e.opts.ClientCertConstraints)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: error validating client certificate: %w", err)
|
||||
}
|
||||
|
@ -303,7 +709,7 @@ func (e *Evaluator) evaluateHeaders(ctx context.Context, req *Request) (*Headers
|
|||
return nil, err
|
||||
}
|
||||
headersReq.Session = req.Session
|
||||
res, err := e.headersEvaluators.Evaluate(ctx, headersReq)
|
||||
res, err := e.headersEvaluator.Evaluate(ctx, headersReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -322,29 +728,36 @@ func (e *Evaluator) getClientCA(policy *config.Policy) (string, error) {
|
|||
return string(bs), nil
|
||||
}
|
||||
|
||||
return string(e.clientCA), nil
|
||||
return string(e.opts.ClientCA), nil
|
||||
}
|
||||
|
||||
func updateStore(ctx context.Context, store *store.Store, cfg *evaluatorConfig) error {
|
||||
jwk, err := getJWK(ctx, cfg)
|
||||
func updateStore(ctx context.Context, store *store.Store, cfg *evaluatorConfig, previous *Evaluator) error {
|
||||
if previous == nil || !bytes.Equal(cfg.SigningKey, previous.opts.SigningKey) {
|
||||
jwk, err := getJWK(ctx, cfg.SigningKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authorize: couldn't create signer: %w", err)
|
||||
}
|
||||
store.UpdateSigningKey(jwk)
|
||||
}
|
||||
|
||||
if previous == nil || cfg.GoogleCloudServerlessAuthenticationServiceAccount != previous.opts.GoogleCloudServerlessAuthenticationServiceAccount {
|
||||
store.UpdateGoogleCloudServerlessAuthenticationServiceAccount(
|
||||
cfg.GoogleCloudServerlessAuthenticationServiceAccount,
|
||||
)
|
||||
store.UpdateJWTClaimHeaders(cfg.JWTClaimsHeaders)
|
||||
store.UpdateRoutePolicies(cfg.Policies)
|
||||
store.UpdateSigningKey(jwk)
|
||||
}
|
||||
|
||||
if previous == nil || !maps.Equal(cfg.JWTClaimsHeaders, previous.opts.JWTClaimsHeaders) {
|
||||
store.UpdateJWTClaimHeaders(cfg.JWTClaimsHeaders)
|
||||
}
|
||||
|
||||
store.UpdateRoutePolicies(cfg.Policies)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getJWK(ctx context.Context, cfg *evaluatorConfig) (*jose.JSONWebKey, error) {
|
||||
func getJWK(ctx context.Context, signingKey []byte) (*jose.JSONWebKey, error) {
|
||||
var decodedCert []byte
|
||||
// if we don't have a signing key, generate one
|
||||
if len(cfg.SigningKey) == 0 {
|
||||
if len(signingKey) == 0 {
|
||||
key, err := cryptutil.NewSigningKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't generate signing key: %w", err)
|
||||
|
@ -354,7 +767,7 @@ func getJWK(ctx context.Context, cfg *evaluatorConfig) (*jose.JSONWebKey, error)
|
|||
return nil, fmt.Errorf("bad signing key: %w", err)
|
||||
}
|
||||
} else {
|
||||
decodedCert = cfg.SigningKey
|
||||
decodedCert = signingKey
|
||||
}
|
||||
|
||||
jwk, err := cryptutil.PrivateJWKFromBytes(decodedCert)
|
||||
|
|
49
authorize/evaluator/evaluator_export_test.go
Normal file
49
authorize/evaluator/evaluator_export_test.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package evaluator
|
||||
|
||||
import (
|
||||
"maps"
|
||||
)
|
||||
|
||||
func (c *PolicyEvaluatorCache) XClone() *PolicyEvaluatorCache {
|
||||
c.evalsMu.Lock()
|
||||
defer c.evalsMu.Unlock()
|
||||
return &PolicyEvaluatorCache{
|
||||
evaluatorsByRouteID: maps.Clone(c.evaluatorsByRouteID),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Evaluator) XEvaluatorCache() *PolicyEvaluatorCache {
|
||||
return e.evalCache
|
||||
}
|
||||
|
||||
func (e *Evaluator) XQueryCache() *QueryCache {
|
||||
return e.queryCache
|
||||
}
|
||||
|
||||
var (
|
||||
XGetGoogleCloudServerlessTokenSource = getGoogleCloudServerlessTokenSource
|
||||
XIsValidClientCertificate = isValidClientCertificate
|
||||
XNormalizeServiceAccount = normalizeServiceAccount
|
||||
XBestChunkSize = bestChunkSize
|
||||
XGetUserPrincipalNamesFromSAN = getUserPrincipalNamesFromSAN
|
||||
)
|
||||
|
||||
var OIDUserPrincipalName = oidUserPrincipalName
|
||||
|
||||
func XWorkerPoolSize() int {
|
||||
return workerPoolSize
|
||||
}
|
||||
|
||||
func OverrideWorkerPoolSizeForTesting(newSize int) {
|
||||
if newSize == workerPoolSize {
|
||||
return
|
||||
}
|
||||
workerPoolMu.Lock()
|
||||
workerPoolSize = newSize
|
||||
close(workerPoolTaskQueue) // this will stop existing workers
|
||||
workerPoolTaskQueue = make(chan func(), (workerPoolSize+1)*2)
|
||||
workerPoolMu.Unlock()
|
||||
for i := 0; i < workerPoolSize; i++ {
|
||||
go worker()
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,4 +1,4 @@
|
|||
package evaluator
|
||||
package evaluator_test
|
||||
|
||||
import (
|
||||
"encoding/asn1"
|
||||
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
)
|
||||
|
||||
|
@ -167,20 +168,20 @@ dgwikvJkMOfcuexx
|
|||
func Test_isValidClientCertificate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var noConstraints ClientCertConstraints
|
||||
var noConstraints evaluator.ClientCertConstraints
|
||||
t.Run("no ca", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(
|
||||
"", "", ClientCertificateInfo{Leaf: "WHATEVER!"}, noConstraints)
|
||||
valid, err := evaluator.XIsValidClientCertificate(
|
||||
"", "", evaluator.ClientCertificateInfo{Leaf: "WHATEVER!"}, noConstraints)
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.True(t, valid, "should return true")
|
||||
})
|
||||
t.Run("no cert", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{}, noConstraints)
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{}, noConstraints)
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.False(t, valid, "should return false")
|
||||
})
|
||||
t.Run("valid cert", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCert,
|
||||
}, noConstraints)
|
||||
|
@ -188,7 +189,7 @@ func Test_isValidClientCertificate(t *testing.T) {
|
|||
assert.True(t, valid, "should return true")
|
||||
})
|
||||
t.Run("valid cert with intermediate", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidIntermediateCert,
|
||||
Intermediates: testIntermediateCA,
|
||||
|
@ -197,7 +198,7 @@ func Test_isValidClientCertificate(t *testing.T) {
|
|||
assert.True(t, valid, "should return true")
|
||||
})
|
||||
t.Run("valid cert missing intermediate", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidIntermediateCert,
|
||||
Intermediates: "",
|
||||
|
@ -206,7 +207,7 @@ func Test_isValidClientCertificate(t *testing.T) {
|
|||
assert.False(t, valid, "should return false")
|
||||
})
|
||||
t.Run("intermediate CA as root", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(testIntermediateCA, "", ClientCertificateInfo{
|
||||
valid, err := evaluator.XIsValidClientCertificate(testIntermediateCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidIntermediateCert,
|
||||
}, noConstraints)
|
||||
|
@ -214,7 +215,7 @@ func Test_isValidClientCertificate(t *testing.T) {
|
|||
assert.True(t, valid, "should return true")
|
||||
})
|
||||
t.Run("unsigned cert", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testUntrustedCert,
|
||||
}, noConstraints)
|
||||
|
@ -222,7 +223,7 @@ func Test_isValidClientCertificate(t *testing.T) {
|
|||
assert.False(t, valid, "should return false")
|
||||
})
|
||||
t.Run("not a cert", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: "WHATEVER!",
|
||||
}, noConstraints)
|
||||
|
@ -230,22 +231,22 @@ func Test_isValidClientCertificate(t *testing.T) {
|
|||
assert.False(t, valid, "should return false")
|
||||
})
|
||||
t.Run("revoked cert", func(t *testing.T) {
|
||||
revokedCertInfo := ClientCertificateInfo{
|
||||
revokedCertInfo := evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testRevokedCert,
|
||||
}
|
||||
|
||||
// The "revoked cert" should otherwise be valid (when no CRL is specified).
|
||||
valid, err := isValidClientCertificate(testCA, "", revokedCertInfo, noConstraints)
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", revokedCertInfo, noConstraints)
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.True(t, valid, "should return true")
|
||||
|
||||
valid, err = isValidClientCertificate(testCA, testCRL, revokedCertInfo, noConstraints)
|
||||
valid, err = evaluator.XIsValidClientCertificate(testCA, testCRL, revokedCertInfo, noConstraints)
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.False(t, valid, "should return false")
|
||||
|
||||
// Specifying a CRL containing the revoked cert should not affect other certs.
|
||||
valid, err = isValidClientCertificate(testCA, testCRL, ClientCertificateInfo{
|
||||
valid, err = evaluator.XIsValidClientCertificate(testCA, testCRL, evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCert,
|
||||
}, noConstraints)
|
||||
|
@ -253,51 +254,51 @@ func Test_isValidClientCertificate(t *testing.T) {
|
|||
assert.True(t, valid, "should return true")
|
||||
})
|
||||
t.Run("chain too deep", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidIntermediateCert,
|
||||
Intermediates: testIntermediateCA,
|
||||
}, ClientCertConstraints{MaxVerifyDepth: 1})
|
||||
}, evaluator.ClientCertConstraints{MaxVerifyDepth: 1})
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.False(t, valid, "should return false")
|
||||
})
|
||||
t.Run("any SAN", func(t *testing.T) {
|
||||
matchAnySAN := ClientCertConstraints{SANMatchers: SANMatchers{
|
||||
matchAnySAN := evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeDNS: regexp.MustCompile("^.*$"),
|
||||
config.SANTypeEmail: regexp.MustCompile("^.*$"),
|
||||
config.SANTypeIPAddress: regexp.MustCompile("^.*$"),
|
||||
config.SANTypeURI: regexp.MustCompile("^.*$"),
|
||||
}}
|
||||
|
||||
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCert, // no SANs
|
||||
}, matchAnySAN)
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.False(t, valid, "should return false")
|
||||
|
||||
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithDNSSANs,
|
||||
}, matchAnySAN)
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.True(t, valid, "should return true")
|
||||
|
||||
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithEmailSAN,
|
||||
}, matchAnySAN)
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.True(t, valid, "should return true")
|
||||
|
||||
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithIPSAN,
|
||||
}, matchAnySAN)
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.True(t, valid, "should return true")
|
||||
|
||||
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithURISAN,
|
||||
}, matchAnySAN)
|
||||
|
@ -305,95 +306,95 @@ func Test_isValidClientCertificate(t *testing.T) {
|
|||
assert.True(t, valid, "should return true")
|
||||
})
|
||||
t.Run("DNS SAN", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithDNSSANs,
|
||||
}, ClientCertConstraints{SANMatchers: SANMatchers{
|
||||
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeDNS: regexp.MustCompile(`^a\..*\.example\.com$`),
|
||||
}})
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.True(t, valid, "should return true")
|
||||
|
||||
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithDNSSANs,
|
||||
}, ClientCertConstraints{SANMatchers: SANMatchers{
|
||||
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeEmail: regexp.MustCompile(`^a\..*\.example\.com$`), // mismatched type
|
||||
}})
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.False(t, valid, "should return false")
|
||||
})
|
||||
t.Run("email SAN", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithEmailSAN,
|
||||
}, ClientCertConstraints{SANMatchers: SANMatchers{
|
||||
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeEmail: regexp.MustCompile(`^.*@example\.com$`),
|
||||
}})
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.True(t, valid, "should return true")
|
||||
|
||||
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithEmailSAN,
|
||||
}, ClientCertConstraints{SANMatchers: SANMatchers{
|
||||
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeIPAddress: regexp.MustCompile(`^.*@example\.com$`), // mismatched type
|
||||
}})
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.False(t, valid, "should return false")
|
||||
})
|
||||
t.Run("IP address SAN", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithIPSAN,
|
||||
}, ClientCertConstraints{SANMatchers: SANMatchers{
|
||||
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeIPAddress: regexp.MustCompile(`^192\.168\.10\..*$`),
|
||||
}})
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.True(t, valid, "should return true")
|
||||
|
||||
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithIPSAN,
|
||||
}, ClientCertConstraints{SANMatchers: SANMatchers{
|
||||
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeURI: regexp.MustCompile(`^192\.168\.10\..*$`), // mismatched type
|
||||
}})
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.False(t, valid, "should return false")
|
||||
})
|
||||
t.Run("URI SAN", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithURISAN,
|
||||
}, ClientCertConstraints{SANMatchers: SANMatchers{
|
||||
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeURI: regexp.MustCompile(`^spiffe://example\.com/.*$`),
|
||||
}})
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.True(t, valid, "should return true")
|
||||
|
||||
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithURISAN,
|
||||
}, ClientCertConstraints{SANMatchers: SANMatchers{
|
||||
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeDNS: regexp.MustCompile(`^spiffe://example\.com/.*$`), // mismatched type
|
||||
}})
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.False(t, valid, "should return false")
|
||||
})
|
||||
t.Run("UserPrincipalName SAN", func(t *testing.T) {
|
||||
valid, err := isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err := evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithUPNSAN,
|
||||
}, ClientCertConstraints{SANMatchers: SANMatchers{
|
||||
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeUserPrincipalName: regexp.MustCompile(`^test_device$`),
|
||||
}})
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
assert.True(t, valid, "should return true")
|
||||
|
||||
valid, err = isValidClientCertificate(testCA, "", ClientCertificateInfo{
|
||||
valid, err = evaluator.XIsValidClientCertificate(testCA, "", evaluator.ClientCertificateInfo{
|
||||
Presented: true,
|
||||
Leaf: testValidCertWithURISAN,
|
||||
}, ClientCertConstraints{SANMatchers: SANMatchers{
|
||||
}, evaluator.ClientCertConstraints{SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeDNS: regexp.MustCompile(`^test-device$`), // mismatched type
|
||||
}})
|
||||
assert.NoError(t, err, "should not return an error")
|
||||
|
@ -406,26 +407,26 @@ func TestClientCertConstraintsFromConfig(t *testing.T) {
|
|||
|
||||
t.Run("default constraints", func(t *testing.T) {
|
||||
var s config.DownstreamMTLSSettings
|
||||
c, err := ClientCertConstraintsFromConfig(&s)
|
||||
c, err := evaluator.ClientCertConstraintsFromConfig(&s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &ClientCertConstraints{MaxVerifyDepth: 1}, c)
|
||||
assert.Equal(t, &evaluator.ClientCertConstraints{MaxVerifyDepth: 1}, c)
|
||||
})
|
||||
t.Run("no constraints", func(t *testing.T) {
|
||||
s := config.DownstreamMTLSSettings{
|
||||
MaxVerifyDepth: new(uint32),
|
||||
}
|
||||
c, err := ClientCertConstraintsFromConfig(&s)
|
||||
c, err := evaluator.ClientCertConstraintsFromConfig(&s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &ClientCertConstraints{}, c)
|
||||
assert.Equal(t, &evaluator.ClientCertConstraints{}, c)
|
||||
})
|
||||
t.Run("larger max depth", func(t *testing.T) {
|
||||
depth := uint32(5)
|
||||
s := config.DownstreamMTLSSettings{
|
||||
MaxVerifyDepth: &depth,
|
||||
}
|
||||
c, err := ClientCertConstraintsFromConfig(&s)
|
||||
c, err := evaluator.ClientCertConstraintsFromConfig(&s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &ClientCertConstraints{MaxVerifyDepth: 5}, c)
|
||||
assert.Equal(t, &evaluator.ClientCertConstraints{MaxVerifyDepth: 5}, c)
|
||||
})
|
||||
t.Run("one SAN match", func(t *testing.T) {
|
||||
s := config.DownstreamMTLSSettings{
|
||||
|
@ -433,11 +434,11 @@ func TestClientCertConstraintsFromConfig(t *testing.T) {
|
|||
{Type: config.SANTypeDNS, Pattern: `.*\.corp\.example\.com`},
|
||||
},
|
||||
}
|
||||
c, err := ClientCertConstraintsFromConfig(&s)
|
||||
c, err := evaluator.ClientCertConstraintsFromConfig(&s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &ClientCertConstraints{
|
||||
assert.Equal(t, &evaluator.ClientCertConstraints{
|
||||
MaxVerifyDepth: 1,
|
||||
SANMatchers: SANMatchers{
|
||||
SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeDNS: regexp.MustCompile(`^(.*\.corp\.example\.com)$`),
|
||||
},
|
||||
}, c)
|
||||
|
@ -449,11 +450,11 @@ func TestClientCertConstraintsFromConfig(t *testing.T) {
|
|||
{Type: config.SANTypeDNS, Pattern: `.*\.bar\.example\.com`},
|
||||
},
|
||||
}
|
||||
c, err := ClientCertConstraintsFromConfig(&s)
|
||||
c, err := evaluator.ClientCertConstraintsFromConfig(&s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &ClientCertConstraints{
|
||||
assert.Equal(t, &evaluator.ClientCertConstraints{
|
||||
MaxVerifyDepth: 1,
|
||||
SANMatchers: SANMatchers{
|
||||
SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeDNS: regexp.MustCompile(`^(.*\.foo\.example\.com)|(.*\.bar\.example\.com)$`),
|
||||
},
|
||||
}, c)
|
||||
|
@ -465,11 +466,11 @@ func TestClientCertConstraintsFromConfig(t *testing.T) {
|
|||
{Type: config.SANTypeEmail, Pattern: `.*@example\.com`},
|
||||
},
|
||||
}
|
||||
c, err := ClientCertConstraintsFromConfig(&s)
|
||||
c, err := evaluator.ClientCertConstraintsFromConfig(&s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &ClientCertConstraints{
|
||||
assert.Equal(t, &evaluator.ClientCertConstraints{
|
||||
MaxVerifyDepth: 1,
|
||||
SANMatchers: SANMatchers{
|
||||
SANMatchers: evaluator.SANMatchers{
|
||||
config.SANTypeDNS: regexp.MustCompile(`^(.*\.foo\.example\.com)$`),
|
||||
config.SANTypeEmail: regexp.MustCompile(`^(.*@example\.com)$`),
|
||||
},
|
||||
|
@ -481,7 +482,7 @@ func TestClientCertConstraintsFromConfig(t *testing.T) {
|
|||
{Type: config.SANTypeDNS, Pattern: "["},
|
||||
},
|
||||
}
|
||||
_, err := ClientCertConstraintsFromConfig(&s)
|
||||
_, err := evaluator.ClientCertConstraintsFromConfig(&s)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
@ -514,7 +515,7 @@ func TestGetUserPrincipalNamesFromSAN(t *testing.T) {
|
|||
san, err := asn1.Marshal(SAN{upn("hello")})
|
||||
require.NoError(t, err)
|
||||
|
||||
names, err := getUserPrincipalNamesFromSAN(san)
|
||||
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"hello"}, names)
|
||||
})
|
||||
|
@ -527,7 +528,7 @@ func TestGetUserPrincipalNamesFromSAN(t *testing.T) {
|
|||
san, err := asn1.Marshal(SAN{upn("foo"), upn("bar"), upn("baz")})
|
||||
require.NoError(t, err)
|
||||
|
||||
names, err := getUserPrincipalNamesFromSAN(san)
|
||||
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"foo", "bar", "baz"}, names)
|
||||
})
|
||||
|
@ -544,7 +545,7 @@ func TestGetUserPrincipalNamesFromSAN(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
names, err := getUserPrincipalNamesFromSAN(san)
|
||||
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, names)
|
||||
})
|
||||
|
@ -571,7 +572,7 @@ func TestGetUserPrincipalNamesFromSAN(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
names, err := getUserPrincipalNamesFromSAN(san)
|
||||
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"UserPrincipalName"}, names)
|
||||
})
|
||||
|
@ -587,24 +588,24 @@ func TestGetUserPrincipalNamesFromSAN(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
names, err := getUserPrincipalNamesFromSAN(san)
|
||||
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
|
||||
assert.ErrorContains(t, err, "expected UTF8String")
|
||||
assert.Empty(t, names)
|
||||
})
|
||||
t.Run("EmptySAN", func(t *testing.T) {
|
||||
names, err := getUserPrincipalNamesFromSAN(nil)
|
||||
names, err := evaluator.XGetUserPrincipalNamesFromSAN(nil)
|
||||
assert.ErrorContains(t, err, "error reading GeneralNames sequence")
|
||||
assert.Empty(t, names)
|
||||
})
|
||||
t.Run("TruncatedGeneralName", func(t *testing.T) {
|
||||
san := []byte{0x30, 0x02, 0x82, 0x05 /* 5 more bytes expected */}
|
||||
names, err := getUserPrincipalNamesFromSAN(san)
|
||||
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
|
||||
assert.ErrorContains(t, err, "error reading GeneralName")
|
||||
assert.Empty(t, names)
|
||||
})
|
||||
t.Run("OtherNameWrongTypeIDType", func(t *testing.T) {
|
||||
san := []byte{0x30, 0x06, 0xa0, 0x04, 0x02 /* type Integer, not OID */, 0x02, 0x46, 0x01}
|
||||
names, err := getUserPrincipalNamesFromSAN(san)
|
||||
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
|
||||
assert.ErrorContains(t, err, "error reading OtherName type ID")
|
||||
assert.Empty(t, names)
|
||||
})
|
||||
|
@ -618,13 +619,13 @@ func TestGetUserPrincipalNamesFromSAN(t *testing.T) {
|
|||
}
|
||||
san, err := asn1.Marshal(SAN{
|
||||
UPN: BadOtherName{
|
||||
TypeID: oidUserPrincipalName,
|
||||
TypeID: evaluator.OIDUserPrincipalName,
|
||||
Value: UTF8String{"hello"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
names, err := getUserPrincipalNamesFromSAN(san)
|
||||
names, err := evaluator.XGetUserPrincipalNamesFromSAN(san)
|
||||
assert.ErrorContains(t, err, "error reading UserPrincipalName value")
|
||||
assert.Empty(t, names)
|
||||
})
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package evaluator
|
||||
package evaluator_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
@ -7,17 +7,19 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
)
|
||||
|
||||
func withMockGCP(t *testing.T, f func()) {
|
||||
originalGCPIdentityDocURL := GCPIdentityDocURL
|
||||
originalGCPIdentityDocURL := evaluator.GCPIdentityDocURL
|
||||
defer func() {
|
||||
GCPIdentityDocURL = originalGCPIdentityDocURL
|
||||
GCPIdentityNow = time.Now
|
||||
evaluator.GCPIdentityDocURL = originalGCPIdentityDocURL
|
||||
evaluator.GCPIdentityNow = time.Now
|
||||
}()
|
||||
|
||||
now := time.Date(2020, 1, 1, 1, 0, 0, 0, time.UTC)
|
||||
GCPIdentityNow = func() time.Time {
|
||||
evaluator.GCPIdentityNow = func() time.Time {
|
||||
return now
|
||||
}
|
||||
|
||||
|
@ -28,13 +30,13 @@ func withMockGCP(t *testing.T, f func()) {
|
|||
}))
|
||||
defer srv.Close()
|
||||
|
||||
GCPIdentityDocURL = srv.URL
|
||||
evaluator.GCPIdentityDocURL = srv.URL
|
||||
f()
|
||||
}
|
||||
|
||||
func TestGCPIdentityTokenSource(t *testing.T) {
|
||||
withMockGCP(t, func() {
|
||||
src, err := getGoogleCloudServerlessTokenSource("", "example")
|
||||
src, err := evaluator.XGetGoogleCloudServerlessTokenSource("", "example")
|
||||
assert.NoError(t, err)
|
||||
|
||||
token, err := src.Token()
|
||||
|
@ -62,7 +64,7 @@ func Test_normalizeServiceAccount(t *testing.T) {
|
|||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
gotServiceAccount, err := normalizeServiceAccount(tc.serviceAccount)
|
||||
gotServiceAccount, err := evaluator.XNormalizeServiceAccount(tc.serviceAccount)
|
||||
assert.True(t, (err != nil) == tc.wantError)
|
||||
assert.Equal(t, tc.expectedServiceAccount, gotServiceAccount)
|
||||
})
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package evaluator
|
||||
package evaluator_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -20,6 +20,7 @@ import (
|
|||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
|
@ -29,7 +30,7 @@ import (
|
|||
)
|
||||
|
||||
func TestNewHeadersRequestFromPolicy(t *testing.T) {
|
||||
req, _ := NewHeadersRequestFromPolicy(&config.Policy{
|
||||
req, _ := evaluator.NewHeadersRequestFromPolicy(&config.Policy{
|
||||
EnableGoogleCloudServerlessAuthentication: true,
|
||||
From: "https://*.example.com",
|
||||
To: config.WeightedURLs{
|
||||
|
@ -37,18 +38,18 @@ func TestNewHeadersRequestFromPolicy(t *testing.T) {
|
|||
URL: *mustParseURL("http://to.example.com"),
|
||||
},
|
||||
},
|
||||
}, RequestHTTP{
|
||||
}, evaluator.RequestHTTP{
|
||||
Hostname: "from.example.com",
|
||||
ClientCertificate: ClientCertificateInfo{
|
||||
ClientCertificate: evaluator.ClientCertificateInfo{
|
||||
Leaf: "--- FAKE CERTIFICATE ---",
|
||||
},
|
||||
})
|
||||
assert.Equal(t, &HeadersRequest{
|
||||
assert.Equal(t, &evaluator.HeadersRequest{
|
||||
EnableGoogleCloudServerlessAuthentication: true,
|
||||
Issuer: "from.example.com",
|
||||
Audience: "from.example.com",
|
||||
ToAudience: "https://to.example.com",
|
||||
ClientCertificate: ClientCertificateInfo{
|
||||
ClientCertificate: evaluator.ClientCertificateInfo{
|
||||
Leaf: "--- FAKE CERTIFICATE ---",
|
||||
},
|
||||
}, req)
|
||||
|
@ -91,21 +92,21 @@ func TestNewHeadersRequestFromPolicy_IssuerFormat(t *testing.T) {
|
|||
},
|
||||
} {
|
||||
policy.JWTIssuerFormat = tc.format
|
||||
req, err := NewHeadersRequestFromPolicy(policy, RequestHTTP{
|
||||
req, err := evaluator.NewHeadersRequestFromPolicy(policy, evaluator.RequestHTTP{
|
||||
Hostname: "from.example.com",
|
||||
ClientCertificate: ClientCertificateInfo{
|
||||
ClientCertificate: evaluator.ClientCertificateInfo{
|
||||
Leaf: "--- FAKE CERTIFICATE ---",
|
||||
},
|
||||
})
|
||||
if tc.err != "" {
|
||||
assert.ErrorContains(t, err, tc.err)
|
||||
} else {
|
||||
assert.Equal(t, &HeadersRequest{
|
||||
assert.Equal(t, &evaluator.HeadersRequest{
|
||||
EnableGoogleCloudServerlessAuthentication: true,
|
||||
Issuer: tc.expectedIssuer,
|
||||
Audience: tc.expectedAudience,
|
||||
ToAudience: "https://to.example.com",
|
||||
ClientCertificate: ClientCertificateInfo{
|
||||
ClientCertificate: evaluator.ClientCertificateInfo{
|
||||
Leaf: "--- FAKE CERTIFICATE ---",
|
||||
},
|
||||
}, req)
|
||||
|
@ -114,8 +115,8 @@ func TestNewHeadersRequestFromPolicy_IssuerFormat(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNewHeadersRequestFromPolicy_nil(t *testing.T) {
|
||||
req, _ := NewHeadersRequestFromPolicy(nil, RequestHTTP{Hostname: "from.example.com"})
|
||||
assert.Equal(t, &HeadersRequest{
|
||||
req, _ := evaluator.NewHeadersRequestFromPolicy(nil, evaluator.RequestHTTP{Hostname: "from.example.com"})
|
||||
assert.Equal(t, &evaluator.HeadersRequest{
|
||||
Issuer: "from.example.com",
|
||||
Audience: "from.example.com",
|
||||
}, req)
|
||||
|
@ -138,13 +139,13 @@ func TestHeadersEvaluator(t *testing.T) {
|
|||
|
||||
iat := time.Unix(1686870680, 0)
|
||||
|
||||
eval := func(t *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) {
|
||||
eval := func(t *testing.T, data []proto.Message, input *evaluator.HeadersRequest) (*evaluator.HeadersResponse, error) {
|
||||
ctx := context.Background()
|
||||
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
|
||||
store := store.New()
|
||||
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
|
||||
store.UpdateSigningKey(privateJWK)
|
||||
e, err := NewHeadersEvaluator(ctx, store, rego.Time(iat))
|
||||
e, err := evaluator.NewHeadersEvaluator(ctx, store, rego.Time(iat))
|
||||
require.NoError(t, err)
|
||||
return e.Evaluate(ctx, input, rego.EvalTime(iat))
|
||||
}
|
||||
|
@ -159,11 +160,11 @@ func TestHeadersEvaluator(t *testing.T) {
|
|||
}},
|
||||
}, IssuedAt: timestamppb.New(iat)},
|
||||
},
|
||||
&HeadersRequest{
|
||||
&evaluator.HeadersRequest{
|
||||
Issuer: "from.example.com",
|
||||
Audience: "from.example.com",
|
||||
ToAudience: "to.example.com",
|
||||
Session: RequestSession{
|
||||
Session: evaluator.RequestSession{
|
||||
ID: "s1",
|
||||
},
|
||||
})
|
||||
|
@ -215,11 +216,11 @@ func TestHeadersEvaluator(t *testing.T) {
|
|||
AccessToken: "ACCESS_TOKEN",
|
||||
}},
|
||||
},
|
||||
&HeadersRequest{
|
||||
&evaluator.HeadersRequest{
|
||||
Issuer: "from.example.com",
|
||||
Audience: "from.example.com",
|
||||
ToAudience: "to.example.com",
|
||||
Session: RequestSession{ID: "s1"},
|
||||
Session: evaluator.RequestSession{ID: "s1"},
|
||||
SetRequestHeaders: map[string]string{
|
||||
"X-Custom-Header": "CUSTOM_VALUE",
|
||||
"X-ID-Token": "${pomerium.id_token}",
|
||||
|
@ -228,7 +229,7 @@ func TestHeadersEvaluator(t *testing.T) {
|
|||
"Authorization": "Bearer ${pomerium.jwt}",
|
||||
"Foo": "escaped $$dollar sign",
|
||||
},
|
||||
ClientCertificate: ClientCertificateInfo{Leaf: testValidCert},
|
||||
ClientCertificate: evaluator.ClientCertificateInfo{Leaf: testValidCert},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -258,11 +259,11 @@ func TestHeadersEvaluator(t *testing.T) {
|
|||
AccessToken: "ACCESS_TOKEN",
|
||||
}},
|
||||
},
|
||||
&HeadersRequest{
|
||||
&evaluator.HeadersRequest{
|
||||
Issuer: "from.example.com",
|
||||
Audience: "from.example.com",
|
||||
ToAudience: "to.example.com",
|
||||
Session: RequestSession{ID: "s1"},
|
||||
Session: evaluator.RequestSession{ID: "s1"},
|
||||
SetRequestHeaders: map[string]string{
|
||||
"X-ID-Token": "${pomerium.id_token}",
|
||||
},
|
||||
|
@ -281,11 +282,11 @@ func TestHeadersEvaluator(t *testing.T) {
|
|||
AccessToken: "ACCESS_TOKEN",
|
||||
}},
|
||||
},
|
||||
&HeadersRequest{
|
||||
&evaluator.HeadersRequest{
|
||||
Issuer: "from.example.com",
|
||||
Audience: "from.example.com",
|
||||
ToAudience: "to.example.com",
|
||||
Session: RequestSession{ID: "s1"},
|
||||
Session: evaluator.RequestSession{ID: "s1"},
|
||||
SetRequestHeaders: map[string]string{
|
||||
"Authorization": "Bearer ${pomerium.id_token}",
|
||||
},
|
||||
|
@ -297,7 +298,7 @@ func TestHeadersEvaluator(t *testing.T) {
|
|||
|
||||
t.Run("set_request_headers no client cert", func(t *testing.T) {
|
||||
output, err := eval(t, nil,
|
||||
&HeadersRequest{
|
||||
&evaluator.HeadersRequest{
|
||||
Issuer: "from.example.com",
|
||||
Audience: "from.example.com",
|
||||
ToAudience: "to.example.com",
|
||||
|
@ -318,12 +319,12 @@ func TestHeadersEvaluator(t *testing.T) {
|
|||
&session.Session{Id: "s1", UserId: "u1"},
|
||||
&user.User{Id: "u1", Email: "u1@example.com"},
|
||||
},
|
||||
&HeadersRequest{
|
||||
&evaluator.HeadersRequest{
|
||||
Issuer: "from.example.com",
|
||||
Audience: "from.example.com",
|
||||
ToAudience: "to.example.com",
|
||||
KubernetesServiceAccountToken: "TOKEN",
|
||||
Session: RequestSession{ID: "s1"},
|
||||
Session: evaluator.RequestSession{ID: "s1"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Bearer TOKEN", output.Headers.Get("Authorization"))
|
||||
|
|
|
@ -3,11 +3,10 @@ package evaluator
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
rttrace "runtime/trace"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
octrace "go.opencensus.io/trace"
|
||||
|
||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -16,6 +15,7 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/policy"
|
||||
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
||||
octrace "go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// PolicyRequest is the input to policy evaluation.
|
||||
|
@ -109,25 +109,36 @@ type PolicyEvaluator struct {
|
|||
|
||||
// NewPolicyEvaluator creates a new PolicyEvaluator.
|
||||
func NewPolicyEvaluator(
|
||||
ctx context.Context, store *store.Store, configPolicy *config.Policy,
|
||||
ctx context.Context,
|
||||
store *store.Store,
|
||||
configPolicy *config.Policy,
|
||||
policyChecksum uint64,
|
||||
addDefaultClientCertificateRule bool,
|
||||
cache *QueryCache,
|
||||
) (*PolicyEvaluator, error) {
|
||||
e := new(PolicyEvaluator)
|
||||
e.policyChecksum = configPolicy.Checksum()
|
||||
e.policyChecksum = policyChecksum
|
||||
|
||||
var err error
|
||||
rttrace.WithRegion(ctx, "Generate Rego", func() {
|
||||
// generate the base rego script for the policy
|
||||
ppl := configPolicy.ToPPL()
|
||||
if addDefaultClientCertificateRule {
|
||||
ppl.AddDefaultClientCertificateRule()
|
||||
}
|
||||
base, err := policy.GenerateRegoFromPolicy(ppl)
|
||||
var base string
|
||||
base, err = policy.GenerateRegoFromPolicy(ppl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return
|
||||
}
|
||||
|
||||
e.queries = []policyQuery{{
|
||||
script: base,
|
||||
}}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// add any custom rego
|
||||
for _, sp := range configPolicy.SubPolicies {
|
||||
|
@ -145,15 +156,12 @@ func NewPolicyEvaluator(
|
|||
}
|
||||
}
|
||||
|
||||
// for each script, create a rego and prepare a query.
|
||||
// for each script, create a rego object and prepare a query.
|
||||
rttrace.WithRegion(ctx, "Compile Rego", func() {
|
||||
numCached := 0
|
||||
for i := range e.queries {
|
||||
log.Ctx(ctx).
|
||||
Trace().
|
||||
Str("script", e.queries[i].script).
|
||||
Str("from", configPolicy.From).
|
||||
Interface("to", configPolicy.To).
|
||||
Msg("authorize: rego script for policy evaluation")
|
||||
|
||||
var cached bool
|
||||
e.queries[i].PreparedEvalQuery, cached, err = cache.LookupOrBuild(&e.queries[i], func() (rego.PreparedEvalQuery, error) {
|
||||
r := rego.New(
|
||||
rego.Store(store),
|
||||
rego.Module("pomerium.policy", e.queries[i].script),
|
||||
|
@ -177,10 +185,20 @@ func NewPolicyEvaluator(
|
|||
q, err = r.PrepareForEval(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return rego.PreparedEvalQuery{}, err
|
||||
}
|
||||
|
||||
e.queries[i].PreparedEvalQuery = q
|
||||
return q, nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cached {
|
||||
numCached++
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return e, nil
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package evaluator
|
||||
package evaluator_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -13,6 +13,7 @@ import (
|
|||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||
|
@ -34,13 +35,15 @@ func TestPolicyEvaluator(t *testing.T) {
|
|||
|
||||
var addDefaultClientCertificateRule bool
|
||||
|
||||
eval := func(t *testing.T, policy *config.Policy, data []proto.Message, input *PolicyRequest) (*PolicyResponse, error) {
|
||||
eval := func(t *testing.T, policy *config.Policy, data []proto.Message, input *evaluator.PolicyRequest) (*evaluator.PolicyResponse, error) {
|
||||
ctx := context.Background()
|
||||
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...))
|
||||
store := store.New()
|
||||
store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY"))
|
||||
store.UpdateSigningKey(privateJWK)
|
||||
e, err := NewPolicyEvaluator(ctx, store, policy, addDefaultClientCertificateRule)
|
||||
checksum := policy.Checksum()
|
||||
queryCache := evaluator.NewQueryCache(1)
|
||||
e, err := evaluator.NewPolicyEvaluator(ctx, store, policy, checksum, addDefaultClientCertificateRule, queryCache)
|
||||
require.NoError(t, err)
|
||||
return e.Evaluate(ctx, input)
|
||||
}
|
||||
|
@ -71,14 +74,14 @@ func TestPolicyEvaluator(t *testing.T) {
|
|||
output, err := eval(t,
|
||||
p1,
|
||||
[]proto.Message{s1, u1, s2, u2},
|
||||
&PolicyRequest{
|
||||
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: RequestSession{ID: "s1"},
|
||||
&evaluator.PolicyRequest{
|
||||
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: evaluator.RequestSession{ID: "s1"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &PolicyResponse{
|
||||
Allow: NewRuleResult(true, criteria.ReasonEmailOK),
|
||||
Deny: NewRuleResult(false),
|
||||
assert.Equal(t, &evaluator.PolicyResponse{
|
||||
Allow: evaluator.NewRuleResult(true, criteria.ReasonEmailOK),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
Traces: []contextutil.PolicyEvaluationTrace{{Allow: true}},
|
||||
}, output)
|
||||
})
|
||||
|
@ -86,14 +89,14 @@ func TestPolicyEvaluator(t *testing.T) {
|
|||
output, err := eval(t,
|
||||
p1,
|
||||
[]proto.Message{s1, u1, s2, u2},
|
||||
&PolicyRequest{
|
||||
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: RequestSession{ID: "s2"},
|
||||
&evaluator.PolicyRequest{
|
||||
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: evaluator.RequestSession{ID: "s2"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &PolicyResponse{
|
||||
Allow: NewRuleResult(false, criteria.ReasonEmailUnauthorized, criteria.ReasonUserUnauthorized),
|
||||
Deny: NewRuleResult(false),
|
||||
assert.Equal(t, &evaluator.PolicyResponse{
|
||||
Allow: evaluator.NewRuleResult(false, criteria.ReasonEmailUnauthorized, criteria.ReasonUserUnauthorized),
|
||||
Deny: evaluator.NewRuleResult(false),
|
||||
Traces: []contextutil.PolicyEvaluationTrace{{}},
|
||||
}, output)
|
||||
})
|
||||
|
@ -105,16 +108,16 @@ func TestPolicyEvaluator(t *testing.T) {
|
|||
output, err := eval(t,
|
||||
p1,
|
||||
[]proto.Message{s1, u1, s2, u2},
|
||||
&PolicyRequest{
|
||||
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: RequestSession{ID: "s1"},
|
||||
&evaluator.PolicyRequest{
|
||||
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: evaluator.RequestSession{ID: "s1"},
|
||||
|
||||
IsValidClientCertificate: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &PolicyResponse{
|
||||
Allow: NewRuleResult(true, criteria.ReasonEmailOK),
|
||||
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificate),
|
||||
assert.Equal(t, &evaluator.PolicyResponse{
|
||||
Allow: evaluator.NewRuleResult(true, criteria.ReasonEmailOK),
|
||||
Deny: evaluator.NewRuleResult(false, criteria.ReasonValidClientCertificate),
|
||||
Traces: []contextutil.PolicyEvaluationTrace{{Allow: true}},
|
||||
}, output)
|
||||
})
|
||||
|
@ -122,16 +125,16 @@ func TestPolicyEvaluator(t *testing.T) {
|
|||
output, err := eval(t,
|
||||
p1,
|
||||
[]proto.Message{s1, u1, s2, u2},
|
||||
&PolicyRequest{
|
||||
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: RequestSession{ID: "s1"},
|
||||
&evaluator.PolicyRequest{
|
||||
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: evaluator.RequestSession{ID: "s1"},
|
||||
|
||||
IsValidClientCertificate: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &PolicyResponse{
|
||||
Allow: NewRuleResult(true, criteria.ReasonEmailOK),
|
||||
Deny: NewRuleResult(true, criteria.ReasonClientCertificateRequired),
|
||||
assert.Equal(t, &evaluator.PolicyResponse{
|
||||
Allow: evaluator.NewRuleResult(true, criteria.ReasonEmailOK),
|
||||
Deny: evaluator.NewRuleResult(true, criteria.ReasonClientCertificateRequired),
|
||||
Traces: []contextutil.PolicyEvaluationTrace{{Allow: true, Deny: true}},
|
||||
}, output)
|
||||
})
|
||||
|
@ -139,20 +142,20 @@ func TestPolicyEvaluator(t *testing.T) {
|
|||
output, err := eval(t,
|
||||
p1,
|
||||
[]proto.Message{s1, u1, s2, u2},
|
||||
&PolicyRequest{
|
||||
HTTP: RequestHTTP{
|
||||
&evaluator.PolicyRequest{
|
||||
HTTP: evaluator.RequestHTTP{
|
||||
Method: http.MethodGet,
|
||||
URL: "https://from.example.com/path",
|
||||
ClientCertificate: ClientCertificateInfo{Presented: true},
|
||||
ClientCertificate: evaluator.ClientCertificateInfo{Presented: true},
|
||||
},
|
||||
Session: RequestSession{ID: "s1"},
|
||||
Session: evaluator.RequestSession{ID: "s1"},
|
||||
|
||||
IsValidClientCertificate: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &PolicyResponse{
|
||||
Allow: NewRuleResult(true, criteria.ReasonEmailOK),
|
||||
Deny: NewRuleResult(true, criteria.ReasonInvalidClientCertificate),
|
||||
assert.Equal(t, &evaluator.PolicyResponse{
|
||||
Allow: evaluator.NewRuleResult(true, criteria.ReasonEmailOK),
|
||||
Deny: evaluator.NewRuleResult(true, criteria.ReasonInvalidClientCertificate),
|
||||
Traces: []contextutil.PolicyEvaluationTrace{{Allow: true, Deny: true}},
|
||||
}, output)
|
||||
})
|
||||
|
@ -160,16 +163,16 @@ func TestPolicyEvaluator(t *testing.T) {
|
|||
output, err := eval(t,
|
||||
p1,
|
||||
[]proto.Message{s1, u1, s2, u2},
|
||||
&PolicyRequest{
|
||||
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: RequestSession{ID: "s2"},
|
||||
&evaluator.PolicyRequest{
|
||||
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: evaluator.RequestSession{ID: "s2"},
|
||||
|
||||
IsValidClientCertificate: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &PolicyResponse{
|
||||
Allow: NewRuleResult(false, criteria.ReasonEmailUnauthorized, criteria.ReasonUserUnauthorized),
|
||||
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificate),
|
||||
assert.Equal(t, &evaluator.PolicyResponse{
|
||||
Allow: evaluator.NewRuleResult(false, criteria.ReasonEmailUnauthorized, criteria.ReasonUserUnauthorized),
|
||||
Deny: evaluator.NewRuleResult(false, criteria.ReasonValidClientCertificate),
|
||||
Traces: []contextutil.PolicyEvaluationTrace{{}},
|
||||
}, output)
|
||||
})
|
||||
|
@ -192,16 +195,16 @@ func TestPolicyEvaluator(t *testing.T) {
|
|||
output, err := eval(t,
|
||||
p,
|
||||
[]proto.Message{s1, u1, s2, u2},
|
||||
&PolicyRequest{
|
||||
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: RequestSession{ID: "s1"},
|
||||
&evaluator.PolicyRequest{
|
||||
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: evaluator.RequestSession{ID: "s1"},
|
||||
|
||||
IsValidClientCertificate: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &PolicyResponse{
|
||||
Allow: NewRuleResult(true, criteria.ReasonAccept),
|
||||
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificate),
|
||||
assert.Equal(t, &evaluator.PolicyResponse{
|
||||
Allow: evaluator.NewRuleResult(true, criteria.ReasonAccept),
|
||||
Deny: evaluator.NewRuleResult(false, criteria.ReasonValidClientCertificate),
|
||||
Traces: []contextutil.PolicyEvaluationTrace{{}, {ID: "p1", Allow: true}},
|
||||
}, output)
|
||||
})
|
||||
|
@ -222,16 +225,16 @@ func TestPolicyEvaluator(t *testing.T) {
|
|||
output, err := eval(t,
|
||||
p,
|
||||
[]proto.Message{s1, u1, s2, u2},
|
||||
&PolicyRequest{
|
||||
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: RequestSession{ID: "s1"},
|
||||
&evaluator.PolicyRequest{
|
||||
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: evaluator.RequestSession{ID: "s1"},
|
||||
|
||||
IsValidClientCertificate: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &PolicyResponse{
|
||||
Allow: NewRuleResult(false),
|
||||
Deny: NewRuleResult(true, criteria.ReasonAccept),
|
||||
assert.Equal(t, &evaluator.PolicyResponse{
|
||||
Allow: evaluator.NewRuleResult(false),
|
||||
Deny: evaluator.NewRuleResult(true, criteria.ReasonAccept),
|
||||
Traces: []contextutil.PolicyEvaluationTrace{{}, {ID: "p1", Deny: true}},
|
||||
}, output)
|
||||
})
|
||||
|
@ -253,16 +256,16 @@ func TestPolicyEvaluator(t *testing.T) {
|
|||
output, err := eval(t,
|
||||
p,
|
||||
[]proto.Message{s1, u1, s2, u2},
|
||||
&PolicyRequest{
|
||||
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: RequestSession{ID: "s1"},
|
||||
&evaluator.PolicyRequest{
|
||||
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: evaluator.RequestSession{ID: "s1"},
|
||||
|
||||
IsValidClientCertificate: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &PolicyResponse{
|
||||
Allow: NewRuleResult(false),
|
||||
Deny: NewRuleResult(true, criteria.ReasonAccept, criteria.ReasonClientCertificateRequired),
|
||||
assert.Equal(t, &evaluator.PolicyResponse{
|
||||
Allow: evaluator.NewRuleResult(false),
|
||||
Deny: evaluator.NewRuleResult(true, criteria.ReasonAccept, criteria.ReasonClientCertificateRequired),
|
||||
Traces: []contextutil.PolicyEvaluationTrace{{Deny: true}, {ID: "p1", Deny: true}},
|
||||
}, output)
|
||||
})
|
||||
|
@ -292,16 +295,16 @@ func TestPolicyEvaluator(t *testing.T) {
|
|||
output, err := eval(t,
|
||||
p,
|
||||
[]proto.Message{s1, u1, s2, u2, r1},
|
||||
&PolicyRequest{
|
||||
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: RequestSession{ID: "s1"},
|
||||
&evaluator.PolicyRequest{
|
||||
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: evaluator.RequestSession{ID: "s1"},
|
||||
|
||||
IsValidClientCertificate: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &PolicyResponse{
|
||||
Allow: NewRuleResult(true),
|
||||
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificate),
|
||||
assert.Equal(t, &evaluator.PolicyResponse{
|
||||
Allow: evaluator.NewRuleResult(true),
|
||||
Deny: evaluator.NewRuleResult(false, criteria.ReasonValidClientCertificate),
|
||||
Traces: []contextutil.PolicyEvaluationTrace{{}, {ID: "p1", Allow: true}},
|
||||
}, output)
|
||||
})
|
||||
|
@ -315,16 +318,16 @@ func TestPolicyEvaluator(t *testing.T) {
|
|||
UserId: "u1",
|
||||
},
|
||||
},
|
||||
&PolicyRequest{
|
||||
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: RequestSession{ID: "sa1"},
|
||||
&evaluator.PolicyRequest{
|
||||
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: evaluator.RequestSession{ID: "sa1"},
|
||||
|
||||
IsValidClientCertificate: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &PolicyResponse{
|
||||
Allow: NewRuleResult(true, criteria.ReasonEmailOK),
|
||||
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificate),
|
||||
assert.Equal(t, &evaluator.PolicyResponse{
|
||||
Allow: evaluator.NewRuleResult(true, criteria.ReasonEmailOK),
|
||||
Deny: evaluator.NewRuleResult(false, criteria.ReasonValidClientCertificate),
|
||||
Traces: []contextutil.PolicyEvaluationTrace{{Allow: true}},
|
||||
}, output)
|
||||
})
|
||||
|
@ -339,16 +342,16 @@ func TestPolicyEvaluator(t *testing.T) {
|
|||
ExpiresAt: timestamppb.New(time.Now().Add(-time.Second)),
|
||||
},
|
||||
},
|
||||
&PolicyRequest{
|
||||
HTTP: RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: RequestSession{ID: "sa1"},
|
||||
&evaluator.PolicyRequest{
|
||||
HTTP: evaluator.RequestHTTP{Method: http.MethodGet, URL: "https://from.example.com/path"},
|
||||
Session: evaluator.RequestSession{ID: "sa1"},
|
||||
|
||||
IsValidClientCertificate: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &PolicyResponse{
|
||||
Allow: NewRuleResult(false, criteria.ReasonUserUnauthenticated),
|
||||
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificate),
|
||||
assert.Equal(t, &evaluator.PolicyResponse{
|
||||
Allow: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
|
||||
Deny: evaluator.NewRuleResult(false, criteria.ReasonValidClientCertificate),
|
||||
Traces: []contextutil.PolicyEvaluationTrace{{Allow: false}},
|
||||
}, output)
|
||||
})
|
||||
|
|
|
@ -17,9 +17,10 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
@ -49,15 +50,27 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
|
|||
-----END CERTIFICATE-----`
|
||||
|
||||
func Test_getEvaluatorRequest(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(&config.Options{
|
||||
Policies: []config.Policy{{
|
||||
policies := []config.Policy{{
|
||||
From: "https://example.com",
|
||||
To: mustParseWeightedURLs(t, "https://foo.bar"),
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
}}
|
||||
|
||||
policy0RouteID, err := policies[0].RouteID()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{
|
||||
Options: &config.Options{
|
||||
SharedKey: cryptutil.NewBase64Key(),
|
||||
CookieSecret: cryptutil.NewBase64Key(),
|
||||
Policies: policies,
|
||||
},
|
||||
}
|
||||
a := New()
|
||||
a.OnConfigChange(context.Background(), cfg)
|
||||
require.True(t, a.HasValidState())
|
||||
|
||||
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
|
||||
&envoy_service_auth_v3.CheckRequest{
|
||||
|
@ -76,6 +89,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
|
|||
Body: "BODY",
|
||||
},
|
||||
},
|
||||
ContextExtensions: envoyconfig.MakeExtAuthzContextExtensions(false, policy0RouteID),
|
||||
MetadataContext: &envoy_config_core_v3.Metadata{
|
||||
FilterMetadata: map[string]*structpb.Struct{
|
||||
"com.pomerium.client-certificate-info": {
|
||||
|
@ -117,15 +131,27 @@ func Test_getEvaluatorRequest(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(&config.Options{
|
||||
Policies: []config.Policy{{
|
||||
policies := []config.Policy{{
|
||||
From: "https://example.com",
|
||||
To: mustParseWeightedURLs(t, "https://foo.bar"),
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
}}
|
||||
|
||||
policy0RouteID, err := policies[0].RouteID()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{
|
||||
Options: &config.Options{
|
||||
SharedKey: cryptutil.NewBase64Key(),
|
||||
CookieSecret: cryptutil.NewBase64Key(),
|
||||
Policies: policies,
|
||||
},
|
||||
}
|
||||
a := New()
|
||||
a.OnConfigChange(context.Background(), cfg)
|
||||
require.True(t, a.HasValidState())
|
||||
|
||||
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
|
||||
&envoy_service_auth_v3.CheckRequest{
|
||||
|
@ -144,11 +170,12 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
|||
Body: "BODY",
|
||||
},
|
||||
},
|
||||
ContextExtensions: envoyconfig.MakeExtAuthzContextExtensions(false, policy0RouteID),
|
||||
},
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
expect := &evaluator.Request{
|
||||
Policy: &a.currentOptions.Load().Policies[0],
|
||||
Policy: &policies[0],
|
||||
Session: evaluator.RequestSession{},
|
||||
HTTP: evaluator.NewRequestHTTP(
|
||||
http.MethodGet,
|
||||
|
|
2
go.mod
2
go.mod
|
@ -80,6 +80,7 @@ require (
|
|||
go.uber.org/mock v0.5.0
|
||||
go.uber.org/zap v1.27.0
|
||||
golang.org/x/crypto v0.28.0
|
||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa
|
||||
golang.org/x/net v0.30.0
|
||||
golang.org/x/oauth2 v0.23.0
|
||||
golang.org/x/sync v0.8.0
|
||||
|
@ -227,7 +228,6 @@ require (
|
|||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.28.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.3.1 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
|
||||
golang.org/x/mod v0.20.0 // indirect
|
||||
golang.org/x/text v0.19.0 // indirect
|
||||
golang.org/x/tools v0.24.0 // indirect
|
||||
|
|
|
@ -210,16 +210,13 @@ func setupAuthenticate(ctx context.Context, src config.Source, controlPlane *con
|
|||
}
|
||||
|
||||
func setupAuthorize(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) {
|
||||
svc, err := authorize.New(ctx, src.GetConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating authorize service: %w", err)
|
||||
}
|
||||
envoy_service_auth_v3.RegisterAuthorizationServer(controlPlane.GRPCServer, svc)
|
||||
a := authorize.New()
|
||||
envoy_service_auth_v3.RegisterAuthorizationServer(controlPlane.GRPCServer, a)
|
||||
|
||||
log.Ctx(ctx).Info().Msg("enabled authorize service")
|
||||
src.OnConfigChange(ctx, svc.OnConfigChange)
|
||||
svc.OnConfigChange(ctx, src.GetConfig())
|
||||
return svc, nil
|
||||
src.OnConfigChange(ctx, a.OnConfigChange)
|
||||
a.OnConfigChange(ctx, src.GetConfig())
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func setupDataBroker(ctx context.Context,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue