storage: add sync querier

This commit is contained in:
Caleb Doxsey 2025-04-10 11:35:22 -06:00
parent 04a5506d1b
commit e31ed435be
7 changed files with 277 additions and 0 deletions

View file

@ -22,6 +22,7 @@ var ErrUnavailable = errors.New("unavailable")
type Querier interface {
InvalidateCache(ctx context.Context, in *databroker.QueryRequest)
Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error)
Stop()
}
// nilQuerier always returns NotFound.
@ -33,6 +34,8 @@ func (nilQuerier) Query(_ context.Context, _ *databroker.QueryRequest, _ ...grpc
return nil, errors.Join(ErrUnavailable, status.Error(codes.NotFound, "not found"))
}
func (nilQuerier) Stop() {}
type querierKey struct{}
// GetQuerier gets the databroker Querier from the context.

View file

@ -50,6 +50,8 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
return res, nil
}
func (*cachingQuerier) Stop() {}
func (q *cachingQuerier) getCacheKey(in *databroker.QueryRequest) ([]byte, error) {
in = proto.Clone(in).(*databroker.QueryRequest)
in.MinimumRecordVersionHint = nil

View file

@ -23,3 +23,5 @@ func (q *clientQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRe
func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
return q.client.Query(ctx, in, opts...)
}
func (*clientQuerier) Stop() {}

View file

@ -40,3 +40,10 @@ func (q fallbackQuerier) Query(ctx context.Context, req *databroker.QueryRequest
}
return nil, merr
}
// Stop stops all the queriers.
func (q fallbackQuerier) Stop() {
for _, qq := range q {
qq.Stop()
}
}

View file

@ -81,3 +81,5 @@ func (q *staticQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRe
func (q *staticQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
return QueryRecordCollections(q.records, req)
}
func (*staticQuerier) Stop() {}

172
pkg/storage/querier_sync.go Normal file
View file

@ -0,0 +1,172 @@
package storage
import (
"context"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type syncQuerier struct {
client databroker.DataBrokerServiceClient
recordType string
cancel context.CancelFunc
serverVersion uint64
latestRecordVersion uint64
mu sync.RWMutex
ready bool
records RecordCollection
}
// NewSyncQuerier creates a new Querier backed by an in-memory record collection
// filled via sync calls to the databroker.
func NewSyncQuerier(
client databroker.DataBrokerServiceClient,
recordType string,
) Querier {
q := &syncQuerier{
client: client,
recordType: recordType,
records: NewRecordCollection(),
}
ctx, cancel := context.WithCancel(context.Background())
q.cancel = cancel
go q.run(ctx)
return q
}
func (q *syncQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {
// do nothing
}
func (q *syncQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
q.mu.RLock()
if !q.canHandleQueryLocked(req) {
q.mu.RUnlock()
return nil, ErrUnavailable
}
defer q.mu.RUnlock()
return QueryRecordCollections(map[string]RecordCollection{
q.recordType: q.records,
}, req)
}
func (q *syncQuerier) Stop() {
q.cancel()
}
func (q *syncQuerier) canHandleQueryLocked(req *databroker.QueryRequest) bool {
if !q.ready {
return false
}
if req.GetType() != q.recordType {
return false
}
if req.MinimumRecordVersionHint != nil && q.latestRecordVersion < *req.MinimumRecordVersionHint {
return false
}
return true
}
func (q *syncQuerier) run(ctx context.Context) {
bo := backoff.WithContext(backoff.NewExponentialBackOff(backoff.WithMaxElapsedTime(0)), ctx)
_ = backoff.RetryNotify(func() error {
if q.serverVersion == 0 {
err := q.syncLatest(ctx)
if err != nil {
return err
}
}
return q.sync(ctx)
}, bo, func(err error, d time.Duration) {
log.Ctx(ctx).Error().
Err(err).
Dur("delay", d).
Msg("storage/sync-querier: error syncing records")
})
}
func (q *syncQuerier) syncLatest(ctx context.Context) error {
stream, err := q.client.SyncLatest(ctx, &databroker.SyncLatestRequest{
Type: q.recordType,
})
if err != nil {
return fmt.Errorf("error starting sync latest stream: %w", err)
}
q.mu.Lock()
q.ready = false
q.records.Clear()
q.mu.Unlock()
for {
res, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return fmt.Errorf("error receiving sync latest message: %w", err)
}
switch res := res.Response.(type) {
case *databroker.SyncLatestResponse_Record:
q.mu.Lock()
q.records.Put(res.Record)
q.mu.Unlock()
case *databroker.SyncLatestResponse_Versions:
q.serverVersion = res.Versions.ServerVersion
q.latestRecordVersion = res.Versions.LatestRecordVersion
default:
return fmt.Errorf("unknown message type from sync latest: %T", res)
}
}
q.mu.Lock()
q.ready = true
q.mu.Unlock()
return nil
}
func (q *syncQuerier) sync(ctx context.Context) error {
stream, err := q.client.Sync(ctx, &databroker.SyncRequest{
ServerVersion: q.serverVersion,
RecordVersion: q.latestRecordVersion,
Type: q.recordType,
})
if err != nil {
return fmt.Errorf("error starting sync stream: %w", err)
}
for {
res, err := stream.Recv()
if status.Code(err) == codes.Aborted {
// this indicates the server version changed, so we need to reset
q.serverVersion = 0
q.latestRecordVersion = 0
return fmt.Errorf("stream was aborted due to mismatched server versions: %w", err)
} else if err != nil {
return fmt.Errorf("error receiving sync message: %w", err)
}
q.latestRecordVersion = max(q.latestRecordVersion, res.Record.Version)
q.mu.Lock()
q.records.Put(res.Record)
q.mu.Unlock()
}
}

View file

@ -0,0 +1,89 @@
package storage_test
import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace/noop"
grpc "google.golang.org/grpc"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/testutil"
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
)
func TestSyncQuerier(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, 10*time.Minute)
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
})
t.Cleanup(func() { cc.Close() })
client := databrokerpb.NewDataBrokerServiceClient(cc)
r1 := &databrokerpb.Record{
Type: "t1",
Id: "r1",
Data: protoutil.ToAny("q2"),
}
_, err := client.Put(ctx, &databrokerpb.PutRequest{
Records: []*databrokerpb.Record{r1},
})
require.NoError(t, err)
r2 := &databrokerpb.Record{
Type: "t1",
Id: "r2",
Data: protoutil.ToAny("q2"),
}
q := storage.NewSyncQuerier(client, "t1")
t.Cleanup(q.Stop)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
res, err := q.Query(ctx, &databrokerpb.QueryRequest{
Type: "t1",
Filter: newStruct(t, map[string]any{
"id": "r1",
}),
Limit: 1,
})
if assert.NoError(c, err) && assert.Len(c, res.Records, 1) {
assert.Empty(c, cmp.Diff(r1.Data, res.Records[0].Data, protocmp.Transform()))
}
}, time.Second*10, time.Millisecond*50, "should sync records")
_, err = client.Put(ctx, &databrokerpb.PutRequest{
Records: []*databrokerpb.Record{r2},
})
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
res, err := q.Query(ctx, &databrokerpb.QueryRequest{
Type: "t1",
Filter: newStruct(t, map[string]any{
"id": "r2",
}),
Limit: 1,
})
if assert.NoError(c, err) && assert.Len(c, res.Records, 1) {
assert.Empty(c, cmp.Diff(r2.Data, res.Records[0].Data, protocmp.Transform()))
}
}, time.Second*10, time.Millisecond*50, "should pick up changes")
}
func newStruct(t *testing.T, m map[string]any) *structpb.Struct {
t.Helper()
s, err := structpb.NewStruct(m)
require.NoError(t, err)
return s
}