mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-01 19:36:32 +02:00
This optimizes the Evaluator in the Authorize service to scale to very large numbers of routes. Additional caching was also added when building rego policy query evaluators in parallel to allow sharing work and to avoid building evaluators for scripts with the same contents.
808 lines
24 KiB
Go
808 lines
24 KiB
Go
// Package evaluator contains rego evaluators for evaluating authorize policy.
|
|
package evaluator
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"maps"
|
|
"math/bits"
|
|
"net/http"
|
|
"net/url"
|
|
"runtime"
|
|
rttrace "runtime/trace"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/go-jose/go-jose/v3"
|
|
"github.com/hashicorp/go-set/v3"
|
|
"github.com/open-policy-agent/opa/rego"
|
|
"golang.org/x/sync/errgroup"
|
|
"golang.org/x/sync/singleflight"
|
|
|
|
"github.com/pomerium/pomerium/authorize/internal/store"
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/httputil"
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
"github.com/pomerium/pomerium/pkg/contextutil"
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
|
)
|
|
|
|
// Request contains the inputs needed for evaluation.
|
|
type Request struct {
|
|
IsInternal bool
|
|
Policy *config.Policy
|
|
HTTP RequestHTTP
|
|
Session RequestSession
|
|
}
|
|
|
|
// RequestHTTP is the HTTP field in the request.
|
|
type RequestHTTP struct {
|
|
Method string `json:"method"`
|
|
Hostname string `json:"hostname"`
|
|
Path string `json:"path"`
|
|
URL string `json:"url"`
|
|
Headers map[string]string `json:"headers"`
|
|
ClientCertificate ClientCertificateInfo `json:"client_certificate"`
|
|
IP string `json:"ip"`
|
|
}
|
|
|
|
// NewRequestHTTP creates a new RequestHTTP.
|
|
func NewRequestHTTP(
|
|
method string,
|
|
requestURL url.URL,
|
|
headers map[string]string,
|
|
clientCertificate ClientCertificateInfo,
|
|
ip string,
|
|
) RequestHTTP {
|
|
return RequestHTTP{
|
|
Method: method,
|
|
Hostname: requestURL.Hostname(),
|
|
Path: requestURL.Path,
|
|
URL: requestURL.String(),
|
|
Headers: headers,
|
|
ClientCertificate: clientCertificate,
|
|
IP: ip,
|
|
}
|
|
}
|
|
|
|
// ClientCertificateInfo contains information about the certificate presented
|
|
// by the client (if any).
|
|
type ClientCertificateInfo struct {
|
|
// Presented is true if the client presented a certificate.
|
|
Presented bool `json:"presented"`
|
|
|
|
// Leaf contains the leaf client certificate (unvalidated).
|
|
Leaf string `json:"leaf,omitempty"`
|
|
|
|
// Intermediates contains the remainder of the client certificate chain as
|
|
// it was originally presented by the client (unvalidated).
|
|
Intermediates string `json:"intermediates,omitempty"`
|
|
}
|
|
|
|
// RequestSession is the session field in the request.
|
|
type RequestSession struct {
|
|
ID string `json:"id"`
|
|
}
|
|
|
|
// Result is the result of evaluation.
|
|
type Result struct {
|
|
Allow RuleResult
|
|
Deny RuleResult
|
|
Headers http.Header
|
|
Traces []contextutil.PolicyEvaluationTrace
|
|
}
|
|
|
|
type PolicyEvaluatorCacheStats struct {
|
|
CacheHits int64
|
|
CacheMisses int64
|
|
}
|
|
|
|
type PolicyEvaluatorCache struct {
|
|
evalsMu sync.RWMutex
|
|
evaluatorsByRouteID map[uint64]*PolicyEvaluator
|
|
|
|
cacheHits atomic.Int64
|
|
cacheMisses atomic.Int64
|
|
}
|
|
|
|
type QueryCacheStats struct {
|
|
CacheHits int64
|
|
CacheMisses int64
|
|
BuildsSucceeded int64
|
|
BuildsFailed int64
|
|
BuildsShared int64
|
|
}
|
|
|
|
type QueryCache struct {
|
|
queriesMu sync.RWMutex
|
|
queriesByScriptChecksum map[string]rego.PreparedEvalQuery
|
|
sf singleflight.Group
|
|
|
|
cacheHits atomic.Int64
|
|
cacheMisses atomic.Int64
|
|
buildsSucceeded atomic.Int64
|
|
buildsFailed atomic.Int64
|
|
buildsShared atomic.Int64
|
|
}
|
|
|
|
func NewPolicyEvaluatorCache(initialSize int) *PolicyEvaluatorCache {
|
|
return &PolicyEvaluatorCache{
|
|
evaluatorsByRouteID: make(map[uint64]*PolicyEvaluator, initialSize),
|
|
}
|
|
}
|
|
|
|
func NewQueryCache(initialSize int) *QueryCache {
|
|
return &QueryCache{
|
|
queriesByScriptChecksum: make(map[string]rego.PreparedEvalQuery, initialSize),
|
|
}
|
|
}
|
|
|
|
func (c *PolicyEvaluatorCache) NumCachedEvaluators() int {
|
|
c.evalsMu.RLock()
|
|
defer c.evalsMu.RUnlock()
|
|
return len(c.evaluatorsByRouteID)
|
|
}
|
|
|
|
func (c *PolicyEvaluatorCache) StoreEvaluator(routeID uint64, eval *PolicyEvaluator) {
|
|
c.evalsMu.Lock()
|
|
defer c.evalsMu.Unlock()
|
|
c.evaluatorsByRouteID[routeID] = eval
|
|
}
|
|
|
|
func (c *PolicyEvaluatorCache) LookupEvaluator(routeID uint64) (*PolicyEvaluator, bool) {
|
|
c.evalsMu.RLock()
|
|
defer c.evalsMu.RUnlock()
|
|
eval, ok := c.evaluatorsByRouteID[routeID]
|
|
if ok {
|
|
c.cacheHits.Add(1)
|
|
} else {
|
|
c.cacheMisses.Add(1)
|
|
}
|
|
return eval, ok
|
|
}
|
|
|
|
func (c *PolicyEvaluatorCache) Stats() PolicyEvaluatorCacheStats {
|
|
return PolicyEvaluatorCacheStats{
|
|
CacheHits: c.cacheHits.Load(),
|
|
CacheMisses: c.cacheMisses.Load(),
|
|
}
|
|
}
|
|
|
|
func (c *QueryCache) NumCachedQueries() int {
|
|
c.queriesMu.RLock()
|
|
defer c.queriesMu.RUnlock()
|
|
return len(c.queriesByScriptChecksum)
|
|
}
|
|
|
|
func (c *QueryCache) Stats() QueryCacheStats {
|
|
return QueryCacheStats{
|
|
CacheHits: c.cacheHits.Load(),
|
|
CacheMisses: c.cacheMisses.Load(),
|
|
BuildsSucceeded: c.buildsSucceeded.Load(),
|
|
BuildsFailed: c.buildsFailed.Load(),
|
|
BuildsShared: c.buildsShared.Load(),
|
|
}
|
|
}
|
|
|
|
func (c *QueryCache) LookupOrBuild(q *policyQuery, builder func() (rego.PreparedEvalQuery, error)) (rego.PreparedEvalQuery, bool, error) {
|
|
checksum := q.checksum()
|
|
c.queriesMu.RLock()
|
|
cached, ok := c.queriesByScriptChecksum[checksum]
|
|
c.queriesMu.RUnlock()
|
|
if ok {
|
|
c.cacheHits.Add(1)
|
|
return cached, true, nil
|
|
}
|
|
c.cacheMisses.Add(1)
|
|
var ours bool
|
|
pq, err, shared := c.sf.Do(checksum, func() (any, error) {
|
|
ours = true
|
|
res, err := builder()
|
|
if err == nil {
|
|
c.queriesMu.Lock()
|
|
c.queriesByScriptChecksum[checksum] = res
|
|
c.queriesMu.Unlock()
|
|
c.buildsSucceeded.Add(1)
|
|
} else {
|
|
c.buildsFailed.Add(1)
|
|
}
|
|
return res, err
|
|
})
|
|
if err != nil {
|
|
return rego.PreparedEvalQuery{}, false, err
|
|
}
|
|
if shared && !ours {
|
|
c.buildsShared.Add(1)
|
|
}
|
|
return pq.(rego.PreparedEvalQuery), false, nil
|
|
}
|
|
|
|
// An Evaluator evaluates policies.
|
|
type Evaluator struct {
|
|
opts *evaluatorConfig
|
|
store *store.Store
|
|
evalCache *PolicyEvaluatorCache
|
|
queryCache *QueryCache
|
|
headersEvaluator *HeadersEvaluator
|
|
}
|
|
|
|
// New creates a new Evaluator.
|
|
func New(
|
|
ctx context.Context,
|
|
store *store.Store,
|
|
previous *Evaluator,
|
|
options ...Option,
|
|
) (*Evaluator, error) {
|
|
ctx, task := rttrace.NewTask(ctx, "evaluator.New")
|
|
defer task.End()
|
|
defer rttrace.StartRegion(ctx, "evaluator.New").End()
|
|
|
|
var opts evaluatorConfig
|
|
opts.apply(options...)
|
|
|
|
e := &Evaluator{
|
|
opts: &opts,
|
|
store: store,
|
|
}
|
|
|
|
if previous == nil || opts.cacheKey() != previous.opts.cacheKey() || store != previous.store {
|
|
var err error
|
|
rttrace.WithRegion(ctx, "update store", func() {
|
|
err = updateStore(ctx, store, &opts, previous)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rttrace.WithRegion(ctx, "create headers evaluator", func() {
|
|
e.headersEvaluator, err = NewHeadersEvaluator(ctx, store)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
e.evalCache = NewPolicyEvaluatorCache(len(opts.Policies))
|
|
e.queryCache = NewQueryCache(len(opts.Policies))
|
|
} else {
|
|
// If there is a previous Evaluator constructed from the same settings, we
|
|
// can reuse the HeadersEvaluator along with any PolicyEvaluators for
|
|
// unchanged policies.
|
|
e.headersEvaluator = previous.headersEvaluator
|
|
e.evalCache = previous.evalCache
|
|
e.queryCache = previous.queryCache
|
|
}
|
|
|
|
var err error
|
|
rttrace.WithRegion(ctx, "update policy evaluators", func() {
|
|
err = getOrCreatePolicyEvaluators(ctx, &opts, store, e.evalCache, e.queryCache)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return e, nil
|
|
}
|
|
|
|
var (
|
|
workerPoolSize = runtime.NumCPU() - 1
|
|
workerPoolMu sync.Mutex
|
|
workerPoolTaskQueue = make(chan func(), (workerPoolSize+1)*2)
|
|
)
|
|
|
|
func init() {
|
|
for i := 0; i < workerPoolSize; i++ {
|
|
// the worker function is separate so that it shows up in stack traces as
|
|
// 'worker' instead of an anonymous function in init()
|
|
go worker()
|
|
}
|
|
}
|
|
|
|
func worker() {
|
|
workerPoolMu.Lock()
|
|
queue := workerPoolTaskQueue
|
|
workerPoolMu.Unlock()
|
|
for fn := range queue {
|
|
fn()
|
|
}
|
|
}
|
|
|
|
type chunkSizes interface {
|
|
uint8 | uint16 | uint32 | uint64
|
|
}
|
|
|
|
// returns the size in bits for any allowed type in chunkSizes
|
|
func chunkSize[T chunkSizes]() int {
|
|
return int(unsafe.Sizeof(T(0))) * 8
|
|
}
|
|
|
|
type workerContext[T chunkSizes] struct {
|
|
context.Context
|
|
Cfg *evaluatorConfig
|
|
Store *store.Store
|
|
StatusBits []T
|
|
Evaluators []routeEvaluator
|
|
EvalCache *PolicyEvaluatorCache
|
|
QueryCache *QueryCache
|
|
Errs *sync.Map
|
|
}
|
|
|
|
type routeEvaluator struct {
|
|
ID uint64 // route id
|
|
Evaluator *PolicyEvaluator // the compiled evaluator
|
|
ComputedChecksum uint64 // cached evaluator checksum
|
|
}
|
|
|
|
// partition represents a range of chunks (fixed-size blocks of policies)
|
|
// corresponding to the Cfg.Policies, StatusBits, and Evaluators fields in
|
|
// the workerContext. It is the slice [Begin*chunkSize:End*chunkSize] w.r.t.
|
|
// those fields, but each index represents a unit of work that can be done in
|
|
// parallel with work on other chunks.
|
|
type partition struct{ Begin, End int }
|
|
|
|
// computeChecksums is a worker task that computes policy checksums and updates
|
|
// StatusBits to flag policies that need to be rebuilt. It operates on entire
|
|
// chunks (fixed-size blocks of policies), given by the start and end indexes
|
|
// in the partition argument, and updates the corresponding indexes of the
|
|
// StatusBits field of the worker context for those chunks.
|
|
func computeChecksums[T chunkSizes](wctx *workerContext[T], part partition) {
|
|
defer rttrace.StartRegion(wctx, "worker-checksum").End()
|
|
for chunkIdx := part.Begin; chunkIdx < part.End; chunkIdx++ {
|
|
var chunkStatus T
|
|
chunkSize := chunkSize[T]()
|
|
off := chunkIdx * chunkSize // chunk offset
|
|
// If there are fewer than chunkSize policies remaining in the actual list,
|
|
// don't go beyond the end
|
|
limit := min(chunkSize, len(wctx.Cfg.Policies)-off)
|
|
popcount := 0
|
|
for i := range limit {
|
|
p := wctx.Cfg.Policies[off+i]
|
|
// Compute the route id; this value is reused later as the route name
|
|
// when computing the checksum
|
|
id, err := p.RouteID()
|
|
if err != nil {
|
|
wctx.Errs.Store(off+i, fmt.Errorf("authorize: error computing policy route id: %w", err))
|
|
continue
|
|
}
|
|
eval := &wctx.Evaluators[off+i]
|
|
eval.ID = id
|
|
// Compute the policy checksum and cache it in the evaluator, reusing
|
|
// the route ID from before (to avoid needing to compute it again)
|
|
eval.ComputedChecksum = p.Checksum()
|
|
// eval.ComputedChecksum = p.ChecksumWithID(id) // TODO: update this when merged
|
|
|
|
// Check if there is an existing evaluator cached for the route ID
|
|
// NB: the route ID is composed of a subset of fields of the Policy; this
|
|
// means the cache will hit if the route ID fields are the same, even if
|
|
// other fields in the policy differ.
|
|
cached, ok := wctx.EvalCache.LookupEvaluator(id)
|
|
if !ok {
|
|
rttrace.Logf(wctx, "", "policy for route ID %d not found in cache", id)
|
|
chunkStatus |= T(1 << i)
|
|
popcount++
|
|
} else if cached.policyChecksum != eval.ComputedChecksum {
|
|
// Route ID is the same, but the full checksum differs
|
|
rttrace.Logf(wctx, "", "policy for route ID %d changed", id)
|
|
chunkStatus |= T(1 << i)
|
|
popcount++
|
|
}
|
|
// On a cache hit, chunkStatus for the ith bit stays at 0
|
|
}
|
|
// Set chunkStatus bitmask all at once (for better locality)
|
|
wctx.StatusBits[chunkIdx] = chunkStatus
|
|
rttrace.Logf(wctx, "", "chunk %d: %d/%d changed", chunkIdx, popcount, limit)
|
|
}
|
|
}
|
|
|
|
// buildEvaluators is a worker task that creates new policy evaluators. It
|
|
// operates on entire chunks (fixed-size blocks of policies), given by the start
|
|
// and end indexes in the partition argument, and updates the corresponding
|
|
// indexes of the Evaluators field of the worker context for those chunks.
|
|
func buildEvaluators[T chunkSizes](wctx *workerContext[T], part partition) {
|
|
chunkSize := chunkSize[T]()
|
|
defer rttrace.StartRegion(wctx, "worker-build").End()
|
|
addDefaultCert := wctx.Cfg.AddDefaultClientCertificateRule
|
|
var err error
|
|
for chunkIdx := part.Begin; chunkIdx < part.End; chunkIdx++ {
|
|
// Obtain the bitmask computed by computeChecksums for this chunk
|
|
stat := wctx.StatusBits[chunkIdx]
|
|
rttrace.Logf(wctx, "", "chunk %d: status: %0*b", chunkIdx, chunkSize, stat)
|
|
|
|
// Iterate over all the set bits in stat. This works by finding the
|
|
// lowest set bit, zeroing it, and repeating. The go compiler will
|
|
// replace [bits.TrailingZeros64] with intrinsics on most platforms.
|
|
for stat != 0 {
|
|
bit := bits.TrailingZeros64(uint64(stat)) // find the lowest set bit
|
|
stat &= (stat - 1) // clear the lowest set bit
|
|
idx := (chunkSize * chunkIdx) + bit
|
|
p := wctx.Cfg.Policies[idx]
|
|
eval := &wctx.Evaluators[idx]
|
|
eval.Evaluator, err = NewPolicyEvaluator(wctx, wctx.Store, p, eval.ComputedChecksum, addDefaultCert, wctx.QueryCache)
|
|
if err != nil {
|
|
wctx.Errs.Store(idx, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// bestChunkSize determines the chunk size (8, 16, 32, or 64) to use for the
|
|
// given number of policies and workers.
|
|
func bestChunkSize(numPolicies, numWorkers int) int {
|
|
// use the chunk size that results in the largest number of chunks without
|
|
// going past the number of workers. this results in the following behavior:
|
|
// - as the number of policies increases, chunk size tends to increase
|
|
// - as the number of workers increases, chunk size tends to decrease
|
|
sizes := []int{64, 32, 16, 8}
|
|
sizeIdx := 0
|
|
for i, size := range sizes {
|
|
if float64(numPolicies)/float64(size) > float64(numWorkers) {
|
|
break
|
|
}
|
|
sizeIdx = i
|
|
}
|
|
return sizes[sizeIdx]
|
|
}
|
|
|
|
func getOrCreatePolicyEvaluators(
|
|
ctx context.Context,
|
|
cfg *evaluatorConfig,
|
|
store *store.Store,
|
|
evalCache *PolicyEvaluatorCache,
|
|
queryCache *QueryCache,
|
|
) error {
|
|
chunkSize := bestChunkSize(len(cfg.Policies), workerPoolSize)
|
|
switch chunkSize {
|
|
case 8:
|
|
return getOrCreatePolicyEvaluatorsT[uint8](ctx, cfg, store, evalCache, queryCache)
|
|
case 16:
|
|
return getOrCreatePolicyEvaluatorsT[uint16](ctx, cfg, store, evalCache, queryCache)
|
|
case 32:
|
|
return getOrCreatePolicyEvaluatorsT[uint32](ctx, cfg, store, evalCache, queryCache)
|
|
case 64:
|
|
return getOrCreatePolicyEvaluatorsT[uint64](ctx, cfg, store, evalCache, queryCache)
|
|
}
|
|
panic("unreachable")
|
|
}
|
|
|
|
func getOrCreatePolicyEvaluatorsT[T chunkSizes](
|
|
ctx context.Context,
|
|
cfg *evaluatorConfig,
|
|
store *store.Store,
|
|
evalCache *PolicyEvaluatorCache,
|
|
queryCache *QueryCache,
|
|
) error {
|
|
chunkSize := bestChunkSize(len(cfg.Policies), workerPoolSize)
|
|
rttrace.Logf(ctx, "", "eval cache size: %d; query cache size: %d; chunk size: %d",
|
|
evalCache.NumCachedEvaluators(), queryCache.NumCachedQueries(), chunkSize)
|
|
now := time.Now()
|
|
|
|
// Split the policy list into chunks which can individually be operated on in
|
|
// parallel with other chunks. Each chunk has a corresponding bitmask in the
|
|
// statusBits list which is used to indicate to workers which policy
|
|
// evaluators (at indexes in the chunk corresponding to set bits) need to be
|
|
// built, or rebuilt due to changes.
|
|
numChunks := len(cfg.Policies) / chunkSize
|
|
if len(cfg.Policies)%chunkSize != 0 {
|
|
numChunks++
|
|
}
|
|
statusBits := make([]T, numChunks) // bits map directly to policy indexes
|
|
evaluators := make([]routeEvaluator, len(cfg.Policies))
|
|
if len(evaluators) == 0 {
|
|
return nil // nothing to do
|
|
}
|
|
// Limit the number of workers to the size of the worker pool; since we are
|
|
// manually distributing chunks between workers, we can avoid spawning more
|
|
// goroutines than we need, and instead giving each worker additional chunks.
|
|
numWorkers := min(workerPoolSize, numChunks)
|
|
// Each worker is given a minimum number of chunks, then the remainder are
|
|
// spread evenly between workers.
|
|
minChunksPerWorker := numChunks / numWorkers
|
|
overflow := numChunks % numWorkers // number of workers which get an additional chunk
|
|
|
|
wctx := &workerContext[T]{
|
|
Context: ctx,
|
|
Cfg: cfg,
|
|
Store: store,
|
|
StatusBits: statusBits,
|
|
Evaluators: evaluators,
|
|
EvalCache: evalCache,
|
|
QueryCache: queryCache,
|
|
Errs: &sync.Map{}, // policy index->error
|
|
}
|
|
|
|
// First, build a list of partitions (start/end chunk indexes) to send to
|
|
// each worker.
|
|
partitions := make([]partition, numWorkers)
|
|
rttrace.WithRegion(ctx, "partitioning", func() {
|
|
chunkIdx := 0
|
|
for workerIdx := range numWorkers {
|
|
chunkStart := chunkIdx
|
|
chunkEnd := chunkStart + minChunksPerWorker
|
|
if workerIdx < overflow {
|
|
chunkEnd++
|
|
}
|
|
chunkIdx = chunkEnd
|
|
partitions[workerIdx] = partition{chunkStart, chunkEnd}
|
|
}
|
|
})
|
|
|
|
// Compute all route checksums in parallel to determine which routes need to
|
|
// be rebuilt.
|
|
rttrace.WithRegion(ctx, "computing checksums", func() {
|
|
var wg sync.WaitGroup
|
|
for _, part := range partitions {
|
|
wg.Add(1)
|
|
workerPoolTaskQueue <- func() {
|
|
defer wg.Done()
|
|
computeChecksums(wctx, part)
|
|
}
|
|
}
|
|
wg.Wait()
|
|
})
|
|
|
|
hasErrs := false
|
|
wctx.Errs.Range(func(key, value any) bool {
|
|
log.Ctx(ctx).Error().Int("policy-index", key.(int)).Msg(value.(error).Error())
|
|
hasErrs = true
|
|
return true
|
|
})
|
|
if hasErrs {
|
|
return fmt.Errorf("authorize: error computing one or more policy route IDs")
|
|
}
|
|
|
|
// After all checksums are computed and status bits populated, build the
|
|
// required evaluators.
|
|
rttrace.WithRegion(ctx, "building evaluators", func() {
|
|
var wg sync.WaitGroup
|
|
for _, part := range partitions {
|
|
// Adjust the partition to skip over chunks with 0 bits set
|
|
for part.Begin < part.End && statusBits[part.Begin] == 0 {
|
|
part.Begin++
|
|
}
|
|
for part.Begin < (part.End)-1 && statusBits[part.End-1] == 0 {
|
|
part.End--
|
|
}
|
|
if part.Begin == part.End {
|
|
continue
|
|
}
|
|
wg.Add(1)
|
|
workerPoolTaskQueue <- func() {
|
|
defer wg.Done()
|
|
buildEvaluators(wctx, part)
|
|
}
|
|
}
|
|
wg.Wait()
|
|
})
|
|
|
|
hasErrs = false
|
|
wctx.Errs.Range(func(key, value any) bool {
|
|
log.Ctx(ctx).Error().Int("policy-index", key.(int)).Msg(value.(error).Error())
|
|
hasErrs = true
|
|
return true
|
|
})
|
|
if hasErrs {
|
|
return fmt.Errorf("authorize: error building policy evaluators")
|
|
}
|
|
|
|
// Store updated evaluators in the cache
|
|
updatedCount := 0
|
|
for _, p := range evaluators {
|
|
if p.Evaluator != nil { // these are only set when modified
|
|
updatedCount++
|
|
evalCache.StoreEvaluator(p.ID, p.Evaluator)
|
|
}
|
|
}
|
|
|
|
log.Ctx(ctx).Debug().
|
|
Dur("duration", time.Since(now)).
|
|
Int("reused-policies", len(cfg.Policies)-updatedCount).
|
|
Int("created-policies", updatedCount).
|
|
Msg("updated policy evaluators")
|
|
return nil
|
|
}
|
|
|
|
// Evaluate evaluates the rego for the given policy and generates the identity headers.
|
|
func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error) {
|
|
ctx, span := trace.StartSpan(ctx, "authorize.Evaluator.Evaluate")
|
|
defer span.End()
|
|
|
|
eg, ctx := errgroup.WithContext(ctx)
|
|
|
|
var policyOutput *PolicyResponse
|
|
eg.Go(func() error {
|
|
var err error
|
|
if req.IsInternal {
|
|
policyOutput, err = e.evaluateInternal(ctx, req)
|
|
} else {
|
|
policyOutput, err = e.evaluatePolicy(ctx, req)
|
|
}
|
|
return err
|
|
})
|
|
|
|
var headersOutput *HeadersResponse
|
|
eg.Go(func() error {
|
|
var err error
|
|
headersOutput, err = e.evaluateHeaders(ctx, req)
|
|
return err
|
|
})
|
|
|
|
err := eg.Wait()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
res := &Result{
|
|
Allow: policyOutput.Allow,
|
|
Deny: policyOutput.Deny,
|
|
Headers: headersOutput.Headers,
|
|
Traces: policyOutput.Traces,
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// Internal endpoints that require a logged-in user.
|
|
var internalPathsNeedingLogin = set.From([]string{
|
|
"/.pomerium/jwt",
|
|
"/.pomerium/user",
|
|
"/.pomerium/webauthn",
|
|
})
|
|
|
|
func (e *Evaluator) evaluateInternal(_ context.Context, req *Request) (*PolicyResponse, error) {
|
|
if internalPathsNeedingLogin.Contains(req.HTTP.Path) {
|
|
if req.Session.ID == "" {
|
|
return &PolicyResponse{
|
|
Allow: NewRuleResult(false, criteria.ReasonUserUnauthenticated),
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
return &PolicyResponse{
|
|
Allow: NewRuleResult(true, criteria.ReasonPomeriumRoute),
|
|
}, nil
|
|
}
|
|
|
|
func (e *Evaluator) evaluatePolicy(ctx context.Context, req *Request) (*PolicyResponse, error) {
|
|
if req.Policy == nil {
|
|
return &PolicyResponse{
|
|
Deny: NewRuleResult(true, criteria.ReasonRouteNotFound),
|
|
}, nil
|
|
}
|
|
|
|
id, err := req.Policy.RouteID()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("authorize: error computing policy route id: %w", err)
|
|
}
|
|
|
|
policyEvaluator, ok := e.evalCache.LookupEvaluator(id)
|
|
if !ok {
|
|
return &PolicyResponse{
|
|
Deny: NewRuleResult(true, criteria.ReasonRouteNotFound),
|
|
}, nil
|
|
}
|
|
|
|
clientCA, err := e.getClientCA(req.Policy)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
isValidClientCertificate, err := isValidClientCertificate(
|
|
clientCA, string(e.opts.ClientCRL), req.HTTP.ClientCertificate, e.opts.ClientCertConstraints)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("authorize: error validating client certificate: %w", err)
|
|
}
|
|
|
|
return policyEvaluator.Evaluate(ctx, &PolicyRequest{
|
|
HTTP: req.HTTP,
|
|
Session: req.Session,
|
|
IsValidClientCertificate: isValidClientCertificate,
|
|
})
|
|
}
|
|
|
|
func (e *Evaluator) evaluateHeaders(ctx context.Context, req *Request) (*HeadersResponse, error) {
|
|
headersReq, err := NewHeadersRequestFromPolicy(req.Policy, req.HTTP)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
headersReq.Session = req.Session
|
|
res, err := e.headersEvaluator.Evaluate(ctx, headersReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
carryOverJWTAssertion(res.Headers, req.HTTP.Headers)
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (e *Evaluator) getClientCA(policy *config.Policy) (string, error) {
|
|
if policy != nil && policy.TLSDownstreamClientCA != "" {
|
|
bs, err := base64.StdEncoding.DecodeString(policy.TLSDownstreamClientCA)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(bs), nil
|
|
}
|
|
|
|
return string(e.opts.ClientCA), nil
|
|
}
|
|
|
|
func updateStore(ctx context.Context, store *store.Store, cfg *evaluatorConfig, previous *Evaluator) error {
|
|
if previous == nil || !bytes.Equal(cfg.SigningKey, previous.opts.SigningKey) {
|
|
jwk, err := getJWK(ctx, cfg.SigningKey)
|
|
if err != nil {
|
|
return fmt.Errorf("authorize: couldn't create signer: %w", err)
|
|
}
|
|
store.UpdateSigningKey(jwk)
|
|
}
|
|
|
|
if previous == nil || cfg.GoogleCloudServerlessAuthenticationServiceAccount != previous.opts.GoogleCloudServerlessAuthenticationServiceAccount {
|
|
store.UpdateGoogleCloudServerlessAuthenticationServiceAccount(
|
|
cfg.GoogleCloudServerlessAuthenticationServiceAccount,
|
|
)
|
|
}
|
|
|
|
if previous == nil || !maps.Equal(cfg.JWTClaimsHeaders, previous.opts.JWTClaimsHeaders) {
|
|
store.UpdateJWTClaimHeaders(cfg.JWTClaimsHeaders)
|
|
}
|
|
|
|
store.UpdateRoutePolicies(cfg.Policies)
|
|
return nil
|
|
}
|
|
|
|
func getJWK(ctx context.Context, signingKey []byte) (*jose.JSONWebKey, error) {
|
|
var decodedCert []byte
|
|
// if we don't have a signing key, generate one
|
|
if len(signingKey) == 0 {
|
|
key, err := cryptutil.NewSigningKey()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("couldn't generate signing key: %w", err)
|
|
}
|
|
decodedCert, err = cryptutil.EncodePrivateKey(key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("bad signing key: %w", err)
|
|
}
|
|
} else {
|
|
decodedCert = signingKey
|
|
}
|
|
|
|
jwk, err := cryptutil.PrivateJWKFromBytes(decodedCert)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("couldn't generate signing key: %w", err)
|
|
}
|
|
log.Ctx(ctx).Info().Str("Algorithm", jwk.Algorithm).
|
|
Str("KeyID", jwk.KeyID).
|
|
Interface("Public Key", jwk.Public()).
|
|
Msg("authorize: signing key")
|
|
|
|
return jwk, nil
|
|
}
|
|
|
|
func safeEval(ctx context.Context, q rego.PreparedEvalQuery, options ...rego.EvalOption) (resultSet rego.ResultSet, err error) {
|
|
defer func() {
|
|
if e := recover(); e != nil {
|
|
err = fmt.Errorf("%v", e)
|
|
}
|
|
}()
|
|
resultSet, err = q.Eval(ctx, options...)
|
|
return resultSet, err
|
|
}
|
|
|
|
// carryOverJWTAssertion copies assertion JWT from request to response
|
|
// note that src keys are expected to be http.CanonicalHeaderKey
|
|
func carryOverJWTAssertion(dst http.Header, src map[string]string) {
|
|
jwtForKey := httputil.CanonicalHeaderKey(httputil.HeaderPomeriumJWTAssertionFor)
|
|
jwtFor, ok := src[jwtForKey]
|
|
if ok && jwtFor != "" {
|
|
dst.Add(jwtForKey, jwtFor)
|
|
return
|
|
}
|
|
jwtFor, ok = src[httputil.CanonicalHeaderKey(httputil.HeaderPomeriumJWTAssertion)]
|
|
if ok && jwtFor != "" {
|
|
dst.Add(jwtForKey, jwtFor)
|
|
}
|
|
}
|