authorize: use pooled query request objects with more efficient cache keys

This commit is contained in:
Joe Kralicky 2024-11-20 16:29:33 -05:00
parent 20a9be891f
commit c9117a0274
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
4 changed files with 120 additions and 26 deletions

View file

@ -3,6 +3,7 @@ package authorize
import ( import (
"context" "context"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
@ -24,13 +25,13 @@ func getDataBrokerRecord(
) (*databroker.Record, error) { ) (*databroker.Record, error) {
q := storage.GetQuerier(ctx) q := storage.GetQuerier(ctx)
req := &databroker.QueryRequest{ req := store.GetPooledQueryRequest()
Type: recordType, req.SetRecordType(recordType)
Limit: 1, req.SetIDOrIndex(recordID)
} ctx = storage.ContextWithCacheKey(ctx, req.CacheKey())
req.SetFilterByIDOrIndex(recordID) defer req.Release()
res, err := q.Query(ctx, req) res, err := q.Query(ctx, req.Request())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -40,13 +41,13 @@ func getDataBrokerRecord(
// if the current record version is less than the lowest we'll accept, invalidate the cache // if the current record version is less than the lowest we'll accept, invalidate the cache
if res.GetRecords()[0].GetVersion() < lowestRecordVersion { if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
q.InvalidateCache(ctx, req) q.InvalidateCache(ctx, req.Request())
} else { } else {
return res.GetRecords()[0], nil return res.GetRecords()[0], nil
} }
// retry with an up to date cache // retry with an up to date cache
res, err = q.Query(ctx, req) res, err = q.Query(ctx, req.Request())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -0,0 +1,70 @@
package store
import (
"encoding/binary"
"sync"
"github.com/cespare/xxhash/v2"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"google.golang.org/protobuf/types/known/structpb"
)
var queryRequestPool = sync.Pool{
New: func() any {
idOrIndex := &structpb.Value_StringValue{}
pqr := &PooledQueryRequest{
qr: &databroker.QueryRequest{
Limit: 1,
Filter: &structpb.Struct{Fields: map[string]*structpb.Value{
"$or": structpb.NewListValue(&structpb.ListValue{Values: []*structpb.Value{
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"id": {Kind: idOrIndex},
}}),
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"$index": {Kind: idOrIndex},
}}),
}}),
}},
},
idOrIndex: idOrIndex,
}
return pqr
},
}
type PooledQueryRequest struct {
qr *databroker.QueryRequest
cacheKey [16]byte
idOrIndex *structpb.Value_StringValue
}
func (pqr *PooledQueryRequest) SetRecordType(recordType string) {
pqr.qr.Type = recordType
binary.LittleEndian.PutUint64(pqr.cacheKey[0:8], xxhash.Sum64String(recordType))
}
func (pqr *PooledQueryRequest) SetIDOrIndex(idOrIndex string) {
pqr.idOrIndex.StringValue = idOrIndex
binary.LittleEndian.PutUint64(pqr.cacheKey[8:16], xxhash.Sum64String(idOrIndex))
}
func (pqr *PooledQueryRequest) Request() *databroker.QueryRequest {
return pqr.qr
}
func (pqr *PooledQueryRequest) CacheKey() []byte {
return pqr.cacheKey[:]
}
func (pqr *PooledQueryRequest) Release() {
queryRequestPool.Put(pqr)
}
func GetPooledQueryRequest() *PooledQueryRequest {
pqr := queryRequestPool.Get().(*PooledQueryRequest)
pqr.qr.Type = ""
pqr.idOrIndex.StringValue = ""
clear(pqr.cacheKey[:])
return pqr
}

View file

@ -21,7 +21,6 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
) )
@ -163,13 +162,13 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
} }
func (s *Store) GetDataBrokerRecord(ctx context.Context, recordType, recordIDOrIndex string) proto.Message { func (s *Store) GetDataBrokerRecord(ctx context.Context, recordType, recordIDOrIndex string) proto.Message {
req := &databroker.QueryRequest{ req := GetPooledQueryRequest()
Type: recordType, req.SetRecordType(recordType)
Limit: 1, req.SetIDOrIndex(recordIDOrIndex)
} ctx = storage.ContextWithCacheKey(ctx, req.CacheKey())
req.SetFilterByIDOrIndex(recordIDOrIndex) defer req.Release()
res, err := storage.GetQuerier(ctx).Query(ctx, req) res, err := storage.GetQuerier(ctx).Query(ctx, req.Request())
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record") log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record")
return nil return nil

View file

@ -215,6 +215,19 @@ func (q *TracingQuerier) Traces() []QueryTrace {
return traces return traces
} }
type cacheKeyType struct{}
var cacheKeyKey cacheKeyType
func ContextWithCacheKey(ctx context.Context, cacheKey []byte) context.Context {
return context.WithValue(ctx, cacheKeyKey, cacheKey)
}
func CacheKeyFromContext(ctx context.Context) ([]byte, bool) {
v, ok := ctx.Value(cacheKeyKey).([]byte)
return v, ok
}
type cachingQuerier struct { type cachingQuerier struct {
q Querier q Querier
cache Cache cache Cache
@ -229,24 +242,35 @@ func NewCachingQuerier(q Querier, cache Cache) Querier {
} }
func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) { func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {
key, err := (&proto.MarshalOptions{ var key []byte
if k, ok := CacheKeyFromContext(ctx); ok {
key = k
} else {
var err error
key, err = (&proto.MarshalOptions{
Deterministic: true, Deterministic: true,
}).Marshal(in) }).Marshal(in)
if err != nil { if err != nil {
return return
} }
}
q.cache.Invalidate(key) q.cache.Invalidate(key)
q.q.InvalidateCache(ctx, in) q.q.InvalidateCache(ctx, in)
} }
func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
key, err := (&proto.MarshalOptions{ var key []byte
if k, ok := CacheKeyFromContext(ctx); ok {
key = k
} else {
var err error
key, err = (&proto.MarshalOptions{
Deterministic: true, Deterministic: true,
}).Marshal(in) }).Marshal(in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
rawResult, err := q.cache.GetOrUpdate(ctx, key, func(ctx context.Context) ([]byte, error) { rawResult, err := q.cache.GetOrUpdate(ctx, key, func(ctx context.Context) ([]byte, error) {
res, err := q.q.Query(ctx, in, opts...) res, err := q.q.Query(ctx, in, opts...)
if err != nil { if err != nil {