From fe61a74e1b7cb049d94a13a70297e7fa83742afa Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 15 Jul 2022 17:27:06 -0600 Subject: [PATCH] authorize: fix device synchronization (#3482) --- authorize/databroker.go | 61 +++++++++++++++++++++++++++++++++++++---- authorize/grpc.go | 2 +- pkg/storage/querier.go | 29 ++++++++++++++++++++ 3 files changed, 86 insertions(+), 6 deletions(-) diff --git a/authorize/databroker.go b/authorize/databroker.go index e0e278f8e..8d6399b72 100644 --- a/authorize/databroker.go +++ b/authorize/databroker.go @@ -4,8 +4,10 @@ import ( "context" "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/storage" ) @@ -13,20 +15,69 @@ type sessionOrServiceAccount interface { GetUserId() string } -func (a *Authorize) getDataBrokerSessionOrServiceAccount(ctx context.Context, sessionID string) (s sessionOrServiceAccount, err error) { +func getDataBrokerRecord( + ctx context.Context, + recordType string, + recordID string, + lowestRecordVersion uint64, +) (*databroker.Record, error) { + q := storage.GetQuerier(ctx) + + req := &databroker.QueryRequest{ + Type: recordType, + Limit: 1, + } + req.SetFilterByIDOrIndex(recordID) + + res, err := q.Query(ctx, req) + if err != nil { + return nil, err + } + if len(res.GetRecords()) == 0 { + return nil, storage.ErrNotFound + } + + // if the current record version is less than the lowest we'll accept, invalidate the cache + if res.GetRecords()[0].GetVersion() < lowestRecordVersion { + q.InvalidateCache(ctx, req) + } else { + return res.GetRecords()[0], nil + } + + // retry with an up to date cache + res, err = q.Query(ctx, req) + if err != nil { + return nil, err + } + if len(res.GetRecords()) == 0 { + return nil, storage.ErrNotFound + } + + return res.GetRecords()[0], nil +} + +func (a *Authorize) getDataBrokerSessionOrServiceAccount( + ctx context.Context, + sessionID string, + dataBrokerRecordVersion uint64, +) (s sessionOrServiceAccount, err error) { ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount") defer span.End() - client := a.state.Load().dataBrokerClient - - s, err = session.Get(ctx, client, sessionID) + record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion) if storage.IsNotFound(err) { - s, err = user.GetServiceAccount(ctx, client, sessionID) + record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion) } if err != nil { return nil, err } + msg, err := record.GetData().UnmarshalNew() + if err != nil { + return nil, err + } + s = msg.(sessionOrServiceAccount) + if _, ok := s.(*session.Session); ok { a.accessTracker.TrackSessionAccess(sessionID) } diff --git a/authorize/grpc.go b/authorize/grpc.go index 5c0f6b79e..e8a8ddbb1 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -58,7 +58,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe var s sessionOrServiceAccount var u *user.User if sessionState != nil { - s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID) + s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion) if err != nil { log.Warn(ctx).Err(err).Msg("clearing session due to missing session or service account") sessionState = nil diff --git a/pkg/storage/querier.go b/pkg/storage/querier.go index 4c8a992ab..1eac9a2f5 100644 --- a/pkg/storage/querier.go +++ b/pkg/storage/querier.go @@ -2,6 +2,7 @@ package storage import ( "context" + "strconv" "sync" "github.com/google/uuid" @@ -19,12 +20,15 @@ import ( // A Querier is a read-only subset of the client methods type Querier interface { + InvalidateCache(ctx context.Context, in *databroker.QueryRequest) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) } // nilQuerier always returns NotFound. type nilQuerier struct{} +func (nilQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {} + func (nilQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { return nil, status.Error(codes.NotFound, "not found") } @@ -63,11 +67,18 @@ func NewStaticQuerier(msgs ...proto.Message) Querier { if hasID, ok := msg.(interface{ GetId() string }); ok { record.Id = hasID.GetId() } + if hasVersion, ok := msg.(interface{ GetVersion() string }); ok { + if v, err := strconv.ParseUint(hasVersion.GetVersion(), 10, 64); err == nil { + record.Version = v + } + } getter.records = append(getter.records, record) } return getter } +func (q *staticQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {} + // Query queries for records. func (q *staticQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { expr, err := FilterExpressionFromStruct(in.GetFilter()) @@ -116,6 +127,8 @@ func NewQuerier(client databroker.DataBrokerServiceClient) Querier { return &clientQuerier{client: client} } +func (q *clientQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {} + // Query queries for records. func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { return q.client.Query(ctx, in, opts...) @@ -145,6 +158,11 @@ func NewTracingQuerier(q Querier) *TracingQuerier { } } +// InvalidateCache invalidates the cache. +func (q *TracingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) { + q.underlying.InvalidateCache(ctx, in) +} + // Query queries for records. func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { res, err := q.underlying.Query(ctx, in, opts...) @@ -182,6 +200,17 @@ func NewCachingQuerier(q Querier, cache Cache) Querier { } } +func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) { + key, err := (&proto.MarshalOptions{ + Deterministic: true, + }).Marshal(in) + if err != nil { + return + } + q.cache.Invalidate(key) + q.q.InvalidateCache(ctx, in) +} + func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { key, err := (&proto.MarshalOptions{ Deterministic: true,