From c9117a0274a935d795ab37dad07cb34a7b30d1b4 Mon Sep 17 00:00:00 2001 From: Joe Kralicky Date: Wed, 20 Nov 2024 16:29:33 -0500 Subject: [PATCH] authorize: use pooled query request objects with more efficient cache keys --- authorize/databroker.go | 17 +++--- authorize/internal/store/request_pool.go | 70 ++++++++++++++++++++++++ authorize/internal/store/store.go | 13 ++--- pkg/storage/querier.go | 46 ++++++++++++---- 4 files changed, 120 insertions(+), 26 deletions(-) create mode 100644 authorize/internal/store/request_pool.go diff --git a/authorize/databroker.go b/authorize/databroker.go index a65db4e27..29962d9d9 100644 --- a/authorize/databroker.go +++ b/authorize/databroker.go @@ -3,6 +3,7 @@ package authorize import ( "context" + "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" @@ -24,13 +25,13 @@ func getDataBrokerRecord( ) (*databroker.Record, error) { q := storage.GetQuerier(ctx) - req := &databroker.QueryRequest{ - Type: recordType, - Limit: 1, - } - req.SetFilterByIDOrIndex(recordID) + req := store.GetPooledQueryRequest() + req.SetRecordType(recordType) + req.SetIDOrIndex(recordID) + ctx = storage.ContextWithCacheKey(ctx, req.CacheKey()) + defer req.Release() - res, err := q.Query(ctx, req) + res, err := q.Query(ctx, req.Request()) if err != nil { return nil, err } @@ -40,13 +41,13 @@ func getDataBrokerRecord( // if the current record version is less than the lowest we'll accept, invalidate the cache if res.GetRecords()[0].GetVersion() < lowestRecordVersion { - q.InvalidateCache(ctx, req) + q.InvalidateCache(ctx, req.Request()) } else { return res.GetRecords()[0], nil } // retry with an up to date cache - res, err = q.Query(ctx, req) + res, err = q.Query(ctx, req.Request()) if err != nil { return nil, err } diff --git a/authorize/internal/store/request_pool.go b/authorize/internal/store/request_pool.go new file mode 100644 index 000000000..3322b9448 --- /dev/null +++ b/authorize/internal/store/request_pool.go @@ -0,0 +1,70 @@ +package store + +import ( + "encoding/binary" + "sync" + + "github.com/cespare/xxhash/v2" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "google.golang.org/protobuf/types/known/structpb" +) + +var queryRequestPool = sync.Pool{ + New: func() any { + idOrIndex := &structpb.Value_StringValue{} + pqr := &PooledQueryRequest{ + qr: &databroker.QueryRequest{ + Limit: 1, + Filter: &structpb.Struct{Fields: map[string]*structpb.Value{ + "$or": structpb.NewListValue(&structpb.ListValue{Values: []*structpb.Value{ + structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{ + "id": {Kind: idOrIndex}, + }}), + structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{ + "$index": {Kind: idOrIndex}, + }}), + }}), + }}, + }, + idOrIndex: idOrIndex, + } + return pqr + }, +} + +type PooledQueryRequest struct { + qr *databroker.QueryRequest + cacheKey [16]byte + + idOrIndex *structpb.Value_StringValue +} + +func (pqr *PooledQueryRequest) SetRecordType(recordType string) { + pqr.qr.Type = recordType + binary.LittleEndian.PutUint64(pqr.cacheKey[0:8], xxhash.Sum64String(recordType)) +} + +func (pqr *PooledQueryRequest) SetIDOrIndex(idOrIndex string) { + pqr.idOrIndex.StringValue = idOrIndex + binary.LittleEndian.PutUint64(pqr.cacheKey[8:16], xxhash.Sum64String(idOrIndex)) +} + +func (pqr *PooledQueryRequest) Request() *databroker.QueryRequest { + return pqr.qr +} + +func (pqr *PooledQueryRequest) CacheKey() []byte { + return pqr.cacheKey[:] +} + +func (pqr *PooledQueryRequest) Release() { + queryRequestPool.Put(pqr) +} + +func GetPooledQueryRequest() *PooledQueryRequest { + pqr := queryRequestPool.Get().(*PooledQueryRequest) + pqr.qr.Type = "" + pqr.idOrIndex.StringValue = "" + clear(pqr.cacheKey[:]) + return pqr +} diff --git a/authorize/internal/store/store.go b/authorize/internal/store/store.go index 6d77baaca..c75502585 100644 --- a/authorize/internal/store/store.go +++ b/authorize/internal/store/store.go @@ -21,7 +21,6 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/trace" - "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" ) @@ -163,13 +162,13 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) { } func (s *Store) GetDataBrokerRecord(ctx context.Context, recordType, recordIDOrIndex string) proto.Message { - req := &databroker.QueryRequest{ - Type: recordType, - Limit: 1, - } - req.SetFilterByIDOrIndex(recordIDOrIndex) + req := GetPooledQueryRequest() + req.SetRecordType(recordType) + req.SetIDOrIndex(recordIDOrIndex) + ctx = storage.ContextWithCacheKey(ctx, req.CacheKey()) + defer req.Release() - res, err := storage.GetQuerier(ctx).Query(ctx, req) + res, err := storage.GetQuerier(ctx).Query(ctx, req.Request()) if err != nil { log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record") return nil diff --git a/pkg/storage/querier.go b/pkg/storage/querier.go index 1de0b195d..8f6e84d12 100644 --- a/pkg/storage/querier.go +++ b/pkg/storage/querier.go @@ -215,6 +215,19 @@ func (q *TracingQuerier) Traces() []QueryTrace { return traces } +type cacheKeyType struct{} + +var cacheKeyKey cacheKeyType + +func ContextWithCacheKey(ctx context.Context, cacheKey []byte) context.Context { + return context.WithValue(ctx, cacheKeyKey, cacheKey) +} + +func CacheKeyFromContext(ctx context.Context) ([]byte, bool) { + v, ok := ctx.Value(cacheKeyKey).([]byte) + return v, ok +} + type cachingQuerier struct { q Querier cache Cache @@ -229,24 +242,35 @@ func NewCachingQuerier(q Querier, cache Cache) Querier { } func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) { - key, err := (&proto.MarshalOptions{ - Deterministic: true, - }).Marshal(in) - if err != nil { - return + var key []byte + if k, ok := CacheKeyFromContext(ctx); ok { + key = k + } else { + var err error + key, err = (&proto.MarshalOptions{ + Deterministic: true, + }).Marshal(in) + if err != nil { + return + } } q.cache.Invalidate(key) q.q.InvalidateCache(ctx, in) } func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { - key, err := (&proto.MarshalOptions{ - Deterministic: true, - }).Marshal(in) - if err != nil { - return nil, err + var key []byte + if k, ok := CacheKeyFromContext(ctx); ok { + key = k + } else { + var err error + key, err = (&proto.MarshalOptions{ + Deterministic: true, + }).Marshal(in) + if err != nil { + return nil, err + } } - rawResult, err := q.cache.GetOrUpdate(ctx, key, func(ctx context.Context) ([]byte, error) { res, err := q.q.Query(ctx, in, opts...) if err != nil {