diff --git a/pkg/storage/querier.go b/pkg/storage/querier.go index 9de1147c2..18c0eca92 100644 --- a/pkg/storage/querier.go +++ b/pkg/storage/querier.go @@ -23,6 +23,7 @@ import ( type Querier interface { InvalidateCache(ctx context.Context, in *databroker.QueryRequest) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) + Stop() } // nilQuerier always returns NotFound. @@ -34,6 +35,8 @@ func (nilQuerier) Query(_ context.Context, _ *databroker.QueryRequest, _ ...grpc return nil, status.Error(codes.NotFound, "not found") } +func (nilQuerier) Stop() {} + type querierKey struct{} // GetQuerier gets the databroker Querier from the context. @@ -117,6 +120,8 @@ func (q *staticQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ return QueryRecordCollections(q.records, req) } +func (q *staticQuerier) Stop() {} + type clientQuerier struct { client databroker.DataBrokerServiceClient } @@ -133,6 +138,8 @@ func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest, return q.client.Query(ctx, in, opts...) } +func (q *clientQuerier) Stop() {} + type cachingQuerier struct { q Querier cache Cache @@ -182,6 +189,10 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, return &res, nil } +func (q *cachingQuerier) Stop() { + q.cache.InvalidateAll() +} + // MarshalQueryRequest marshales the query request. func MarshalQueryRequest(req *databroker.QueryRequest) ([]byte, error) { return (&proto.MarshalOptions{ diff --git a/pkg/storage/querier_sync.go b/pkg/storage/querier_sync.go new file mode 100644 index 000000000..2fa33b2e3 --- /dev/null +++ b/pkg/storage/querier_sync.go @@ -0,0 +1,176 @@ +package storage + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "time" + + "github.com/cenkalti/backoff/v4" + grpc "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +type syncQuerier struct { + client databroker.DataBrokerServiceClient + recordType string + fallback Querier + + cancel context.CancelFunc + serverVersion uint64 + latestRecordVersion uint64 + + mu sync.RWMutex + ready bool + records RecordCollection +} + +// NewSyncQuerier creates a new Querier backed by an in-memory record collection +// filled via sync calls to the databroker. +func NewSyncQuerier( + client databroker.DataBrokerServiceClient, + recordType string, + fallback Querier, +) Querier { + q := &syncQuerier{ + client: client, + recordType: recordType, + fallback: fallback, + records: NewRecordCollection(), + } + + ctx, cancel := context.WithCancel(context.Background()) + q.cancel = cancel + go q.run(ctx) + + return q +} + +func (q *syncQuerier) InvalidateCache( + ctx context.Context, + req *databroker.QueryRequest, +) { + q.mu.RLock() + ready := q.ready + q.mu.RUnlock() + + // only invalidate the fallback querier if we aren't ready yet + if ready { + q.fallback.InvalidateCache(ctx, req) + } +} + +func (q *syncQuerier) Query( + ctx context.Context, + req *databroker.QueryRequest, + opts ...grpc.CallOption, +) (*databroker.QueryResponse, error) { + q.mu.RLock() + if !q.ready || req.GetType() != q.recordType { + q.mu.RUnlock() + return q.fallback.Query(ctx, req, opts...) + } + defer q.mu.RUnlock() + return QueryRecordCollections(map[string]RecordCollection{ + q.recordType: q.records, + }, req) +} + +func (q *syncQuerier) Stop() { + q.cancel() +} + +func (q *syncQuerier) run(ctx context.Context) { + bo := backoff.WithContext(backoff.NewExponentialBackOff(backoff.WithMaxElapsedTime(0)), ctx) + _ = backoff.RetryNotify(func() error { + if q.serverVersion == 0 { + err := q.syncLatest(ctx) + if err != nil { + return err + } + } + + return q.sync(ctx) + }, bo, func(err error, d time.Duration) { + log.Ctx(ctx).Error(). + Err(err). + Dur("delay", d). + Msg("storage/sync-querier: error syncing records") + }) +} + +func (q *syncQuerier) syncLatest(ctx context.Context) error { + stream, err := q.client.SyncLatest(ctx, &databroker.SyncLatestRequest{ + Type: q.recordType, + }) + if err != nil { + return fmt.Errorf("error starting sync latest stream: %w", err) + } + + q.mu.Lock() + q.ready = false + q.records.Clear() + q.mu.Unlock() + + for { + res, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return fmt.Errorf("error receiving sync latest message: %w", err) + } + + switch res := res.Response.(type) { + case *databroker.SyncLatestResponse_Record: + q.mu.Lock() + q.records.Put(res.Record) + q.mu.Unlock() + case *databroker.SyncLatestResponse_Versions: + q.serverVersion = res.Versions.ServerVersion + q.latestRecordVersion = res.Versions.LatestRecordVersion + default: + return fmt.Errorf("unknown message type from sync latest: %T", res) + } + } + + q.mu.Lock() + q.ready = true + q.mu.Unlock() + + return nil +} + +func (q *syncQuerier) sync(ctx context.Context) error { + stream, err := q.client.Sync(ctx, &databroker.SyncRequest{ + ServerVersion: q.serverVersion, + RecordVersion: q.latestRecordVersion, + Type: q.recordType, + }) + if err != nil { + return fmt.Errorf("error starting sync stream: %w", err) + } + + for { + res, err := stream.Recv() + if status.Code(err) == codes.Aborted { + // this indicates the server version changed, so we need to reset + q.serverVersion = 0 + q.latestRecordVersion = 0 + return fmt.Errorf("stream was aborted due to mismatched server versions: %w", err) + } else if err != nil { + return fmt.Errorf("error receiving sync message: %w", err) + } + + q.latestRecordVersion = max(q.latestRecordVersion, res.Record.Version) + + q.mu.Lock() + q.records.Put(res.Record) + q.mu.Unlock() + } +} diff --git a/pkg/storage/querier_sync_test.go b/pkg/storage/querier_sync_test.go new file mode 100644 index 000000000..b01824e3c --- /dev/null +++ b/pkg/storage/querier_sync_test.go @@ -0,0 +1,105 @@ +package storage_test + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" + grpc "google.golang.org/grpc" + "google.golang.org/protobuf/testing/protocmp" + + "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/internal/testutil" + databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" + "github.com/pomerium/pomerium/pkg/storage" +) + +func TestSyncQuerier(t *testing.T) { + t.Parallel() + + ctx := testutil.GetContext(t, 10*time.Minute) + cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { + databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider())) + }) + t.Cleanup(func() { cc.Close() }) + + client := databrokerpb.NewDataBrokerServiceClient(cc) + + q1r1 := &databrokerpb.Record{ + Type: "t1", + Id: "r1", + Data: protoutil.ToAny("q1"), + } + q1r2 := &databrokerpb.Record{ + Type: "t2", + Id: "r2", + Data: protoutil.ToAny("q2"), + } + q1 := storage.NewStaticQuerier(q1r1, q1r2) + + q2r1 := &databrokerpb.Record{ + Type: "t1", + Id: "r1", + Data: protoutil.ToAny("q2"), + } + _, err := client.Put(ctx, &databrokerpb.PutRequest{ + Records: []*databrokerpb.Record{q2r1}, + }) + require.NoError(t, err) + + q2r2 := &databrokerpb.Record{ + Type: "t1", + Id: "r2", + Data: protoutil.ToAny("q2"), + } + + q2 := storage.NewSyncQuerier(client, "t1", q1) + t.Cleanup(q2.Stop) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + res, err := q2.Query(ctx, &databrokerpb.QueryRequest{ + Type: "t1", + Filter: newStruct(t, map[string]any{ + "id": "r1", + }), + Limit: 1, + }) + if assert.NoError(c, err) && assert.Len(c, res.Records, 1) { + assert.Empty(c, cmp.Diff(q2r1.Data, res.Records[0].Data, protocmp.Transform())) + } + }, time.Second*10, time.Millisecond*50, "should sync records") + + res, err := q2.Query(ctx, &databrokerpb.QueryRequest{ + Type: "t2", + Filter: newStruct(t, map[string]any{ + "id": "r2", + }), + Limit: 1, + }) + if assert.NoError(t, err) && assert.Len(t, res.Records, 1) { + assert.Empty(t, cmp.Diff(q1r2.Data, res.Records[0].Data, protocmp.Transform()), + "should use fallback querier for other record types") + } + + _, err = client.Put(ctx, &databrokerpb.PutRequest{ + Records: []*databrokerpb.Record{q2r2}, + }) + require.NoError(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + res, err := q2.Query(ctx, &databrokerpb.QueryRequest{ + Type: "t1", + Filter: newStruct(t, map[string]any{ + "id": "r2", + }), + Limit: 1, + }) + if assert.NoError(c, err) && assert.Len(c, res.Records, 1) { + assert.Empty(c, cmp.Diff(q2r2.Data, res.Records[0].Data, protocmp.Transform())) + } + }, time.Second*10, time.Millisecond*50, "should pick up changes") +} diff --git a/pkg/storage/querier_typed.go b/pkg/storage/querier_typed.go new file mode 100644 index 000000000..dee55882a --- /dev/null +++ b/pkg/storage/querier_typed.go @@ -0,0 +1,45 @@ +package storage + +import ( + "context" + + grpc "google.golang.org/grpc" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +type typedQuerier struct { + defaultQuerier Querier + queriersByType map[string]Querier +} + +// NewTypedQuerier creates a new Querier that dispatches to other queries based on the type. +func NewTypedQuerier(defaultQuerier Querier, queriersByType map[string]Querier) Querier { + return &typedQuerier{ + defaultQuerier: defaultQuerier, + queriersByType: queriersByType, + } +} + +func (q *typedQuerier) InvalidateCache(ctx context.Context, req *databroker.QueryRequest) { + qq, ok := q.queriersByType[req.Type] + if !ok { + qq = q.defaultQuerier + } + qq.InvalidateCache(ctx, req) +} + +func (q *typedQuerier) Query(ctx context.Context, req *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { + qq, ok := q.queriersByType[req.Type] + if !ok { + qq = q.defaultQuerier + } + return qq.Query(ctx, req, opts...) +} + +func (q *typedQuerier) Stop() { + q.defaultQuerier.Stop() + for _, qq := range q.queriersByType { + qq.Stop() + } +} diff --git a/pkg/storage/record_collection.go b/pkg/storage/record_collection.go index 8ea1cd12f..a63fd2a99 100644 --- a/pkg/storage/record_collection.go +++ b/pkg/storage/record_collection.go @@ -20,6 +20,8 @@ import ( type RecordCollection interface { // All returns all of the databroker records as a slice. The slice is in insertion order. All() []*databroker.Record + // Clear removes all the records from the collection. + Clear() // Get returns a record based on the record id. Get(recordID string) (*databroker.Record, bool) // Len returns the number of records stored in the collection. @@ -65,6 +67,12 @@ func (c *recordCollection) All() []*databroker.Record { return l } +func (c *recordCollection) Clear() { + c.cidrIndex = bart.Table[[]string]{} + clear(c.records) + c.insertionOrder = list.New() +} + func (c *recordCollection) Get(recordID string) (*databroker.Record, bool) { node, ok := c.records[recordID] if !ok { diff --git a/pkg/storage/record_collection_test.go b/pkg/storage/record_collection_test.go index 0e3d6dba9..985194c3e 100644 --- a/pkg/storage/record_collection_test.go +++ b/pkg/storage/record_collection_test.go @@ -21,7 +21,7 @@ func TestRecordCollection(t *testing.T) { r1 := &databroker.Record{ Id: "r1", - Data: newStructData(t, map[string]any{ + Data: newStructAny(t, map[string]any{ "$index": map[string]any{ "cidr": "10.0.0.0/24", }, @@ -29,7 +29,7 @@ func TestRecordCollection(t *testing.T) { } r2 := &databroker.Record{ Id: "r2", - Data: newStructData(t, map[string]any{ + Data: newStructAny(t, map[string]any{ "$index": map[string]any{ "cidr": "192.168.0.0/24", }, @@ -37,7 +37,7 @@ func TestRecordCollection(t *testing.T) { } r3 := &databroker.Record{ Id: "r3", - Data: newStructData(t, map[string]any{ + Data: newStructAny(t, map[string]any{ "$index": map[string]any{ "cidr": "10.0.0.0/16", }, @@ -45,7 +45,7 @@ func TestRecordCollection(t *testing.T) { } r4 := &databroker.Record{ Id: "r4", - Data: newStructData(t, map[string]any{ + Data: newStructAny(t, map[string]any{ "$index": map[string]any{ "cidr": "10.0.0.0/24", }, @@ -123,11 +123,16 @@ func TestRecordCollection(t *testing.T) { assert.Empty(t, cmp.Diff([]*databroker.Record{r3}, rs, protocmp.Transform())) } -func newStructData(t *testing.T, m map[string]any) *anypb.Any { +func newStruct(t *testing.T, m map[string]any) *structpb.Struct { t.Helper() - s, err := structpb.NewStruct(m) require.NoError(t, err) + return s +} +func newStructAny(t *testing.T, m map[string]any) *anypb.Any { + t.Helper() + s, err := structpb.NewStruct(m) + require.NoError(t, err) return protoutil.NewAny(s) }