mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-30 17:37:25 +02:00
authorize: cache warming (#5439)
* authorize: cache warming * add Authorize to test? * remove tracing querier * only update connection when it changes
This commit is contained in:
parent
b674d5c19d
commit
6e1fabec0b
9 changed files with 254 additions and 186 deletions
|
@ -10,7 +10,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/pomerium/datasource/pkg/directory"
|
||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
|
@ -21,16 +24,16 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// Authorize struct holds
|
||||
type Authorize struct {
|
||||
state *atomicutil.Value[*authorizeState]
|
||||
store *store.Store
|
||||
currentOptions *atomicutil.Value[*config.Options]
|
||||
accessTracker *AccessTracker
|
||||
globalCache storage.Cache
|
||||
state *atomicutil.Value[*authorizeState]
|
||||
store *store.Store
|
||||
currentOptions *atomicutil.Value[*config.Options]
|
||||
accessTracker *AccessTracker
|
||||
globalCache storage.Cache
|
||||
groupsCacheWarmer *cacheWarmer
|
||||
|
||||
// The stateLock prevents updating the evaluator store simultaneously with an evaluation.
|
||||
// This should provide a consistent view of the data at a given server/record version and
|
||||
|
@ -60,6 +63,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
|||
}
|
||||
a.state = atomicutil.NewValue(state)
|
||||
|
||||
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, a.globalCache, directory.GroupRecordType)
|
||||
return a, nil
|
||||
}
|
||||
|
||||
|
@ -70,8 +74,16 @@ func (a *Authorize) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli
|
|||
|
||||
// Run runs the authorize service.
|
||||
func (a *Authorize) Run(ctx context.Context) error {
|
||||
a.accessTracker.Run(ctx)
|
||||
return nil
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
a.accessTracker.Run(ctx)
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
a.groupsCacheWarmer.Run(ctx)
|
||||
return nil
|
||||
})
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
func validateOptions(o *config.Options) error {
|
||||
|
@ -150,9 +162,13 @@ func newPolicyEvaluator(
|
|||
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
currentState := a.state.Load()
|
||||
a.currentOptions.Store(cfg.Options)
|
||||
if state, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
|
||||
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
|
||||
} else {
|
||||
a.state.Store(state)
|
||||
a.state.Store(newState)
|
||||
|
||||
if currentState.dataBrokerClientConnection != newState.dataBrokerClientConnection {
|
||||
a.groupsCacheWarmer.UpdateConn(newState.dataBrokerClientConnection)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
122
authorize/cache_warmer.go
Normal file
122
authorize/cache_warmer.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package authorize
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
type cacheWarmer struct {
|
||||
cc *grpc.ClientConn
|
||||
cache storage.Cache
|
||||
typeURL string
|
||||
|
||||
updatedCC chan *grpc.ClientConn
|
||||
}
|
||||
|
||||
func newCacheWarmer(
|
||||
cc *grpc.ClientConn,
|
||||
cache storage.Cache,
|
||||
typeURL string,
|
||||
) *cacheWarmer {
|
||||
return &cacheWarmer{
|
||||
cc: cc,
|
||||
cache: cache,
|
||||
typeURL: typeURL,
|
||||
|
||||
updatedCC: make(chan *grpc.ClientConn, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *cacheWarmer) UpdateConn(cc *grpc.ClientConn) {
|
||||
for {
|
||||
select {
|
||||
case cw.updatedCC <- cc:
|
||||
return
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-cw.updatedCC:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *cacheWarmer) Run(ctx context.Context) {
|
||||
// Run a syncer for the cache warmer until the underlying databroker connection is changed.
|
||||
// When that happens cancel the currently running syncer and start a new one.
|
||||
|
||||
runCtx, runCancel := context.WithCancel(ctx)
|
||||
go cw.run(runCtx, cw.cc)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
runCancel()
|
||||
return
|
||||
case cc := <-cw.updatedCC:
|
||||
log.Ctx(ctx).Info().Msg("cache-warmer: received updated databroker client connection, restarting syncer")
|
||||
cw.cc = cc
|
||||
runCancel()
|
||||
runCtx, runCancel = context.WithCancel(ctx)
|
||||
go cw.run(runCtx, cw.cc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *cacheWarmer) run(ctx context.Context, cc *grpc.ClientConn) {
|
||||
log.Ctx(ctx).Debug().Str("type-url", cw.typeURL).Msg("cache-warmer: running databroker syncer to warm cache")
|
||||
_ = databroker.NewSyncer(ctx, "cache-warmer", cacheWarmerSyncerHandler{
|
||||
client: databroker.NewDataBrokerServiceClient(cc),
|
||||
cache: cw.cache,
|
||||
}, databroker.WithTypeURL(cw.typeURL)).Run(ctx)
|
||||
}
|
||||
|
||||
type cacheWarmerSyncerHandler struct {
|
||||
client databroker.DataBrokerServiceClient
|
||||
cache storage.Cache
|
||||
}
|
||||
|
||||
func (h cacheWarmerSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
||||
return h.client
|
||||
}
|
||||
|
||||
func (h cacheWarmerSyncerHandler) ClearRecords(_ context.Context) {
|
||||
h.cache.InvalidateAll()
|
||||
}
|
||||
|
||||
func (h cacheWarmerSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
|
||||
for _, record := range records {
|
||||
req := &databroker.QueryRequest{
|
||||
Type: record.Type,
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex(record.Id)
|
||||
|
||||
res := &databroker.QueryResponse{
|
||||
Records: []*databroker.Record{record},
|
||||
TotalCount: 1,
|
||||
ServerVersion: serverVersion,
|
||||
RecordVersion: record.Version,
|
||||
}
|
||||
|
||||
expiry := time.Now().Add(time.Hour * 24 * 365)
|
||||
key, err := storage.MarshalQueryRequest(req)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query request")
|
||||
continue
|
||||
}
|
||||
value, err := storage.MarshalQueryResponse(res)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("cache-warmer: failed to marshal query response")
|
||||
continue
|
||||
}
|
||||
|
||||
h.cache.Set(expiry, key, value)
|
||||
}
|
||||
}
|
52
authorize/cache_warmer_test.go
Normal file
52
authorize/cache_warmer_test.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package authorize
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/databroker"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func TestCacheWarmer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, 10*time.Minute)
|
||||
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
||||
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
|
||||
})
|
||||
t.Cleanup(func() { cc.Close() })
|
||||
|
||||
client := databrokerpb.NewDataBrokerServiceClient(cc)
|
||||
_, err := client.Put(ctx, &databrokerpb.PutRequest{
|
||||
Records: []*databrokerpb.Record{
|
||||
{Type: "example.com/record", Id: "e1", Data: protoutil.NewAnyBool(true)},
|
||||
{Type: "example.com/record", Id: "e2", Data: protoutil.NewAnyBool(true)},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cache := storage.NewGlobalCache(time.Minute)
|
||||
|
||||
cw := newCacheWarmer(cc, cache, "example.com/record")
|
||||
go cw.Run(ctx)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
req := &databrokerpb.QueryRequest{
|
||||
Type: "example.com/record",
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex("e1")
|
||||
res, err := storage.NewCachingQuerier(storage.NewStaticQuerier(), cache).Query(ctx, req)
|
||||
require.NoError(t, err)
|
||||
return len(res.GetRecords()) == 1
|
||||
}, 10*time.Second, time.Millisecond*100)
|
||||
}
|
|
@ -37,10 +37,8 @@ func Test_getDataBrokerRecord(t *testing.T) {
|
|||
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
|
||||
|
||||
sq := storage.NewStaticQuerier(s1)
|
||||
tsq := storage.NewTracingQuerier(sq)
|
||||
cq := storage.NewCachingQuerier(tsq, storage.NewLocalCache())
|
||||
tcq := storage.NewTracingQuerier(cq)
|
||||
qctx := storage.WithQuerier(ctx, tcq)
|
||||
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||
qctx := storage.WithQuerier(ctx, cq)
|
||||
|
||||
s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||
assert.NoError(t, err)
|
||||
|
@ -49,11 +47,6 @@ func Test_getDataBrokerRecord(t *testing.T) {
|
|||
s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
|
||||
assert.Len(t, tsq.Traces(), tc.underlyingQueryCount,
|
||||
"should have %d traces to the underlying querier", tc.underlyingQueryCount)
|
||||
assert.Len(t, tcq.Traces(), tc.cachedQueryCount,
|
||||
"should have %d traces to the cached querier", tc.cachedQueryCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,14 +31,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check")
|
||||
defer span.End()
|
||||
|
||||
querier := storage.NewTracingQuerier(
|
||||
storage.NewCachingQuerier(
|
||||
storage.NewCachingQuerier(
|
||||
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||||
a.globalCache,
|
||||
),
|
||||
storage.NewLocalCache(),
|
||||
),
|
||||
querier := storage.NewCachingQuerier(
|
||||
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||||
a.globalCache,
|
||||
)
|
||||
ctx = storage.WithQuerier(ctx, querier)
|
||||
|
||||
|
|
|
@ -12,13 +12,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
@ -27,6 +20,14 @@ import (
|
|||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.17.0"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
|
||||
)
|
||||
|
||||
func otlpTraceReceiverOrFromEnv(t *testing.T) (modifier testenv.Modifier, newRemoteClient func() otlptrace.Client, getResults func() *TraceResults) {
|
||||
|
@ -116,7 +117,7 @@ func TestOTLPTracing(t *testing.T) {
|
|||
Exact: true,
|
||||
CheckDetachedSpans: true,
|
||||
},
|
||||
Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Test Environment", "Control Plane", "Data Broker"}},
|
||||
Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Authorize", "Test Environment", "Control Plane", "Data Broker"}},
|
||||
Match{Name: testEnvironmentAuthenticate, TraceCount: 1, Services: allServices},
|
||||
Match{Name: authenticateOAuth2Client, TraceCount: Greater(0)},
|
||||
Match{Name: idpServerGetUserinfo, TraceCount: EqualToMatch(authenticateOAuth2Client)},
|
||||
|
|
|
@ -18,64 +18,8 @@ type Cache interface {
|
|||
update func(ctx context.Context) ([]byte, error),
|
||||
) ([]byte, error)
|
||||
Invalidate(key []byte)
|
||||
}
|
||||
|
||||
type localCache struct {
|
||||
singleflight singleflight.Group
|
||||
mu sync.RWMutex
|
||||
m map[string][]byte
|
||||
}
|
||||
|
||||
// NewLocalCache creates a new Cache backed by a map.
|
||||
func NewLocalCache() Cache {
|
||||
return &localCache{
|
||||
m: make(map[string][]byte),
|
||||
}
|
||||
}
|
||||
|
||||
func (cache *localCache) GetOrUpdate(
|
||||
ctx context.Context,
|
||||
key []byte,
|
||||
update func(ctx context.Context) ([]byte, error),
|
||||
) ([]byte, error) {
|
||||
strkey := string(key)
|
||||
|
||||
cache.mu.RLock()
|
||||
cached, ok := cache.m[strkey]
|
||||
cache.mu.RUnlock()
|
||||
if ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
v, err, _ := cache.singleflight.Do(strkey, func() (any, error) {
|
||||
cache.mu.RLock()
|
||||
cached, ok := cache.m[strkey]
|
||||
cache.mu.RUnlock()
|
||||
if ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
result, err := update(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cache.mu.Lock()
|
||||
cache.m[strkey] = result
|
||||
cache.mu.Unlock()
|
||||
|
||||
return result, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v.([]byte), nil
|
||||
}
|
||||
|
||||
func (cache *localCache) Invalidate(key []byte) {
|
||||
cache.mu.Lock()
|
||||
delete(cache.m, string(key))
|
||||
cache.mu.Unlock()
|
||||
InvalidateAll()
|
||||
Set(expiry time.Time, key, value []byte)
|
||||
}
|
||||
|
||||
type globalCache struct {
|
||||
|
@ -115,7 +59,7 @@ func (cache *globalCache) GetOrUpdate(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cache.set(key, value)
|
||||
cache.set(time.Now().Add(cache.ttl), key, value)
|
||||
return value, nil
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -130,6 +74,16 @@ func (cache *globalCache) Invalidate(key []byte) {
|
|||
cache.mu.Unlock()
|
||||
}
|
||||
|
||||
func (cache *globalCache) InvalidateAll() {
|
||||
cache.mu.Lock()
|
||||
cache.fastcache.Reset()
|
||||
cache.mu.Unlock()
|
||||
}
|
||||
|
||||
func (cache *globalCache) Set(expiry time.Time, key, value []byte) {
|
||||
cache.set(expiry, key, value)
|
||||
}
|
||||
|
||||
func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) {
|
||||
cache.mu.RLock()
|
||||
item := cache.fastcache.Get(nil, k)
|
||||
|
@ -143,13 +97,13 @@ func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool)
|
|||
return data, expiry, true
|
||||
}
|
||||
|
||||
func (cache *globalCache) set(k, v []byte) {
|
||||
unix := time.Now().Add(cache.ttl).UnixMilli()
|
||||
item := make([]byte, len(v)+8)
|
||||
func (cache *globalCache) set(expiry time.Time, key, value []byte) {
|
||||
unix := expiry.UnixMilli()
|
||||
item := make([]byte, len(value)+8)
|
||||
binary.LittleEndian.PutUint64(item, uint64(unix))
|
||||
copy(item[8:], v)
|
||||
copy(item[8:], value)
|
||||
|
||||
cache.mu.Lock()
|
||||
cache.fastcache.Set(k, item)
|
||||
cache.fastcache.Set(key, item)
|
||||
cache.mu.Unlock()
|
||||
}
|
||||
|
|
|
@ -9,34 +9,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLocalCache(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
callCount := 0
|
||||
update := func(_ context.Context) ([]byte, error) {
|
||||
callCount++
|
||||
return []byte("v1"), nil
|
||||
}
|
||||
c := NewLocalCache()
|
||||
v, err := c.GetOrUpdate(ctx, []byte("k1"), update)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("v1"), v)
|
||||
assert.Equal(t, 1, callCount)
|
||||
|
||||
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("v1"), v)
|
||||
assert.Equal(t, 1, callCount)
|
||||
|
||||
c.Invalidate([]byte("k1"))
|
||||
|
||||
v, err = c.GetOrUpdate(ctx, []byte("k1"), update)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("v1"), v)
|
||||
assert.Equal(t, 2, callCount)
|
||||
}
|
||||
|
||||
func TestGlobalCache(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
@ -70,4 +42,10 @@ func TestGlobalCache(t *testing.T) {
|
|||
})
|
||||
return err != nil
|
||||
}, time.Second, time.Millisecond*10, "should honor TTL")
|
||||
|
||||
c.Set(time.Now().Add(time.Hour), []byte("k2"), []byte("v2"))
|
||||
v, err = c.GetOrUpdate(ctx, []byte("k2"), update)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("v2"), v)
|
||||
assert.Equal(t, 2, callCount)
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
grpc "google.golang.org/grpc"
|
||||
|
@ -12,7 +11,6 @@ import (
|
|||
status "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
|
@ -162,59 +160,6 @@ func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
|
|||
return q.client.Query(ctx, in, opts...)
|
||||
}
|
||||
|
||||
// A TracingQuerier records calls to Query.
|
||||
type TracingQuerier struct {
|
||||
underlying Querier
|
||||
|
||||
mu sync.Mutex
|
||||
traces []QueryTrace
|
||||
}
|
||||
|
||||
// A QueryTrace traces a call to Query.
|
||||
type QueryTrace struct {
|
||||
ServerVersion, RecordVersion uint64
|
||||
|
||||
RecordType string
|
||||
Query string
|
||||
Filter *structpb.Struct
|
||||
}
|
||||
|
||||
// NewTracingQuerier creates a new TracingQuerier.
|
||||
func NewTracingQuerier(q Querier) *TracingQuerier {
|
||||
return &TracingQuerier{
|
||||
underlying: q,
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCache invalidates the cache.
|
||||
func (q *TracingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {
|
||||
q.underlying.InvalidateCache(ctx, in)
|
||||
}
|
||||
|
||||
// Query queries for records.
|
||||
func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||
res, err := q.underlying.Query(ctx, in, opts...)
|
||||
if err == nil {
|
||||
q.mu.Lock()
|
||||
q.traces = append(q.traces, QueryTrace{
|
||||
RecordType: in.GetType(),
|
||||
Query: in.GetQuery(),
|
||||
Filter: in.GetFilter(),
|
||||
})
|
||||
q.mu.Unlock()
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Traces returns all the traces.
|
||||
func (q *TracingQuerier) Traces() []QueryTrace {
|
||||
q.mu.Lock()
|
||||
traces := make([]QueryTrace, len(q.traces))
|
||||
copy(traces, q.traces)
|
||||
q.mu.Unlock()
|
||||
return traces
|
||||
}
|
||||
|
||||
type cachingQuerier struct {
|
||||
q Querier
|
||||
cache Cache
|
||||
|
@ -240,9 +185,7 @@ func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.Que
|
|||
}
|
||||
|
||||
func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||
key, err := (&proto.MarshalOptions{
|
||||
Deterministic: true,
|
||||
}).Marshal(in)
|
||||
key, err := MarshalQueryRequest(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -252,7 +195,7 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return proto.Marshal(res)
|
||||
return MarshalQueryResponse(res)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -265,3 +208,17 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
|
|||
}
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
// MarshalQueryRequest marshales the query request.
|
||||
func MarshalQueryRequest(req *databroker.QueryRequest) ([]byte, error) {
|
||||
return (&proto.MarshalOptions{
|
||||
Deterministic: true,
|
||||
}).Marshal(req)
|
||||
}
|
||||
|
||||
// MarshalQueryResponse marshals the query response.
|
||||
func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) {
|
||||
return (&proto.MarshalOptions{
|
||||
Deterministic: true,
|
||||
}).Marshal(res)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue