authorize: retry session fetch if invalid

When the authorize service fetches a session or service account record
from the databroker, if such record is already invalid, bypass the
databroker cache and attempt the fetch again.
This commit is contained in:
Kenneth Jenkins 2023-10-12 15:14:44 -07:00
parent 5f9f46652a
commit 318076c2bf
2 changed files with 39 additions and 15 deletions

View file

@ -20,7 +20,7 @@ func getDataBrokerRecord(
ctx context.Context, ctx context.Context,
recordType string, recordType string,
recordID string, recordID string,
lowestRecordVersion uint64, invalidate func(*databroker.Record) bool,
) (*databroker.Record, error) { ) (*databroker.Record, error) {
q := storage.GetQuerier(ctx) q := storage.GetQuerier(ctx)
@ -38,14 +38,13 @@ func getDataBrokerRecord(
return nil, storage.ErrNotFound return nil, storage.ErrNotFound
} }
// if the current record version is less than the lowest we'll accept, invalidate the cache // Check to see if we should invalidate the cache.
if res.GetRecords()[0].GetVersion() < lowestRecordVersion { if invalidate == nil || !invalidate(res.GetRecords()[0]) {
q.InvalidateCache(ctx, req)
} else {
return res.GetRecords()[0], nil return res.GetRecords()[0], nil
} }
// retry with an up to date cache // retry with an up to date cache
q.InvalidateCache(ctx, req)
res, err = q.Query(ctx, req) res, err = q.Query(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
@ -65,22 +64,30 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount(
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount") ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
defer span.End() defer span.End()
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion) invalidate := func(record *databroker.Record) bool {
// if the current record version is less than the lowest we'll accept, invalidate the cache
if record.GetVersion() < dataBrokerRecordVersion {
return true
}
// or if the session or service account is invalid, invalidate the cache
_, err := validateSessionOrServiceAccount(record)
return err != nil
}
record, err := getDataBrokerRecord(
ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, invalidate)
if storage.IsNotFound(err) { if storage.IsNotFound(err) {
record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion) record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, invalidate)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg, err := record.GetData().UnmarshalNew() s, err = validateSessionOrServiceAccount(record)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s = msg.(sessionOrServiceAccount)
if err := s.Validate(); err != nil {
return nil, err
}
if _, ok := s.(*session.Session); ok { if _, ok := s.(*session.Session); ok {
a.accessTracker.TrackSessionAccess(sessionID) a.accessTracker.TrackSessionAccess(sessionID)
@ -91,6 +98,18 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount(
return s, nil return s, nil
} }
func validateSessionOrServiceAccount(record *databroker.Record) (sessionOrServiceAccount, error) {
msg, err := record.GetData().UnmarshalNew()
if err != nil {
return nil, err
}
s := msg.(sessionOrServiceAccount)
if err := s.Validate(); err != nil {
return nil, err
}
return s, nil
}
func (a *Authorize) getDataBrokerUser( func (a *Authorize) getDataBrokerUser(
ctx context.Context, ctx context.Context,
userID string, userID string,
@ -98,7 +117,7 @@ func (a *Authorize) getDataBrokerUser(
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerUser") ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerUser")
defer span.End() defer span.End()
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0) record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -11,6 +11,7 @@ import (
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
@ -42,11 +43,15 @@ func Test_getDataBrokerRecord(t *testing.T) {
tcq := storage.NewTracingQuerier(cq) tcq := storage.NewTracingQuerier(cq)
qctx := storage.WithQuerier(ctx, tcq) qctx := storage.WithQuerier(ctx, tcq)
s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) invalidate := func(record *databroker.Record) bool {
return record.GetVersion() < tc.queryVersion
}
s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), invalidate)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, s) assert.NotNil(t, s)
s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), invalidate)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, s) assert.NotNil(t, s)