storage: invalidate sync querier when records are updated (#5612)

## Summary
Invalidate the sync querier when records are updated so that we fallback
to databroker querying until the sync is complete.

## Related issues
For
[ENG-2377](https://linear.app/pomerium/issue/ENG-2377/core-initial-access-with-idp-accessidentity-tokens-sometimes-fails)


## Checklist

- [x] reference any related issues
- [x] updated unit tests
- [x] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [x] ready for review
This commit is contained in:
Caleb Doxsey 2025-05-12 13:45:36 -06:00 committed by GitHub
parent f6b344fd9e
commit ba0fcffe81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 68 additions and 13 deletions

View file

@ -140,13 +140,13 @@ func (a *Authorize) maybeGetSessionFromRequest(
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0) return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
}, },
func(ctx context.Context, records []*databroker.Record) error { func(ctx context.Context, records []*databroker.Record) error {
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{ res, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: records, Records: records,
}) })
if err != nil { if err != nil {
return err return err
} }
storage.InvalidateCacheForDataBrokerRecords(ctx, records...) storage.InvalidateCacheForDataBrokerRecords(ctx, res.Records...)
return nil return nil
}, },
).CreateSession(ctx, a.currentConfig.Load(), policy, hreq) ).CreateSession(ctx, a.currentConfig.Load(), policy, hreq)

View file

@ -155,8 +155,9 @@ func InvalidateCacheForDataBrokerRecords(
) { ) {
for _, record := range records { for _, record := range records {
q := &databroker.QueryRequest{ q := &databroker.QueryRequest{
Type: record.GetType(), Type: record.GetType(),
Limit: 1, Limit: 1,
MinimumRecordVersionHint: proto.Uint64(record.GetVersion()),
} }
q.SetFilterByIDOrIndex(record.GetId()) q.SetFilterByIDOrIndex(record.GetId())
GetQuerier(ctx).InvalidateCache(ctx, q) GetQuerier(ctx).InvalidateCache(ctx, q)

View file

@ -21,13 +21,14 @@ type syncQuerier struct {
client databroker.DataBrokerServiceClient client databroker.DataBrokerServiceClient
recordType string recordType string
cancel context.CancelFunc cancel context.CancelFunc
serverVersion uint64
latestRecordVersion uint64
mu sync.RWMutex mu sync.RWMutex
ready bool ready bool
records RecordCollection records RecordCollection
serverVersion uint64
minimumRecordVersion uint64
latestRecordVersion uint64
} }
// NewSyncQuerier creates a new Querier backed by an in-memory record collection // NewSyncQuerier creates a new Querier backed by an in-memory record collection
@ -49,8 +50,15 @@ func NewSyncQuerier(
return q return q
} }
func (q *syncQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) { func (q *syncQuerier) InvalidateCache(_ context.Context, req *databroker.QueryRequest) {
// do nothing v := req.MinimumRecordVersionHint
if v == nil {
return
}
q.mu.Lock()
q.minimumRecordVersion = max(q.minimumRecordVersion, *v)
q.mu.Unlock()
} }
func (q *syncQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) { func (q *syncQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
@ -76,6 +84,11 @@ func (q *syncQuerier) canHandleQueryLocked(req *databroker.QueryRequest) bool {
if req.GetType() != q.recordType { if req.GetType() != q.recordType {
return false return false
} }
// if the latest record version hasn't reached the minimum version our sync is out-of-date
// so we can't handle queries
if q.latestRecordVersion < q.minimumRecordVersion {
return false
}
if req.MinimumRecordVersionHint != nil && q.latestRecordVersion < *req.MinimumRecordVersionHint { if req.MinimumRecordVersionHint != nil && q.latestRecordVersion < *req.MinimumRecordVersionHint {
return false return false
} }
@ -170,6 +183,7 @@ func (q *syncQuerier) sync(ctx context.Context) error {
q.mu.Lock() q.mu.Lock()
q.serverVersion = 0 q.serverVersion = 0
q.latestRecordVersion = 0 q.latestRecordVersion = 0
q.minimumRecordVersion = 0
q.mu.Unlock() q.mu.Unlock()
return fmt.Errorf("stream was aborted due to mismatched server versions: %w", err) return fmt.Errorf("stream was aborted due to mismatched server versions: %w", err)
} else if err != nil { } else if err != nil {

View file

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace/noop" "go.opentelemetry.io/otel/trace/noop"
grpc "google.golang.org/grpc" grpc "google.golang.org/grpc"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/structpb"
@ -46,6 +47,12 @@ func TestSyncQuerier(t *testing.T) {
Data: protoutil.ToAny("q2"), Data: protoutil.ToAny("q2"),
} }
r2a := &databrokerpb.Record{
Type: "t1",
Id: "r2",
Data: protoutil.ToAny("q2a"),
}
q := storage.NewSyncQuerier(client, "t1") q := storage.NewSyncQuerier(client, "t1")
t.Cleanup(q.Stop) t.Cleanup(q.Stop)
@ -62,7 +69,7 @@ func TestSyncQuerier(t *testing.T) {
} }
}, time.Second*10, time.Millisecond*50, "should sync records") }, time.Second*10, time.Millisecond*50, "should sync records")
_, err = client.Put(ctx, &databrokerpb.PutRequest{ res, err := client.Put(ctx, &databrokerpb.PutRequest{
Records: []*databrokerpb.Record{r2}, Records: []*databrokerpb.Record{r2},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -79,6 +86,39 @@ func TestSyncQuerier(t *testing.T) {
assert.Empty(c, cmp.Diff(r2.Data, res.Records[0].Data, protocmp.Transform())) assert.Empty(c, cmp.Diff(r2.Data, res.Records[0].Data, protocmp.Transform()))
} }
}, time.Second*10, time.Millisecond*50, "should pick up changes") }, time.Second*10, time.Millisecond*50, "should pick up changes")
q.InvalidateCache(ctx, &databrokerpb.QueryRequest{
Type: "t1",
MinimumRecordVersionHint: proto.Uint64(res.GetRecord().GetVersion() + 1),
})
_, err = q.Query(ctx, &databrokerpb.QueryRequest{
Type: "t1",
Filter: newStruct(t, map[string]any{
"id": "r2",
}),
Limit: 1,
})
assert.ErrorIs(t, err, storage.ErrUnavailable,
"should return unavailable until record is updated")
res, err = client.Put(ctx, &databrokerpb.PutRequest{
Records: []*databrokerpb.Record{r2a},
})
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
res, err := q.Query(ctx, &databrokerpb.QueryRequest{
Type: "t1",
Filter: newStruct(t, map[string]any{
"id": "r2",
}),
Limit: 1,
})
if assert.NoError(c, err) && assert.Len(c, res.Records, 1) {
assert.Empty(c, cmp.Diff(r2a.Data, res.Records[0].Data, protocmp.Transform()))
}
}, time.Second*10, time.Millisecond*50, "should pick up changes after invalidation")
} }
func newStruct(t *testing.T, m map[string]any) *structpb.Struct { func newStruct(t *testing.T, m map[string]any) *structpb.Struct {