diff --git a/authorize/authorize.go b/authorize/authorize.go index 4cd3dd496..01960792f 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -10,7 +10,10 @@ import ( "time" "github.com/rs/zerolog" + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/errgroup" + "github.com/pomerium/datasource/pkg/directory" "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" @@ -21,16 +24,16 @@ import ( "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" - oteltrace "go.opentelemetry.io/otel/trace" ) // Authorize struct holds type Authorize struct { - state *atomicutil.Value[*authorizeState] - store *store.Store - currentOptions *atomicutil.Value[*config.Options] - accessTracker *AccessTracker - globalCache storage.Cache + state *atomicutil.Value[*authorizeState] + store *store.Store + currentOptions *atomicutil.Value[*config.Options] + accessTracker *AccessTracker + globalCache storage.Cache + groupsCacheWarmer *cacheWarmer // The stateLock prevents updating the evaluator store simultaneously with an evaluation. // This should provide a consistent view of the data at a given server/record version and @@ -60,6 +63,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) { } a.state = atomicutil.NewValue(state) + a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, a.globalCache, directory.GroupRecordType) return a, nil } @@ -70,8 +74,16 @@ func (a *Authorize) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli // Run runs the authorize service. func (a *Authorize) Run(ctx context.Context) error { - a.accessTracker.Run(ctx) - return nil + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + a.accessTracker.Run(ctx) + return nil + }) + eg.Go(func() error { + a.groupsCacheWarmer.Run(ctx) + return nil + }) + return eg.Wait() } func validateOptions(o *config.Options) error { @@ -150,9 +162,13 @@ func newPolicyEvaluator( func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) { currentState := a.state.Load() a.currentOptions.Store(cfg.Options) - if state, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil { + if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil { log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state") } else { - a.state.Store(state) + a.state.Store(newState) + + if currentState.dataBrokerClientConnection != newState.dataBrokerClientConnection { + a.groupsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection) + } } } diff --git a/authorize/cache_warmer.go b/authorize/cache_warmer.go new file mode 100644 index 000000000..41c1d0ae2 --- /dev/null +++ b/authorize/cache_warmer.go @@ -0,0 +1,122 @@ +package authorize + +import ( + "context" + "time" + + "google.golang.org/grpc" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/storage" +) + +type cacheWarmer struct { + cc *grpc.ClientConn + cache storage.Cache + typeURL string + + updatedCC chan *grpc.ClientConn +} + +func newCacheWarmer( + cc *grpc.ClientConn, + cache storage.Cache, + typeURL string, +) *cacheWarmer { + return &cacheWarmer{ + cc: cc, + cache: cache, + typeURL: typeURL, + + updatedCC: make(chan *grpc.ClientConn, 1), + } +} + +func (cw *cacheWarmer) UpdateConn(cc *grpc.ClientConn) { + for { + select { + case cw.updatedCC <- cc: + return + default: + } + select { + case <-cw.updatedCC: + default: + } + } +} + +func (cw *cacheWarmer) Run(ctx context.Context) { + // Run a syncer for the cache warmer until the underlying databroker connection is changed. + // When that happens cancel the currently running syncer and start a new one. + + runCtx, runCancel := context.WithCancel(ctx) + go cw.run(runCtx, cw.cc) + + for { + select { + case <-ctx.Done(): + runCancel() + return + case cc := <-cw.updatedCC: + log.Ctx(ctx).Info().Msg("cache-warmer: received updated databroker client connection, restarting syncer") + cw.cc = cc + runCancel() + runCtx, runCancel = context.WithCancel(ctx) + go cw.run(runCtx, cw.cc) + } + } +} + +func (cw *cacheWarmer) run(ctx context.Context, cc *grpc.ClientConn) { + log.Ctx(ctx).Debug().Str("type-url", cw.typeURL).Msg("cache-warmer: running databroker syncer to warm cache") + _ = databroker.NewSyncer(ctx, "cache-warmer", cacheWarmerSyncerHandler{ + client: databroker.NewDataBrokerServiceClient(cc), + cache: cw.cache, + }, databroker.WithTypeURL(cw.typeURL)).Run(ctx) +} + +type cacheWarmerSyncerHandler struct { + client databroker.DataBrokerServiceClient + cache storage.Cache +} + +func (h cacheWarmerSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { + return h.client +} + +func (h cacheWarmerSyncerHandler) ClearRecords(_ context.Context) { + h.cache.InvalidateAll() +} + +func (h cacheWarmerSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) { + for _, record := range records { + req := &databroker.QueryRequest{ + Type: record.Type, + Limit: 1, + } + req.SetFilterByIDOrIndex(record.Id) + + res := &databroker.QueryResponse{ + Records: []*databroker.Record{record}, + TotalCount: 1, + ServerVersion: serverVersion, + RecordVersion: record.Version, + } + + expiry := time.Now().Add(time.Hour * 24 * 365) + key, err := storage.MarshalQueryRequest(req) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query request") + continue + } + value, err := storage.MarshalQueryResponse(res) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query response") + continue + } + + h.cache.Set(expiry, key, value) + } +} diff --git a/authorize/cache_warmer_test.go b/authorize/cache_warmer_test.go new file mode 100644 index 000000000..58df29018 --- /dev/null +++ b/authorize/cache_warmer_test.go @@ -0,0 +1,52 @@ +package authorize + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" + "google.golang.org/grpc" + + "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/internal/testutil" + databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" + "github.com/pomerium/pomerium/pkg/storage" +) + +func TestCacheWarmer(t *testing.T) { + t.Parallel() + + ctx := testutil.GetContext(t, 10*time.Minute) + cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { + databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider())) + }) + t.Cleanup(func() { cc.Close() }) + + client := databrokerpb.NewDataBrokerServiceClient(cc) + _, err := client.Put(ctx, &databrokerpb.PutRequest{ + Records: []*databrokerpb.Record{ + {Type: "example.com/record", Id: "e1", Data: protoutil.NewAnyBool(true)}, + {Type: "example.com/record", Id: "e2", Data: protoutil.NewAnyBool(true)}, + }, + }) + require.NoError(t, err) + + cache := storage.NewGlobalCache(time.Minute) + + cw := newCacheWarmer(cc, cache, "example.com/record") + go cw.Run(ctx) + + assert.Eventually(t, func() bool { + req := &databrokerpb.QueryRequest{ + Type: "example.com/record", + Limit: 1, + } + req.SetFilterByIDOrIndex("e1") + res, err := storage.NewCachingQuerier(storage.NewStaticQuerier(), cache).Query(ctx, req) + require.NoError(t, err) + return len(res.GetRecords()) == 1 + }, 10*time.Second, time.Millisecond*100) +} diff --git a/authorize/databroker_test.go b/authorize/databroker_test.go index 1c773fc21..eefb5987b 100644 --- a/authorize/databroker_test.go +++ b/authorize/databroker_test.go @@ -37,10 +37,8 @@ func Test_getDataBrokerRecord(t *testing.T) { s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)} sq := storage.NewStaticQuerier(s1) - tsq := storage.NewTracingQuerier(sq) - cq := storage.NewCachingQuerier(tsq, storage.NewLocalCache()) - tcq := storage.NewTracingQuerier(cq) - qctx := storage.WithQuerier(ctx, tcq) + cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute)) + qctx := storage.WithQuerier(ctx, cq) s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) assert.NoError(t, err) @@ -49,11 +47,6 @@ func Test_getDataBrokerRecord(t *testing.T) { s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) assert.NoError(t, err) assert.NotNil(t, s) - - assert.Len(t, tsq.Traces(), tc.underlyingQueryCount, - "should have %d traces to the underlying querier", tc.underlyingQueryCount) - assert.Len(t, tcq.Traces(), tc.cachedQueryCount, - "should have %d traces to the cached querier", tc.cachedQueryCount) }) } } diff --git a/authorize/grpc.go b/authorize/grpc.go index 8172226f9..69b26a167 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -31,14 +31,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check") defer span.End() - querier := storage.NewTracingQuerier( - storage.NewCachingQuerier( - storage.NewCachingQuerier( - storage.NewQuerier(a.state.Load().dataBrokerClient), - a.globalCache, - ), - storage.NewLocalCache(), - ), + querier := storage.NewCachingQuerier( + storage.NewQuerier(a.state.Load().dataBrokerClient), + a.globalCache, ) ctx = storage.WithQuerier(ctx, querier) diff --git a/internal/testenv/selftests/tracing_test.go b/internal/testenv/selftests/tracing_test.go index 0a799d872..4e59e1d6c 100644 --- a/internal/testenv/selftests/tracing_test.go +++ b/internal/testenv/selftests/tracing_test.go @@ -12,13 +12,6 @@ import ( "testing" "time" - "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/internal/telemetry/trace" - "github.com/pomerium/pomerium/internal/testenv" - "github.com/pomerium/pomerium/internal/testenv/scenarios" - "github.com/pomerium/pomerium/internal/testenv/snippets" - "github.com/pomerium/pomerium/internal/testenv/upstreams" - . "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -27,6 +20,14 @@ import ( sdktrace "go.opentelemetry.io/otel/sdk/trace" semconv "go.opentelemetry.io/otel/semconv/v1.17.0" oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/scenarios" + "github.com/pomerium/pomerium/internal/testenv/snippets" + "github.com/pomerium/pomerium/internal/testenv/upstreams" + . "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive ) func otlpTraceReceiverOrFromEnv(t *testing.T) (modifier testenv.Modifier, newRemoteClient func() otlptrace.Client, getResults func() *TraceResults) { @@ -116,7 +117,7 @@ func TestOTLPTracing(t *testing.T) { Exact: true, CheckDetachedSpans: true, }, - Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Test Environment", "Control Plane", "Data Broker"}}, + Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Authorize", "Test Environment", "Control Plane", "Data Broker"}}, Match{Name: testEnvironmentAuthenticate, TraceCount: 1, Services: allServices}, Match{Name: authenticateOAuth2Client, TraceCount: Greater(0)}, Match{Name: idpServerGetUserinfo, TraceCount: EqualToMatch(authenticateOAuth2Client)}, diff --git a/pkg/storage/cache.go b/pkg/storage/cache.go index ee75cb631..1f0e391b4 100644 --- a/pkg/storage/cache.go +++ b/pkg/storage/cache.go @@ -18,64 +18,8 @@ type Cache interface { update func(ctx context.Context) ([]byte, error), ) ([]byte, error) Invalidate(key []byte) -} - -type localCache struct { - singleflight singleflight.Group - mu sync.RWMutex - m map[string][]byte -} - -// NewLocalCache creates a new Cache backed by a map. -func NewLocalCache() Cache { - return &localCache{ - m: make(map[string][]byte), - } -} - -func (cache *localCache) GetOrUpdate( - ctx context.Context, - key []byte, - update func(ctx context.Context) ([]byte, error), -) ([]byte, error) { - strkey := string(key) - - cache.mu.RLock() - cached, ok := cache.m[strkey] - cache.mu.RUnlock() - if ok { - return cached, nil - } - - v, err, _ := cache.singleflight.Do(strkey, func() (any, error) { - cache.mu.RLock() - cached, ok := cache.m[strkey] - cache.mu.RUnlock() - if ok { - return cached, nil - } - - result, err := update(ctx) - if err != nil { - return nil, err - } - - cache.mu.Lock() - cache.m[strkey] = result - cache.mu.Unlock() - - return result, nil - }) - if err != nil { - return nil, err - } - return v.([]byte), nil -} - -func (cache *localCache) Invalidate(key []byte) { - cache.mu.Lock() - delete(cache.m, string(key)) - cache.mu.Unlock() + InvalidateAll() + Set(expiry time.Time, key, value []byte) } type globalCache struct { @@ -115,7 +59,7 @@ func (cache *globalCache) GetOrUpdate( if err != nil { return nil, err } - cache.set(key, value) + cache.set(time.Now().Add(cache.ttl), key, value) return value, nil }) if err != nil { @@ -130,6 +74,16 @@ func (cache *globalCache) Invalidate(key []byte) { cache.mu.Unlock() } +func (cache *globalCache) InvalidateAll() { + cache.mu.Lock() + cache.fastcache.Reset() + cache.mu.Unlock() +} + +func (cache *globalCache) Set(expiry time.Time, key, value []byte) { + cache.set(expiry, key, value) +} + func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) { cache.mu.RLock() item := cache.fastcache.Get(nil, k) @@ -143,13 +97,13 @@ func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) return data, expiry, true } -func (cache *globalCache) set(k, v []byte) { - unix := time.Now().Add(cache.ttl).UnixMilli() - item := make([]byte, len(v)+8) +func (cache *globalCache) set(expiry time.Time, key, value []byte) { + unix := expiry.UnixMilli() + item := make([]byte, len(value)+8) binary.LittleEndian.PutUint64(item, uint64(unix)) - copy(item[8:], v) + copy(item[8:], value) cache.mu.Lock() - cache.fastcache.Set(k, item) + cache.fastcache.Set(key, item) cache.mu.Unlock() } diff --git a/pkg/storage/cache_test.go b/pkg/storage/cache_test.go index 132dc201a..e33cf8dc7 100644 --- a/pkg/storage/cache_test.go +++ b/pkg/storage/cache_test.go @@ -9,34 +9,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestLocalCache(t *testing.T) { - ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) - defer clearTimeout() - - callCount := 0 - update := func(_ context.Context) ([]byte, error) { - callCount++ - return []byte("v1"), nil - } - c := NewLocalCache() - v, err := c.GetOrUpdate(ctx, []byte("k1"), update) - assert.NoError(t, err) - assert.Equal(t, []byte("v1"), v) - assert.Equal(t, 1, callCount) - - v, err = c.GetOrUpdate(ctx, []byte("k1"), update) - assert.NoError(t, err) - assert.Equal(t, []byte("v1"), v) - assert.Equal(t, 1, callCount) - - c.Invalidate([]byte("k1")) - - v, err = c.GetOrUpdate(ctx, []byte("k1"), update) - assert.NoError(t, err) - assert.Equal(t, []byte("v1"), v) - assert.Equal(t, 2, callCount) -} - func TestGlobalCache(t *testing.T) { ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) defer clearTimeout() @@ -70,4 +42,10 @@ func TestGlobalCache(t *testing.T) { }) return err != nil }, time.Second, time.Millisecond*10, "should honor TTL") + + c.Set(time.Now().Add(time.Hour), []byte("k2"), []byte("v2")) + v, err = c.GetOrUpdate(ctx, []byte("k2"), update) + assert.NoError(t, err) + assert.Equal(t, []byte("v2"), v) + assert.Equal(t, 2, callCount) } diff --git a/pkg/storage/querier.go b/pkg/storage/querier.go index 1de0b195d..c70247e5c 100644 --- a/pkg/storage/querier.go +++ b/pkg/storage/querier.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "strconv" - "sync" "github.com/google/uuid" grpc "google.golang.org/grpc" @@ -12,7 +11,6 @@ import ( status "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/structpb" timestamppb "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/pkg/cryptutil" @@ -162,59 +160,6 @@ func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest, return q.client.Query(ctx, in, opts...) } -// A TracingQuerier records calls to Query. -type TracingQuerier struct { - underlying Querier - - mu sync.Mutex - traces []QueryTrace -} - -// A QueryTrace traces a call to Query. -type QueryTrace struct { - ServerVersion, RecordVersion uint64 - - RecordType string - Query string - Filter *structpb.Struct -} - -// NewTracingQuerier creates a new TracingQuerier. -func NewTracingQuerier(q Querier) *TracingQuerier { - return &TracingQuerier{ - underlying: q, - } -} - -// InvalidateCache invalidates the cache. -func (q *TracingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) { - q.underlying.InvalidateCache(ctx, in) -} - -// Query queries for records. -func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { - res, err := q.underlying.Query(ctx, in, opts...) - if err == nil { - q.mu.Lock() - q.traces = append(q.traces, QueryTrace{ - RecordType: in.GetType(), - Query: in.GetQuery(), - Filter: in.GetFilter(), - }) - q.mu.Unlock() - } - return res, err -} - -// Traces returns all the traces. -func (q *TracingQuerier) Traces() []QueryTrace { - q.mu.Lock() - traces := make([]QueryTrace, len(q.traces)) - copy(traces, q.traces) - q.mu.Unlock() - return traces -} - type cachingQuerier struct { q Querier cache Cache @@ -240,9 +185,7 @@ func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.Que } func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { - key, err := (&proto.MarshalOptions{ - Deterministic: true, - }).Marshal(in) + key, err := MarshalQueryRequest(in) if err != nil { return nil, err } @@ -252,7 +195,7 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, if err != nil { return nil, err } - return proto.Marshal(res) + return MarshalQueryResponse(res) }) if err != nil { return nil, err @@ -265,3 +208,17 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, } return &res, nil } + +// MarshalQueryRequest marshales the query request. +func MarshalQueryRequest(req *databroker.QueryRequest) ([]byte, error) { + return (&proto.MarshalOptions{ + Deterministic: true, + }).Marshal(req) +} + +// MarshalQueryResponse marshals the query response. +func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) { + return (&proto.MarshalOptions{ + Deterministic: true, + }).Marshal(res) +}