mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-08 13:52:53 +02:00
storage: invalidate sync querier when records are updated (#5612)
## Summary Invalidate the sync querier when records are updated so that we fallback to databroker querying until the sync is complete. ## Related issues For [ENG-2377](https://linear.app/pomerium/issue/ENG-2377/core-initial-access-with-idp-accessidentity-tokens-sometimes-fails) ## Checklist - [x] reference any related issues - [x] updated unit tests - [x] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [x] ready for review
This commit is contained in:
parent
f6b344fd9e
commit
ba0fcffe81
4 changed files with 68 additions and 13 deletions
|
@ -140,13 +140,13 @@ func (a *Authorize) maybeGetSessionFromRequest(
|
||||||
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
|
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||||
},
|
},
|
||||||
func(ctx context.Context, records []*databroker.Record) error {
|
func(ctx context.Context, records []*databroker.Record) error {
|
||||||
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
res, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||||
Records: records,
|
Records: records,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
|
storage.InvalidateCacheForDataBrokerRecords(ctx, res.Records...)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
).CreateSession(ctx, a.currentConfig.Load(), policy, hreq)
|
).CreateSession(ctx, a.currentConfig.Load(), policy, hreq)
|
||||||
|
|
|
@ -155,8 +155,9 @@ func InvalidateCacheForDataBrokerRecords(
|
||||||
) {
|
) {
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
q := &databroker.QueryRequest{
|
q := &databroker.QueryRequest{
|
||||||
Type: record.GetType(),
|
Type: record.GetType(),
|
||||||
Limit: 1,
|
Limit: 1,
|
||||||
|
MinimumRecordVersionHint: proto.Uint64(record.GetVersion()),
|
||||||
}
|
}
|
||||||
q.SetFilterByIDOrIndex(record.GetId())
|
q.SetFilterByIDOrIndex(record.GetId())
|
||||||
GetQuerier(ctx).InvalidateCache(ctx, q)
|
GetQuerier(ctx).InvalidateCache(ctx, q)
|
||||||
|
|
|
@ -21,13 +21,14 @@ type syncQuerier struct {
|
||||||
client databroker.DataBrokerServiceClient
|
client databroker.DataBrokerServiceClient
|
||||||
recordType string
|
recordType string
|
||||||
|
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
serverVersion uint64
|
|
||||||
latestRecordVersion uint64
|
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ready bool
|
ready bool
|
||||||
records RecordCollection
|
records RecordCollection
|
||||||
|
serverVersion uint64
|
||||||
|
minimumRecordVersion uint64
|
||||||
|
latestRecordVersion uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSyncQuerier creates a new Querier backed by an in-memory record collection
|
// NewSyncQuerier creates a new Querier backed by an in-memory record collection
|
||||||
|
@ -49,8 +50,15 @@ func NewSyncQuerier(
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *syncQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {
|
func (q *syncQuerier) InvalidateCache(_ context.Context, req *databroker.QueryRequest) {
|
||||||
// do nothing
|
v := req.MinimumRecordVersionHint
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
q.mu.Lock()
|
||||||
|
q.minimumRecordVersion = max(q.minimumRecordVersion, *v)
|
||||||
|
q.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *syncQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
func (q *syncQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||||
|
@ -76,6 +84,11 @@ func (q *syncQuerier) canHandleQueryLocked(req *databroker.QueryRequest) bool {
|
||||||
if req.GetType() != q.recordType {
|
if req.GetType() != q.recordType {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
// if the latest record version hasn't reached the minimum version our sync is out-of-date
|
||||||
|
// so we can't handle queries
|
||||||
|
if q.latestRecordVersion < q.minimumRecordVersion {
|
||||||
|
return false
|
||||||
|
}
|
||||||
if req.MinimumRecordVersionHint != nil && q.latestRecordVersion < *req.MinimumRecordVersionHint {
|
if req.MinimumRecordVersionHint != nil && q.latestRecordVersion < *req.MinimumRecordVersionHint {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -170,6 +183,7 @@ func (q *syncQuerier) sync(ctx context.Context) error {
|
||||||
q.mu.Lock()
|
q.mu.Lock()
|
||||||
q.serverVersion = 0
|
q.serverVersion = 0
|
||||||
q.latestRecordVersion = 0
|
q.latestRecordVersion = 0
|
||||||
|
q.minimumRecordVersion = 0
|
||||||
q.mu.Unlock()
|
q.mu.Unlock()
|
||||||
return fmt.Errorf("stream was aborted due to mismatched server versions: %w", err)
|
return fmt.Errorf("stream was aborted due to mismatched server versions: %w", err)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.opentelemetry.io/otel/trace/noop"
|
"go.opentelemetry.io/otel/trace/noop"
|
||||||
grpc "google.golang.org/grpc"
|
grpc "google.golang.org/grpc"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/testing/protocmp"
|
"google.golang.org/protobuf/testing/protocmp"
|
||||||
"google.golang.org/protobuf/types/known/structpb"
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
|
||||||
|
@ -46,6 +47,12 @@ func TestSyncQuerier(t *testing.T) {
|
||||||
Data: protoutil.ToAny("q2"),
|
Data: protoutil.ToAny("q2"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r2a := &databrokerpb.Record{
|
||||||
|
Type: "t1",
|
||||||
|
Id: "r2",
|
||||||
|
Data: protoutil.ToAny("q2a"),
|
||||||
|
}
|
||||||
|
|
||||||
q := storage.NewSyncQuerier(client, "t1")
|
q := storage.NewSyncQuerier(client, "t1")
|
||||||
t.Cleanup(q.Stop)
|
t.Cleanup(q.Stop)
|
||||||
|
|
||||||
|
@ -62,7 +69,7 @@ func TestSyncQuerier(t *testing.T) {
|
||||||
}
|
}
|
||||||
}, time.Second*10, time.Millisecond*50, "should sync records")
|
}, time.Second*10, time.Millisecond*50, "should sync records")
|
||||||
|
|
||||||
_, err = client.Put(ctx, &databrokerpb.PutRequest{
|
res, err := client.Put(ctx, &databrokerpb.PutRequest{
|
||||||
Records: []*databrokerpb.Record{r2},
|
Records: []*databrokerpb.Record{r2},
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -79,6 +86,39 @@ func TestSyncQuerier(t *testing.T) {
|
||||||
assert.Empty(c, cmp.Diff(r2.Data, res.Records[0].Data, protocmp.Transform()))
|
assert.Empty(c, cmp.Diff(r2.Data, res.Records[0].Data, protocmp.Transform()))
|
||||||
}
|
}
|
||||||
}, time.Second*10, time.Millisecond*50, "should pick up changes")
|
}, time.Second*10, time.Millisecond*50, "should pick up changes")
|
||||||
|
|
||||||
|
q.InvalidateCache(ctx, &databrokerpb.QueryRequest{
|
||||||
|
Type: "t1",
|
||||||
|
MinimumRecordVersionHint: proto.Uint64(res.GetRecord().GetVersion() + 1),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err = q.Query(ctx, &databrokerpb.QueryRequest{
|
||||||
|
Type: "t1",
|
||||||
|
Filter: newStruct(t, map[string]any{
|
||||||
|
"id": "r2",
|
||||||
|
}),
|
||||||
|
Limit: 1,
|
||||||
|
})
|
||||||
|
assert.ErrorIs(t, err, storage.ErrUnavailable,
|
||||||
|
"should return unavailable until record is updated")
|
||||||
|
|
||||||
|
res, err = client.Put(ctx, &databrokerpb.PutRequest{
|
||||||
|
Records: []*databrokerpb.Record{r2a},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
res, err := q.Query(ctx, &databrokerpb.QueryRequest{
|
||||||
|
Type: "t1",
|
||||||
|
Filter: newStruct(t, map[string]any{
|
||||||
|
"id": "r2",
|
||||||
|
}),
|
||||||
|
Limit: 1,
|
||||||
|
})
|
||||||
|
if assert.NoError(c, err) && assert.Len(c, res.Records, 1) {
|
||||||
|
assert.Empty(c, cmp.Diff(r2a.Data, res.Records[0].Data, protocmp.Transform()))
|
||||||
|
}
|
||||||
|
}, time.Second*10, time.Millisecond*50, "should pick up changes after invalidation")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newStruct(t *testing.T, m map[string]any) *structpb.Struct {
|
func newStruct(t *testing.T, m map[string]any) *structpb.Struct {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue