authorize: cache warming (#5439)

* authorize: cache warming

* add Authorize to test?

* remove tracing querier

* only update connection when it changes
This commit is contained in:
Caleb Doxsey 2025-01-22 09:27:22 -07:00 committed by GitHub
parent b674d5c19d
commit 6e1fabec0b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 254 additions and 186 deletions

View file

@ -10,7 +10,10 @@ import (
"time"
"github.com/rs/zerolog"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config"
@ -21,16 +24,16 @@ import (
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
oteltrace "go.opentelemetry.io/otel/trace"
)
// Authorize struct holds
type Authorize struct {
state *atomicutil.Value[*authorizeState]
store *store.Store
currentOptions *atomicutil.Value[*config.Options]
accessTracker *AccessTracker
globalCache storage.Cache
state *atomicutil.Value[*authorizeState]
store *store.Store
currentOptions *atomicutil.Value[*config.Options]
accessTracker *AccessTracker
globalCache storage.Cache
groupsCacheWarmer *cacheWarmer
// The stateLock prevents updating the evaluator store simultaneously with an evaluation.
// This should provide a consistent view of the data at a given server/record version and
@ -60,6 +63,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
}
a.state = atomicutil.NewValue(state)
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, a.globalCache, directory.GroupRecordType)
return a, nil
}
@ -70,8 +74,16 @@ func (a *Authorize) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli
// Run runs the authorize service.
func (a *Authorize) Run(ctx context.Context) error {
a.accessTracker.Run(ctx)
return nil
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
a.accessTracker.Run(ctx)
return nil
})
eg.Go(func() error {
a.groupsCacheWarmer.Run(ctx)
return nil
})
return eg.Wait()
}
func validateOptions(o *config.Options) error {
@ -150,9 +162,13 @@ func newPolicyEvaluator(
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
currentState := a.state.Load()
a.currentOptions.Store(cfg.Options)
if state, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
} else {
a.state.Store(state)
a.state.Store(newState)
if currentState.dataBrokerClientConnection != newState.dataBrokerClientConnection {
a.groupsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
}
}
}

122
authorize/cache_warmer.go Normal file
View file

@ -0,0 +1,122 @@
package authorize
import (
"context"
"time"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
type cacheWarmer struct {
cc *grpc.ClientConn
cache storage.Cache
typeURL string
updatedCC chan *grpc.ClientConn
}
func newCacheWarmer(
cc *grpc.ClientConn,
cache storage.Cache,
typeURL string,
) *cacheWarmer {
return &cacheWarmer{
cc: cc,
cache: cache,
typeURL: typeURL,
updatedCC: make(chan *grpc.ClientConn, 1),
}
}
func (cw *cacheWarmer) UpdateConn(cc *grpc.ClientConn) {
for {
select {
case cw.updatedCC <- cc:
return
default:
}
select {
case <-cw.updatedCC:
default:
}
}
}
func (cw *cacheWarmer) Run(ctx context.Context) {
// Run a syncer for the cache warmer until the underlying databroker connection is changed.
// When that happens cancel the currently running syncer and start a new one.
runCtx, runCancel := context.WithCancel(ctx)
go cw.run(runCtx, cw.cc)
for {
select {
case <-ctx.Done():
runCancel()
return
case cc := <-cw.updatedCC:
log.Ctx(ctx).Info().Msg("cache-warmer: received updated databroker client connection, restarting syncer")
cw.cc = cc
runCancel()
runCtx, runCancel = context.WithCancel(ctx)
go cw.run(runCtx, cw.cc)
}
}
}
func (cw *cacheWarmer) run(ctx context.Context, cc *grpc.ClientConn) {
log.Ctx(ctx).Debug().Str("type-url", cw.typeURL).Msg("cache-warmer: running databroker syncer to warm cache")
_ = databroker.NewSyncer(ctx, "cache-warmer", cacheWarmerSyncerHandler{
client: databroker.NewDataBrokerServiceClient(cc),
cache: cw.cache,
}, databroker.WithTypeURL(cw.typeURL)).Run(ctx)
}
type cacheWarmerSyncerHandler struct {
client databroker.DataBrokerServiceClient
cache storage.Cache
}
func (h cacheWarmerSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return h.client
}
func (h cacheWarmerSyncerHandler) ClearRecords(_ context.Context) {
h.cache.InvalidateAll()
}
func (h cacheWarmerSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
for _, record := range records {
req := &databroker.QueryRequest{
Type: record.Type,
Limit: 1,
}
req.SetFilterByIDOrIndex(record.Id)
res := &databroker.QueryResponse{
Records: []*databroker.Record{record},
TotalCount: 1,
ServerVersion: serverVersion,
RecordVersion: record.Version,
}
expiry := time.Now().Add(time.Hour * 24 * 365)
key, err := storage.MarshalQueryRequest(req)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query request")
continue
}
value, err := storage.MarshalQueryResponse(res)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query response")
continue
}
h.cache.Set(expiry, key, value)
}
}

View file

@ -0,0 +1,52 @@
package authorize
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace/noop"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/testutil"
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
)
func TestCacheWarmer(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, 10*time.Minute)
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
})
t.Cleanup(func() { cc.Close() })
client := databrokerpb.NewDataBrokerServiceClient(cc)
_, err := client.Put(ctx, &databrokerpb.PutRequest{
Records: []*databrokerpb.Record{
{Type: "example.com/record", Id: "e1", Data: protoutil.NewAnyBool(true)},
{Type: "example.com/record", Id: "e2", Data: protoutil.NewAnyBool(true)},
},
})
require.NoError(t, err)
cache := storage.NewGlobalCache(time.Minute)
cw := newCacheWarmer(cc, cache, "example.com/record")
go cw.Run(ctx)
assert.Eventually(t, func() bool {
req := &databrokerpb.QueryRequest{
Type: "example.com/record",
Limit: 1,
}
req.SetFilterByIDOrIndex("e1")
res, err := storage.NewCachingQuerier(storage.NewStaticQuerier(), cache).Query(ctx, req)
require.NoError(t, err)
return len(res.GetRecords()) == 1
}, 10*time.Second, time.Millisecond*100)
}

View file

@ -37,10 +37,8 @@ func Test_getDataBrokerRecord(t *testing.T) {
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
sq := storage.NewStaticQuerier(s1)
tsq := storage.NewTracingQuerier(sq)
cq := storage.NewCachingQuerier(tsq, storage.NewLocalCache())
tcq := storage.NewTracingQuerier(cq)
qctx := storage.WithQuerier(ctx, tcq)
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
qctx := storage.WithQuerier(ctx, cq)
s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
assert.NoError(t, err)
@ -49,11 +47,6 @@ func Test_getDataBrokerRecord(t *testing.T) {
s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
assert.NoError(t, err)
assert.NotNil(t, s)
assert.Len(t, tsq.Traces(), tc.underlyingQueryCount,
"should have %d traces to the underlying querier", tc.underlyingQueryCount)
assert.Len(t, tcq.Traces(), tc.cachedQueryCount,
"should have %d traces to the cached querier", tc.cachedQueryCount)
})
}
}

View file

@ -31,14 +31,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check")
defer span.End()
querier := storage.NewTracingQuerier(
storage.NewCachingQuerier(
storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
a.globalCache,
),
storage.NewLocalCache(),
),
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
a.globalCache,
)
ctx = storage.WithQuerier(ctx, querier)

View file

@ -12,13 +12,6 @@ import (
"testing"
"time"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@ -27,6 +20,14 @@ import (
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.17.0"
oteltrace "go.opentelemetry.io/otel/trace"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
)
func otlpTraceReceiverOrFromEnv(t *testing.T) (modifier testenv.Modifier, newRemoteClient func() otlptrace.Client, getResults func() *TraceResults) {
@ -116,7 +117,7 @@ func TestOTLPTracing(t *testing.T) {
Exact: true,
CheckDetachedSpans: true,
},
Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Test Environment", "Control Plane", "Data Broker"}},
Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Authorize", "Test Environment", "Control Plane", "Data Broker"}},
Match{Name: testEnvironmentAuthenticate, TraceCount: 1, Services: allServices},
Match{Name: authenticateOAuth2Client, TraceCount: Greater(0)},
Match{Name: idpServerGetUserinfo, TraceCount: EqualToMatch(authenticateOAuth2Client)},

View file

@ -18,64 +18,8 @@ type Cache interface {
update func(ctx context.Context) ([]byte, error),
) ([]byte, error)
Invalidate(key []byte)
}
type localCache struct {
singleflight singleflight.Group
mu sync.RWMutex
m map[string][]byte
}
// NewLocalCache creates a new Cache backed by a map.
func NewLocalCache() Cache {
return &localCache{
m: make(map[string][]byte),
}
}
func (cache *localCache) GetOrUpdate(
ctx context.Context,
key []byte,
update func(ctx context.Context) ([]byte, error),
) ([]byte, error) {
strkey := string(key)
cache.mu.RLock()
cached, ok := cache.m[strkey]
cache.mu.RUnlock()
if ok {
return cached, nil
}
v, err, _ := cache.singleflight.Do(strkey, func() (any, error) {
cache.mu.RLock()
cached, ok := cache.m[strkey]
cache.mu.RUnlock()
if ok {
return cached, nil
}
result, err := update(ctx)
if err != nil {
return nil, err
}
cache.mu.Lock()
cache.m[strkey] = result
cache.mu.Unlock()
return result, nil
})
if err != nil {
return nil, err
}
return v.([]byte), nil
}
func (cache *localCache) Invalidate(key []byte) {
cache.mu.Lock()
delete(cache.m, string(key))
cache.mu.Unlock()
InvalidateAll()
Set(expiry time.Time, key, value []byte)
}
type globalCache struct {
@ -115,7 +59,7 @@ func (cache *globalCache) GetOrUpdate(
if err != nil {
return nil, err
}
cache.set(key, value)
cache.set(time.Now().Add(cache.ttl), key, value)
return value, nil
})
if err != nil {
@ -130,6 +74,16 @@ func (cache *globalCache) Invalidate(key []byte) {
cache.mu.Unlock()
}
func (cache *globalCache) InvalidateAll() {
cache.mu.Lock()
cache.fastcache.Reset()
cache.mu.Unlock()
}
func (cache *globalCache) Set(expiry time.Time, key, value []byte) {
cache.set(expiry, key, value)
}
func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) {
cache.mu.RLock()
item := cache.fastcache.Get(nil, k)
@ -143,13 +97,13 @@ func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool)
return data, expiry, true
}
func (cache *globalCache) set(k, v []byte) {
unix := time.Now().Add(cache.ttl).UnixMilli()
item := make([]byte, len(v)+8)
func (cache *globalCache) set(expiry time.Time, key, value []byte) {
unix := expiry.UnixMilli()
item := make([]byte, len(value)+8)
binary.LittleEndian.PutUint64(item, uint64(unix))
copy(item[8:], v)
copy(item[8:], value)
cache.mu.Lock()
cache.fastcache.Set(k, item)
cache.fastcache.Set(key, item)
cache.mu.Unlock()
}

View file

@ -9,34 +9,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestLocalCache(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
callCount := 0
update := func(_ context.Context) ([]byte, error) {
callCount++
return []byte("v1"), nil
}
c := NewLocalCache()
v, err := c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 1, callCount)
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 1, callCount)
c.Invalidate([]byte("k1"))
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v1"), v)
assert.Equal(t, 2, callCount)
}
func TestGlobalCache(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
@ -70,4 +42,10 @@ func TestGlobalCache(t *testing.T) {
})
return err != nil
}, time.Second, time.Millisecond*10, "should honor TTL")
c.Set(time.Now().Add(time.Hour), []byte("k2"), []byte("v2"))
v, err = c.GetOrUpdate(ctx, []byte("k2"), update)
assert.NoError(t, err)
assert.Equal(t, []byte("v2"), v)
assert.Equal(t, 2, callCount)
}

View file

@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"strconv"
"sync"
"github.com/google/uuid"
grpc "google.golang.org/grpc"
@ -12,7 +11,6 @@ import (
status "google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/cryptutil"
@ -162,59 +160,6 @@ func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
return q.client.Query(ctx, in, opts...)
}
// A TracingQuerier records calls to Query.
type TracingQuerier struct {
underlying Querier
mu sync.Mutex
traces []QueryTrace
}
// A QueryTrace traces a call to Query.
type QueryTrace struct {
ServerVersion, RecordVersion uint64
RecordType string
Query string
Filter *structpb.Struct
}
// NewTracingQuerier creates a new TracingQuerier.
func NewTracingQuerier(q Querier) *TracingQuerier {
return &TracingQuerier{
underlying: q,
}
}
// InvalidateCache invalidates the cache.
func (q *TracingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {
q.underlying.InvalidateCache(ctx, in)
}
// Query queries for records.
func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
res, err := q.underlying.Query(ctx, in, opts...)
if err == nil {
q.mu.Lock()
q.traces = append(q.traces, QueryTrace{
RecordType: in.GetType(),
Query: in.GetQuery(),
Filter: in.GetFilter(),
})
q.mu.Unlock()
}
return res, err
}
// Traces returns all the traces.
func (q *TracingQuerier) Traces() []QueryTrace {
q.mu.Lock()
traces := make([]QueryTrace, len(q.traces))
copy(traces, q.traces)
q.mu.Unlock()
return traces
}
type cachingQuerier struct {
q Querier
cache Cache
@ -240,9 +185,7 @@ func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.Que
}
func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
key, err := (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(in)
key, err := MarshalQueryRequest(in)
if err != nil {
return nil, err
}
@ -252,7 +195,7 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
if err != nil {
return nil, err
}
return proto.Marshal(res)
return MarshalQueryResponse(res)
})
if err != nil {
return nil, err
@ -265,3 +208,17 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
}
return &res, nil
}
// MarshalQueryRequest marshales the query request.
func MarshalQueryRequest(req *databroker.QueryRequest) ([]byte, error) {
return (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(req)
}
// MarshalQueryResponse marshals the query response.
func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) {
return (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(res)
}