authorize: use pooled query request objects with more efficient cache keys

This commit is contained in:
Joe Kralicky 2024-11-20 16:29:33 -05:00
parent 20a9be891f
commit c9117a0274
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
4 changed files with 120 additions and 26 deletions

View file

@ -3,6 +3,7 @@ package authorize
import (
"context"
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
@ -24,13 +25,13 @@ func getDataBrokerRecord(
) (*databroker.Record, error) {
q := storage.GetQuerier(ctx)
req := &databroker.QueryRequest{
Type: recordType,
Limit: 1,
}
req.SetFilterByIDOrIndex(recordID)
req := store.GetPooledQueryRequest()
req.SetRecordType(recordType)
req.SetIDOrIndex(recordID)
ctx = storage.ContextWithCacheKey(ctx, req.CacheKey())
defer req.Release()
res, err := q.Query(ctx, req)
res, err := q.Query(ctx, req.Request())
if err != nil {
return nil, err
}
@ -40,13 +41,13 @@ func getDataBrokerRecord(
// if the current record version is less than the lowest we'll accept, invalidate the cache
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
q.InvalidateCache(ctx, req)
q.InvalidateCache(ctx, req.Request())
} else {
return res.GetRecords()[0], nil
}
// retry with an up to date cache
res, err = q.Query(ctx, req)
res, err = q.Query(ctx, req.Request())
if err != nil {
return nil, err
}

View file

@ -0,0 +1,70 @@
package store
import (
"encoding/binary"
"sync"
"github.com/cespare/xxhash/v2"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"google.golang.org/protobuf/types/known/structpb"
)
var queryRequestPool = sync.Pool{
New: func() any {
idOrIndex := &structpb.Value_StringValue{}
pqr := &PooledQueryRequest{
qr: &databroker.QueryRequest{
Limit: 1,
Filter: &structpb.Struct{Fields: map[string]*structpb.Value{
"$or": structpb.NewListValue(&structpb.ListValue{Values: []*structpb.Value{
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"id": {Kind: idOrIndex},
}}),
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"$index": {Kind: idOrIndex},
}}),
}}),
}},
},
idOrIndex: idOrIndex,
}
return pqr
},
}
type PooledQueryRequest struct {
qr *databroker.QueryRequest
cacheKey [16]byte
idOrIndex *structpb.Value_StringValue
}
func (pqr *PooledQueryRequest) SetRecordType(recordType string) {
pqr.qr.Type = recordType
binary.LittleEndian.PutUint64(pqr.cacheKey[0:8], xxhash.Sum64String(recordType))
}
func (pqr *PooledQueryRequest) SetIDOrIndex(idOrIndex string) {
pqr.idOrIndex.StringValue = idOrIndex
binary.LittleEndian.PutUint64(pqr.cacheKey[8:16], xxhash.Sum64String(idOrIndex))
}
func (pqr *PooledQueryRequest) Request() *databroker.QueryRequest {
return pqr.qr
}
func (pqr *PooledQueryRequest) CacheKey() []byte {
return pqr.cacheKey[:]
}
func (pqr *PooledQueryRequest) Release() {
queryRequestPool.Put(pqr)
}
func GetPooledQueryRequest() *PooledQueryRequest {
pqr := queryRequestPool.Get().(*PooledQueryRequest)
pqr.qr.Type = ""
pqr.idOrIndex.StringValue = ""
clear(pqr.cacheKey[:])
return pqr
}

View file

@ -21,7 +21,6 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
@ -163,13 +162,13 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) {
}
func (s *Store) GetDataBrokerRecord(ctx context.Context, recordType, recordIDOrIndex string) proto.Message {
req := &databroker.QueryRequest{
Type: recordType,
Limit: 1,
}
req.SetFilterByIDOrIndex(recordIDOrIndex)
req := GetPooledQueryRequest()
req.SetRecordType(recordType)
req.SetIDOrIndex(recordIDOrIndex)
ctx = storage.ContextWithCacheKey(ctx, req.CacheKey())
defer req.Release()
res, err := storage.GetQuerier(ctx).Query(ctx, req)
res, err := storage.GetQuerier(ctx).Query(ctx, req.Request())
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize/store: error retrieving record")
return nil

View file

@ -215,6 +215,19 @@ func (q *TracingQuerier) Traces() []QueryTrace {
return traces
}
type cacheKeyType struct{}
var cacheKeyKey cacheKeyType
func ContextWithCacheKey(ctx context.Context, cacheKey []byte) context.Context {
return context.WithValue(ctx, cacheKeyKey, cacheKey)
}
func CacheKeyFromContext(ctx context.Context) ([]byte, bool) {
v, ok := ctx.Value(cacheKeyKey).([]byte)
return v, ok
}
type cachingQuerier struct {
q Querier
cache Cache
@ -229,24 +242,35 @@ func NewCachingQuerier(q Querier, cache Cache) Querier {
}
func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {
key, err := (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(in)
if err != nil {
return
var key []byte
if k, ok := CacheKeyFromContext(ctx); ok {
key = k
} else {
var err error
key, err = (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(in)
if err != nil {
return
}
}
q.cache.Invalidate(key)
q.q.InvalidateCache(ctx, in)
}
func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
key, err := (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(in)
if err != nil {
return nil, err
var key []byte
if k, ok := CacheKeyFromContext(ctx); ok {
key = k
} else {
var err error
key, err = (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(in)
if err != nil {
return nil, err
}
}
rawResult, err := q.cache.GetOrUpdate(ctx, key, func(ctx context.Context) ([]byte, error) {
res, err := q.q.Query(ctx, in, opts...)
if err != nil {