authorize: cache warming (#5439)

* authorize: cache warming

* add Authorize to test?

* remove tracing querier

* only update connection when it changes
This commit is contained in:
Caleb Doxsey 2025-01-22 09:27:22 -07:00 committed by GitHub
parent b674d5c19d
commit 6e1fabec0b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 254 additions and 186 deletions

View file

@ -10,7 +10,10 @@ import (
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
@ -21,16 +24,16 @@ import (
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
oteltrace "go.opentelemetry.io/otel/trace"
) )
// Authorize struct holds // Authorize struct holds
type Authorize struct { type Authorize struct {
state *atomicutil.Value[*authorizeState] state *atomicutil.Value[*authorizeState]
store *store.Store store *store.Store
currentOptions *atomicutil.Value[*config.Options] currentOptions *atomicutil.Value[*config.Options]
accessTracker *AccessTracker accessTracker *AccessTracker
globalCache storage.Cache globalCache storage.Cache
groupsCacheWarmer *cacheWarmer
// The stateLock prevents updating the evaluator store simultaneously with an evaluation. // The stateLock prevents updating the evaluator store simultaneously with an evaluation.
// This should provide a consistent view of the data at a given server/record version and // This should provide a consistent view of the data at a given server/record version and
@ -60,6 +63,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
} }
a.state = atomicutil.NewValue(state) a.state = atomicutil.NewValue(state)
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, a.globalCache, directory.GroupRecordType)
return a, nil return a, nil
} }
@ -70,8 +74,16 @@ func (a *Authorize) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli
// Run runs the authorize service. // Run runs the authorize service.
func (a *Authorize) Run(ctx context.Context) error { func (a *Authorize) Run(ctx context.Context) error {
a.accessTracker.Run(ctx) eg, ctx := errgroup.WithContext(ctx)
return nil eg.Go(func() error {
a.accessTracker.Run(ctx)
return nil
})
eg.Go(func() error {
a.groupsCacheWarmer.Run(ctx)
return nil
})
return eg.Wait()
} }
func validateOptions(o *config.Options) error { func validateOptions(o *config.Options) error {
@ -150,9 +162,13 @@ func newPolicyEvaluator(
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) { func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
currentState := a.state.Load() currentState := a.state.Load()
a.currentOptions.Store(cfg.Options) a.currentOptions.Store(cfg.Options)
if state, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil { if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state") log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
} else { } else {
a.state.Store(state) a.state.Store(newState)
if currentState.dataBrokerClientConnection != newState.dataBrokerClientConnection {
a.groupsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
}
} }
} }

122
authorize/cache_warmer.go Normal file
View file

@ -0,0 +1,122 @@
package authorize
import (
"context"
"time"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
type cacheWarmer struct {
cc *grpc.ClientConn
cache storage.Cache
typeURL string
updatedCC chan *grpc.ClientConn
}
func newCacheWarmer(
cc *grpc.ClientConn,
cache storage.Cache,
typeURL string,
) *cacheWarmer {
return &cacheWarmer{
cc: cc,
cache: cache,
typeURL: typeURL,
updatedCC: make(chan *grpc.ClientConn, 1),
}
}
func (cw *cacheWarmer) UpdateConn(cc *grpc.ClientConn) {
for {
select {
case cw.updatedCC <- cc:
return
default:
}
select {
case <-cw.updatedCC:
default:
}
}
}
func (cw *cacheWarmer) Run(ctx context.Context) {
// Run a syncer for the cache warmer until the underlying databroker connection is changed.
// When that happens cancel the currently running syncer and start a new one.
runCtx, runCancel := context.WithCancel(ctx)
go cw.run(runCtx, cw.cc)
for {
select {
case <-ctx.Done():
runCancel()
return
case cc := <-cw.updatedCC:
log.Ctx(ctx).Info().Msg("cache-warmer: received updated databroker client connection, restarting syncer")
cw.cc = cc
runCancel()
runCtx, runCancel = context.WithCancel(ctx)
go cw.run(runCtx, cw.cc)
}
}
}
func (cw *cacheWarmer) run(ctx context.Context, cc *grpc.ClientConn) {
log.Ctx(ctx).Debug().Str("type-url", cw.typeURL).Msg("cache-warmer: running databroker syncer to warm cache")
_ = databroker.NewSyncer(ctx, "cache-warmer", cacheWarmerSyncerHandler{
client: databroker.NewDataBrokerServiceClient(cc),
cache: cw.cache,
}, databroker.WithTypeURL(cw.typeURL)).Run(ctx)
}
type cacheWarmerSyncerHandler struct {
client databroker.DataBrokerServiceClient
cache storage.Cache
}
func (h cacheWarmerSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return h.client
}
func (h cacheWarmerSyncerHandler) ClearRecords(_ context.Context) {
h.cache.InvalidateAll()
}
func (h cacheWarmerSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
for _, record := range records {
req := &databroker.QueryRequest{
Type: record.Type,
Limit: 1,
}
req.SetFilterByIDOrIndex(record.Id)
res := &databroker.QueryResponse{
Records: []*databroker.Record{record},
TotalCount: 1,
ServerVersion: serverVersion,
RecordVersion: record.Version,
}
expiry := time.Now().Add(time.Hour * 24 * 365)
key, err := storage.MarshalQueryRequest(req)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query request")
continue
}
value, err := storage.MarshalQueryResponse(res)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query response")
continue
}
h.cache.Set(expiry, key, value)
}
}

View file

@ -0,0 +1,52 @@
package authorize
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace/noop"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/testutil"
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
)
func TestCacheWarmer(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, 10*time.Minute)
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
})
t.Cleanup(func() { cc.Close() })
client := databrokerpb.NewDataBrokerServiceClient(cc)
_, err := client.Put(ctx, &databrokerpb.PutRequest{
Records: []*databrokerpb.Record{
{Type: "example.com/record", Id: "e1", Data: protoutil.NewAnyBool(true)},
{Type: "example.com/record", Id: "e2", Data: protoutil.NewAnyBool(true)},
},
})
require.NoError(t, err)
cache := storage.NewGlobalCache(time.Minute)
cw := newCacheWarmer(cc, cache, "example.com/record")
go cw.Run(ctx)
assert.Eventually(t, func() bool {
req := &databrokerpb.QueryRequest{
Type: "example.com/record",
Limit: 1,
}
req.SetFilterByIDOrIndex("e1")
res, err := storage.NewCachingQuerier(storage.NewStaticQuerier(), cache).Query(ctx, req)
require.NoError(t, err)
return len(res.GetRecords()) == 1
}, 10*time.Second, time.Millisecond*100)
}

View file

@ -37,10 +37,8 @@ func Test_getDataBrokerRecord(t *testing.T) {
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)} s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
sq := storage.NewStaticQuerier(s1) sq := storage.NewStaticQuerier(s1)
tsq := storage.NewTracingQuerier(sq) cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
cq := storage.NewCachingQuerier(tsq, storage.NewLocalCache()) qctx := storage.WithQuerier(ctx, cq)
tcq := storage.NewTracingQuerier(cq)
qctx := storage.WithQuerier(ctx, tcq)
s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
assert.NoError(t, err) assert.NoError(t, err)
@ -49,11 +47,6 @@ func Test_getDataBrokerRecord(t *testing.T) {
s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, s) assert.NotNil(t, s)
assert.Len(t, tsq.Traces(), tc.underlyingQueryCount,
"should have %d traces to the underlying querier", tc.underlyingQueryCount)
assert.Len(t, tcq.Traces(), tc.cachedQueryCount,
"should have %d traces to the cached querier", tc.cachedQueryCount)
}) })
} }
} }

View file

@ -31,14 +31,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check") ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check")
defer span.End() defer span.End()
querier := storage.NewTracingQuerier( querier := storage.NewCachingQuerier(
storage.NewCachingQuerier( storage.NewQuerier(a.state.Load().dataBrokerClient),
storage.NewCachingQuerier( a.globalCache,
storage.NewQuerier(a.state.Load().dataBrokerClient),
a.globalCache,
),
storage.NewLocalCache(),
),
) )
ctx = storage.WithQuerier(ctx, querier) ctx = storage.WithQuerier(ctx, querier)

View file

@ -12,13 +12,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@ -27,6 +20,14 @@ import (
sdktrace "go.opentelemetry.io/otel/sdk/trace" sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.17.0" semconv "go.opentelemetry.io/otel/semconv/v1.17.0"
oteltrace "go.opentelemetry.io/otel/trace" oteltrace "go.opentelemetry.io/otel/trace"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
) )
func otlpTraceReceiverOrFromEnv(t *testing.T) (modifier testenv.Modifier, newRemoteClient func() otlptrace.Client, getResults func() *TraceResults) { func otlpTraceReceiverOrFromEnv(t *testing.T) (modifier testenv.Modifier, newRemoteClient func() otlptrace.Client, getResults func() *TraceResults) {
@ -116,7 +117,7 @@ func TestOTLPTracing(t *testing.T) {
Exact: true, Exact: true,
CheckDetachedSpans: true, CheckDetachedSpans: true,
}, },
Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Test Environment", "Control Plane", "Data Broker"}}, Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Authorize", "Test Environment", "Control Plane", "Data Broker"}},
Match{Name: testEnvironmentAuthenticate, TraceCount: 1, Services: allServices}, Match{Name: testEnvironmentAuthenticate, TraceCount: 1, Services: allServices},
Match{Name: authenticateOAuth2Client, TraceCount: Greater(0)}, Match{Name: authenticateOAuth2Client, TraceCount: Greater(0)},
Match{Name: idpServerGetUserinfo, TraceCount: EqualToMatch(authenticateOAuth2Client)}, Match{Name: idpServerGetUserinfo, TraceCount: EqualToMatch(authenticateOAuth2Client)},

View file

@ -18,64 +18,8 @@ type Cache interface {
update func(ctx context.Context) ([]byte, error), update func(ctx context.Context) ([]byte, error),
) ([]byte, error) ) ([]byte, error)
Invalidate(key []byte) Invalidate(key []byte)
} InvalidateAll()
Set(expiry time.Time, key, value []byte)
type localCache struct {
singleflight singleflight.Group
mu sync.RWMutex
m map[string][]byte
}
// NewLocalCache creates a new Cache backed by a map.
func NewLocalCache() Cache {
return &localCache{
m: make(map[string][]byte),
}
}
func (cache *localCache) GetOrUpdate(
ctx context.Context,
key []byte,
update func(ctx context.Context) ([]byte, error),
) ([]byte, error) {
strkey := string(key)
cache.mu.RLock()
cached, ok := cache.m[strkey]
cache.mu.RUnlock()
if ok {
return cached, nil
}
v, err, _ := cache.singleflight.Do(strkey, func() (any, error) {
cache.mu.RLock()
cached, ok := cache.m[strkey]
cache.mu.RUnlock()
if ok {
return cached, nil
}
result, err := update(ctx)
if err != nil {
return nil, err
}
cache.mu.Lock()
cache.m[strkey] = result
cache.mu.Unlock()
return result, nil
})
if err != nil {
return nil, err
}
return v.([]byte), nil
}
func (cache *localCache) Invalidate(key []byte) {
cache.mu.Lock()
delete(cache.m, string(key))
cache.mu.Unlock()
} }
type globalCache struct { type globalCache struct {
@ -115,7 +59,7 @@ func (cache *globalCache) GetOrUpdate(
if err != nil { if err != nil {
return nil, err return nil, err
} }
cache.set(key, value) cache.set(time.Now().Add(cache.ttl), key, value)
return value, nil return value, nil
}) })
if err != nil { if err != nil {
@ -130,6 +74,16 @@ func (cache *globalCache) Invalidate(key []byte) {
cache.mu.Unlock() cache.mu.Unlock()
} }
func (cache *globalCache) InvalidateAll() {
cache.mu.Lock()
cache.fastcache.Reset()
cache.mu.Unlock()
}
func (cache *globalCache) Set(expiry time.Time, key, value []byte) {
cache.set(expiry, key, value)
}
func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) { func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) {
cache.mu.RLock() cache.mu.RLock()
item := cache.fastcache.Get(nil, k) item := cache.fastcache.Get(nil, k)
@ -143,13 +97,13 @@ func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool)
return data, expiry, true return data, expiry, true
} }
func (cache *globalCache) set(k, v []byte) { func (cache *globalCache) set(expiry time.Time, key, value []byte) {
unix := time.Now().Add(cache.ttl).UnixMilli() unix := expiry.UnixMilli()
item := make([]byte, len(v)+8) item := make([]byte, len(value)+8)
binary.LittleEndian.PutUint64(item, uint64(unix)) binary.LittleEndian.PutUint64(item, uint64(unix))
copy(item[8:], v) copy(item[8:], value)
cache.mu.Lock() cache.mu.Lock()
cache.fastcache.Set(k, item) cache.fastcache.Set(key, item)
cache.mu.Unlock() cache.mu.Unlock()
} }

View file

@ -9,34 +9,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestLocalCache(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
callCount := 0
update := func(_ context.Context) ([]byte, error) {
callCount++
return []byte("v1"), nil
}
c := NewLocalCache()
v, err := c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 1, callCount)
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 1, callCount)
c.Invalidate([]byte("k1"))
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 2, callCount)
}
func TestGlobalCache(t *testing.T) { func TestGlobalCache(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout() defer clearTimeout()
@ -70,4 +42,10 @@ func TestGlobalCache(t *testing.T) {
}) })
return err != nil return err != nil
}, time.Second, time.Millisecond*10, "should honor TTL") }, time.Second, time.Millisecond*10, "should honor TTL")
c.Set(time.Now().Add(time.Hour), []byte("k2"), []byte("v2"))
v, err = c.GetOrUpdate(ctx, []byte("k2"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v2"), v)
assert.Equal(t, 2, callCount)
} }

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"strconv" "strconv"
"sync"
"github.com/google/uuid" "github.com/google/uuid"
grpc "google.golang.org/grpc" grpc "google.golang.org/grpc"
@ -12,7 +11,6 @@ import (
status "google.golang.org/grpc/status" status "google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb" timestamppb "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
@ -162,59 +160,6 @@ func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
return q.client.Query(ctx, in, opts...) return q.client.Query(ctx, in, opts...)
} }
// A TracingQuerier records calls to Query.
type TracingQuerier struct {
underlying Querier
mu sync.Mutex
traces []QueryTrace
}
// A QueryTrace traces a call to Query.
type QueryTrace struct {
ServerVersion, RecordVersion uint64
RecordType string
Query string
Filter *structpb.Struct
}
// NewTracingQuerier creates a new TracingQuerier.
func NewTracingQuerier(q Querier) *TracingQuerier {
return &TracingQuerier{
underlying: q,
}
}
// InvalidateCache invalidates the cache.
func (q *TracingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {
q.underlying.InvalidateCache(ctx, in)
}
// Query queries for records.
func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
res, err := q.underlying.Query(ctx, in, opts...)
if err == nil {
q.mu.Lock()
q.traces = append(q.traces, QueryTrace{
RecordType: in.GetType(),
Query: in.GetQuery(),
Filter: in.GetFilter(),
})
q.mu.Unlock()
}
return res, err
}
// Traces returns all the traces.
func (q *TracingQuerier) Traces() []QueryTrace {
q.mu.Lock()
traces := make([]QueryTrace, len(q.traces))
copy(traces, q.traces)
q.mu.Unlock()
return traces
}
type cachingQuerier struct { type cachingQuerier struct {
q Querier q Querier
cache Cache cache Cache
@ -240,9 +185,7 @@ func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.Que
} }
func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
key, err := (&proto.MarshalOptions{ key, err := MarshalQueryRequest(in)
Deterministic: true,
}).Marshal(in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -252,7 +195,7 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
if err != nil { if err != nil {
return nil, err return nil, err
} }
return proto.Marshal(res) return MarshalQueryResponse(res)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -265,3 +208,17 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
} }
return &res, nil return &res, nil
} }
// MarshalQueryRequest marshales the query request.
func MarshalQueryRequest(req *databroker.QueryRequest) ([]byte, error) {
return (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(req)
}
// MarshalQueryResponse marshals the query response.
func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) {
return (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(res)
}