diff --git a/authorize/databroker.go b/authorize/databroker.go index a65db4e27..bdef56ac5 100644 --- a/authorize/databroker.go +++ b/authorize/databroker.go @@ -20,7 +20,7 @@ func getDataBrokerRecord( ctx context.Context, recordType string, recordID string, - lowestRecordVersion uint64, + invalidate func(*databroker.Record) bool, ) (*databroker.Record, error) { q := storage.GetQuerier(ctx) @@ -38,14 +38,13 @@ func getDataBrokerRecord( return nil, storage.ErrNotFound } - // if the current record version is less than the lowest we'll accept, invalidate the cache - if res.GetRecords()[0].GetVersion() < lowestRecordVersion { - q.InvalidateCache(ctx, req) - } else { + // Check to see if we should invalidate the cache. + if invalidate == nil || !invalidate(res.GetRecords()[0]) { return res.GetRecords()[0], nil } // retry with an up to date cache + q.InvalidateCache(ctx, req) res, err = q.Query(ctx, req) if err != nil { return nil, err @@ -65,22 +64,30 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount( ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount") defer span.End() - record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion) + invalidate := func(record *databroker.Record) bool { + // if the current record version is less than the lowest we'll accept, invalidate the cache + if record.GetVersion() < dataBrokerRecordVersion { + return true + } + + // or if the session or service account is invalid, invalidate the cache + _, err := validateSessionOrServiceAccount(record) + return err != nil + } + + record, err := getDataBrokerRecord( + ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, invalidate) if storage.IsNotFound(err) { - record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion) + record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, invalidate) } if err != nil { return nil, err } - msg, err := record.GetData().UnmarshalNew() + s, err = validateSessionOrServiceAccount(record) if err != nil { return nil, err } - s = msg.(sessionOrServiceAccount) - if err := s.Validate(); err != nil { - return nil, err - } if _, ok := s.(*session.Session); ok { a.accessTracker.TrackSessionAccess(sessionID) @@ -91,6 +98,18 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount( return s, nil } +func validateSessionOrServiceAccount(record *databroker.Record) (sessionOrServiceAccount, error) { + msg, err := record.GetData().UnmarshalNew() + if err != nil { + return nil, err + } + s := msg.(sessionOrServiceAccount) + if err := s.Validate(); err != nil { + return nil, err + } + return s, nil +} + func (a *Authorize) getDataBrokerUser( ctx context.Context, userID string, @@ -98,7 +117,7 @@ func (a *Authorize) getDataBrokerUser( ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerUser") defer span.End() - record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0) + record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, nil) if err != nil { return nil, err } diff --git a/authorize/databroker_test.go b/authorize/databroker_test.go index a9e441c21..1ab6beb9b 100644 --- a/authorize/databroker_test.go +++ b/authorize/databroker_test.go @@ -11,6 +11,7 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/storage" @@ -42,11 +43,15 @@ func Test_getDataBrokerRecord(t *testing.T) { tcq := storage.NewTracingQuerier(cq) qctx := storage.WithQuerier(ctx, tcq) - s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) + invalidate := func(record *databroker.Record) bool { + return record.GetVersion() < tc.queryVersion + } + + s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), invalidate) assert.NoError(t, err) assert.NotNil(t, s) - s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) + s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), invalidate) assert.NoError(t, err) assert.NotNil(t, s)