From ba0fcffe81f27f33dd5aa858dba3a1276208ddcd Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Mon, 12 May 2025 13:45:36 -0600 Subject: [PATCH] storage: invalidate sync querier when records are updated (#5612) ## Summary Invalidate the sync querier when records are updated so that we fallback to databroker querying until the sync is complete. ## Related issues For [ENG-2377](https://linear.app/pomerium/issue/ENG-2377/core-initial-access-with-idp-accessidentity-tokens-sometimes-fails) ## Checklist - [x] reference any related issues - [x] updated unit tests - [x] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [x] ready for review --- authorize/grpc.go | 4 +-- pkg/storage/querier.go | 5 ++-- pkg/storage/querier_sync.go | 30 +++++++++++++++++------ pkg/storage/querier_sync_test.go | 42 +++++++++++++++++++++++++++++++- 4 files changed, 68 insertions(+), 13 deletions(-) diff --git a/authorize/grpc.go b/authorize/grpc.go index 792370fa9..a3b3b4300 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -140,13 +140,13 @@ func (a *Authorize) maybeGetSessionFromRequest( return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0) }, func(ctx context.Context, records []*databroker.Record) error { - _, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{ + res, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{ Records: records, }) if err != nil { return err } - storage.InvalidateCacheForDataBrokerRecords(ctx, records...) + storage.InvalidateCacheForDataBrokerRecords(ctx, res.Records...) return nil }, ).CreateSession(ctx, a.currentConfig.Load(), policy, hreq) diff --git a/pkg/storage/querier.go b/pkg/storage/querier.go index e4944328f..aa331c0a1 100644 --- a/pkg/storage/querier.go +++ b/pkg/storage/querier.go @@ -155,8 +155,9 @@ func InvalidateCacheForDataBrokerRecords( ) { for _, record := range records { q := &databroker.QueryRequest{ - Type: record.GetType(), - Limit: 1, + Type: record.GetType(), + Limit: 1, + MinimumRecordVersionHint: proto.Uint64(record.GetVersion()), } q.SetFilterByIDOrIndex(record.GetId()) GetQuerier(ctx).InvalidateCache(ctx, q) diff --git a/pkg/storage/querier_sync.go b/pkg/storage/querier_sync.go index c5c7c98e6..d400404fc 100644 --- a/pkg/storage/querier_sync.go +++ b/pkg/storage/querier_sync.go @@ -21,13 +21,14 @@ type syncQuerier struct { client databroker.DataBrokerServiceClient recordType string - cancel context.CancelFunc - serverVersion uint64 - latestRecordVersion uint64 + cancel context.CancelFunc - mu sync.RWMutex - ready bool - records RecordCollection + mu sync.RWMutex + ready bool + records RecordCollection + serverVersion uint64 + minimumRecordVersion uint64 + latestRecordVersion uint64 } // NewSyncQuerier creates a new Querier backed by an in-memory record collection @@ -49,8 +50,15 @@ func NewSyncQuerier( return q } -func (q *syncQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) { - // do nothing +func (q *syncQuerier) InvalidateCache(_ context.Context, req *databroker.QueryRequest) { + v := req.MinimumRecordVersionHint + if v == nil { + return + } + + q.mu.Lock() + q.minimumRecordVersion = max(q.minimumRecordVersion, *v) + q.mu.Unlock() } func (q *syncQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) { @@ -76,6 +84,11 @@ func (q *syncQuerier) canHandleQueryLocked(req *databroker.QueryRequest) bool { if req.GetType() != q.recordType { return false } + // if the latest record version hasn't reached the minimum version our sync is out-of-date + // so we can't handle queries + if q.latestRecordVersion < q.minimumRecordVersion { + return false + } if req.MinimumRecordVersionHint != nil && q.latestRecordVersion < *req.MinimumRecordVersionHint { return false } @@ -170,6 +183,7 @@ func (q *syncQuerier) sync(ctx context.Context) error { q.mu.Lock() q.serverVersion = 0 q.latestRecordVersion = 0 + q.minimumRecordVersion = 0 q.mu.Unlock() return fmt.Errorf("stream was aborted due to mismatched server versions: %w", err) } else if err != nil { diff --git a/pkg/storage/querier_sync_test.go b/pkg/storage/querier_sync_test.go index fc13f18f7..f98a65b9f 100644 --- a/pkg/storage/querier_sync_test.go +++ b/pkg/storage/querier_sync_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace/noop" grpc "google.golang.org/grpc" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/structpb" @@ -46,6 +47,12 @@ func TestSyncQuerier(t *testing.T) { Data: protoutil.ToAny("q2"), } + r2a := &databrokerpb.Record{ + Type: "t1", + Id: "r2", + Data: protoutil.ToAny("q2a"), + } + q := storage.NewSyncQuerier(client, "t1") t.Cleanup(q.Stop) @@ -62,7 +69,7 @@ func TestSyncQuerier(t *testing.T) { } }, time.Second*10, time.Millisecond*50, "should sync records") - _, err = client.Put(ctx, &databrokerpb.PutRequest{ + res, err := client.Put(ctx, &databrokerpb.PutRequest{ Records: []*databrokerpb.Record{r2}, }) require.NoError(t, err) @@ -79,6 +86,39 @@ func TestSyncQuerier(t *testing.T) { assert.Empty(c, cmp.Diff(r2.Data, res.Records[0].Data, protocmp.Transform())) } }, time.Second*10, time.Millisecond*50, "should pick up changes") + + q.InvalidateCache(ctx, &databrokerpb.QueryRequest{ + Type: "t1", + MinimumRecordVersionHint: proto.Uint64(res.GetRecord().GetVersion() + 1), + }) + + _, err = q.Query(ctx, &databrokerpb.QueryRequest{ + Type: "t1", + Filter: newStruct(t, map[string]any{ + "id": "r2", + }), + Limit: 1, + }) + assert.ErrorIs(t, err, storage.ErrUnavailable, + "should return unavailable until record is updated") + + res, err = client.Put(ctx, &databrokerpb.PutRequest{ + Records: []*databrokerpb.Record{r2a}, + }) + require.NoError(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + res, err := q.Query(ctx, &databrokerpb.QueryRequest{ + Type: "t1", + Filter: newStruct(t, map[string]any{ + "id": "r2", + }), + Limit: 1, + }) + if assert.NoError(c, err) && assert.Len(c, res.Records, 1) { + assert.Empty(c, cmp.Diff(r2a.Data, res.Records[0].Data, protocmp.Transform())) + } + }, time.Second*10, time.Millisecond*50, "should pick up changes after invalidation") } func newStruct(t *testing.T, m map[string]any) *structpb.Struct {