storage: invalidate sync querier when records are updated (#5612)

## Summary
Invalidate the sync querier when records are updated so that we fallback
to databroker querying until the sync is complete.

## Related issues
For
[ENG-2377](https://linear.app/pomerium/issue/ENG-2377/core-initial-access-with-idp-accessidentity-tokens-sometimes-fails)


## Checklist

- [x] reference any related issues
- [x] updated unit tests
- [x] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [x] ready for review
This commit is contained in:
Caleb Doxsey 2025-05-12 13:45:36 -06:00 committed by GitHub
parent f6b344fd9e
commit ba0fcffe81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 68 additions and 13 deletions

View file

@ -140,13 +140,13 @@ func (a *Authorize) maybeGetSessionFromRequest(
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0) return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
}, },
func(ctx context.Context, records []*databroker.Record) error { func(ctx context.Context, records []*databroker.Record) error {
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{ res, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: records, Records: records,
}) })
if err != nil { if err != nil {
return err return err
} }
storage.InvalidateCacheForDataBrokerRecords(ctx, records...) storage.InvalidateCacheForDataBrokerRecords(ctx, res.Records...)
return nil return nil
}, },
).CreateSession(ctx, a.currentConfig.Load(), policy, hreq) ).CreateSession(ctx, a.currentConfig.Load(), policy, hreq)

View file

@ -157,6 +157,7 @@ func InvalidateCacheForDataBrokerRecords(
q := &databroker.QueryRequest{ q := &databroker.QueryRequest{
Type: record.GetType(), Type: record.GetType(),
Limit: 1, Limit: 1,
MinimumRecordVersionHint: proto.Uint64(record.GetVersion()),
} }
q.SetFilterByIDOrIndex(record.GetId()) q.SetFilterByIDOrIndex(record.GetId())
GetQuerier(ctx).InvalidateCache(ctx, q) GetQuerier(ctx).InvalidateCache(ctx, q)

View file

@ -22,12 +22,13 @@ type syncQuerier struct {
recordType string recordType string
cancel context.CancelFunc cancel context.CancelFunc
serverVersion uint64
latestRecordVersion uint64
mu sync.RWMutex mu sync.RWMutex
ready bool ready bool
records RecordCollection records RecordCollection
serverVersion uint64
minimumRecordVersion uint64
latestRecordVersion uint64
} }
// NewSyncQuerier creates a new Querier backed by an in-memory record collection // NewSyncQuerier creates a new Querier backed by an in-memory record collection
@ -49,8 +50,15 @@ func NewSyncQuerier(
return q return q
} }
func (q *syncQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) { func (q *syncQuerier) InvalidateCache(_ context.Context, req *databroker.QueryRequest) {
// do nothing v := req.MinimumRecordVersionHint
if v == nil {
return
}
q.mu.Lock()
q.minimumRecordVersion = max(q.minimumRecordVersion, *v)
q.mu.Unlock()
} }
func (q *syncQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) { func (q *syncQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
@ -76,6 +84,11 @@ func (q *syncQuerier) canHandleQueryLocked(req *databroker.QueryRequest) bool {
if req.GetType() != q.recordType { if req.GetType() != q.recordType {
return false return false
} }
// if the latest record version hasn't reached the minimum version our sync is out-of-date
// so we can't handle queries
if q.latestRecordVersion < q.minimumRecordVersion {
return false
}
if req.MinimumRecordVersionHint != nil && q.latestRecordVersion < *req.MinimumRecordVersionHint { if req.MinimumRecordVersionHint != nil && q.latestRecordVersion < *req.MinimumRecordVersionHint {
return false return false
} }
@ -170,6 +183,7 @@ func (q *syncQuerier) sync(ctx context.Context) error {
q.mu.Lock() q.mu.Lock()
q.serverVersion = 0 q.serverVersion = 0
q.latestRecordVersion = 0 q.latestRecordVersion = 0
q.minimumRecordVersion = 0
q.mu.Unlock() q.mu.Unlock()
return fmt.Errorf("stream was aborted due to mismatched server versions: %w", err) return fmt.Errorf("stream was aborted due to mismatched server versions: %w", err)
} else if err != nil { } else if err != nil {

View file

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace/noop" "go.opentelemetry.io/otel/trace/noop"
grpc "google.golang.org/grpc" grpc "google.golang.org/grpc"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/structpb"
@ -46,6 +47,12 @@ func TestSyncQuerier(t *testing.T) {
Data: protoutil.ToAny("q2"), Data: protoutil.ToAny("q2"),
} }
r2a := &databrokerpb.Record{
Type: "t1",
Id: "r2",
Data: protoutil.ToAny("q2a"),
}
q := storage.NewSyncQuerier(client, "t1") q := storage.NewSyncQuerier(client, "t1")
t.Cleanup(q.Stop) t.Cleanup(q.Stop)
@ -62,7 +69,7 @@ func TestSyncQuerier(t *testing.T) {
} }
}, time.Second*10, time.Millisecond*50, "should sync records") }, time.Second*10, time.Millisecond*50, "should sync records")
_, err = client.Put(ctx, &databrokerpb.PutRequest{ res, err := client.Put(ctx, &databrokerpb.PutRequest{
Records: []*databrokerpb.Record{r2}, Records: []*databrokerpb.Record{r2},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -79,6 +86,39 @@ func TestSyncQuerier(t *testing.T) {
assert.Empty(c, cmp.Diff(r2.Data, res.Records[0].Data, protocmp.Transform())) assert.Empty(c, cmp.Diff(r2.Data, res.Records[0].Data, protocmp.Transform()))
} }
}, time.Second*10, time.Millisecond*50, "should pick up changes") }, time.Second*10, time.Millisecond*50, "should pick up changes")
q.InvalidateCache(ctx, &databrokerpb.QueryRequest{
Type: "t1",
MinimumRecordVersionHint: proto.Uint64(res.GetRecord().GetVersion() + 1),
})
_, err = q.Query(ctx, &databrokerpb.QueryRequest{
Type: "t1",
Filter: newStruct(t, map[string]any{
"id": "r2",
}),
Limit: 1,
})
assert.ErrorIs(t, err, storage.ErrUnavailable,
"should return unavailable until record is updated")
res, err = client.Put(ctx, &databrokerpb.PutRequest{
Records: []*databrokerpb.Record{r2a},
})
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
res, err := q.Query(ctx, &databrokerpb.QueryRequest{
Type: "t1",
Filter: newStruct(t, map[string]any{
"id": "r2",
}),
Limit: 1,
})
if assert.NoError(c, err) && assert.Len(c, res.Records, 1) {
assert.Empty(c, cmp.Diff(r2a.Data, res.Records[0].Data, protocmp.Transform()))
}
}, time.Second*10, time.Millisecond*50, "should pick up changes after invalidation")
} }
func newStruct(t *testing.T, m map[string]any) *structpb.Struct { func newStruct(t *testing.T, m map[string]any) *structpb.Struct {