pomerium/pkg/storage/querier_caching.go
Caleb Doxsey 8738066ce4
storage: add sync querier (#5570)
* storage: add fallback querier

* storage: add sync querier

* storage: add typed querier

* use synced querier
2025-04-23 10:15:48 -06:00

84 lines
2 KiB
Go

package storage
import (
"context"
grpc "google.golang.org/grpc"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type cachingQuerier struct {
q Querier
cache Cache
}
// NewCachingQuerier creates a new querier that caches results in a Cache.
func NewCachingQuerier(q Querier, cache Cache) Querier {
return &cachingQuerier{
q: q,
cache: cache,
}
}
func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {
key, err := q.getCacheKey(in)
if err != nil {
return
}
q.cache.Invalidate(key)
q.q.InvalidateCache(ctx, in)
}
func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
res, err := q.query(ctx, in, opts...)
if err != nil {
return nil, err
}
// If a minimum record version hint is sent, check to see if the result meets the minimum
// record version and if not, invalidate the cache and re-query.
if in.MinimumRecordVersionHint != nil && res.RecordVersion < *in.MinimumRecordVersionHint {
q.InvalidateCache(ctx, in)
res, err = q.query(ctx, in, opts...)
if err != nil {
return nil, err
}
}
return res, nil
}
func (*cachingQuerier) Stop() {}
func (q *cachingQuerier) getCacheKey(in *databroker.QueryRequest) ([]byte, error) {
in = proto.Clone(in).(*databroker.QueryRequest)
in.MinimumRecordVersionHint = nil
return MarshalQueryRequest(in)
}
func (q *cachingQuerier) query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
key, err := q.getCacheKey(in)
if err != nil {
return nil, err
}
rawResult, err := q.cache.GetOrUpdate(ctx, key, func(ctx context.Context) ([]byte, error) {
res, err := q.q.Query(ctx, in, opts...)
if err != nil {
return nil, err
}
return MarshalQueryResponse(res)
})
if err != nil {
return nil, err
}
var res databroker.QueryResponse
err = proto.Unmarshal(rawResult, &res)
if err != nil {
return nil, err
}
return &res, nil
}