mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 10:56:28 +02:00
authorize: retry session fetch if invalid
When the authorize service fetches a session or service account record from the databroker, if such record is already invalid, bypass the databroker cache and attempt the fetch again.
This commit is contained in:
parent
5f9f46652a
commit
318076c2bf
2 changed files with 39 additions and 15 deletions
|
@ -20,7 +20,7 @@ func getDataBrokerRecord(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
recordType string,
|
recordType string,
|
||||||
recordID string,
|
recordID string,
|
||||||
lowestRecordVersion uint64,
|
invalidate func(*databroker.Record) bool,
|
||||||
) (*databroker.Record, error) {
|
) (*databroker.Record, error) {
|
||||||
q := storage.GetQuerier(ctx)
|
q := storage.GetQuerier(ctx)
|
||||||
|
|
||||||
|
@ -38,14 +38,13 @@ func getDataBrokerRecord(
|
||||||
return nil, storage.ErrNotFound
|
return nil, storage.ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the current record version is less than the lowest we'll accept, invalidate the cache
|
// Check to see if we should invalidate the cache.
|
||||||
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
|
if invalidate == nil || !invalidate(res.GetRecords()[0]) {
|
||||||
q.InvalidateCache(ctx, req)
|
|
||||||
} else {
|
|
||||||
return res.GetRecords()[0], nil
|
return res.GetRecords()[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// retry with an up to date cache
|
// retry with an up to date cache
|
||||||
|
q.InvalidateCache(ctx, req)
|
||||||
res, err = q.Query(ctx, req)
|
res, err = q.Query(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -65,22 +64,30 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
||||||
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
|
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
|
invalidate := func(record *databroker.Record) bool {
|
||||||
|
// if the current record version is less than the lowest we'll accept, invalidate the cache
|
||||||
|
if record.GetVersion() < dataBrokerRecordVersion {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// or if the session or service account is invalid, invalidate the cache
|
||||||
|
_, err := validateSessionOrServiceAccount(record)
|
||||||
|
return err != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
record, err := getDataBrokerRecord(
|
||||||
|
ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, invalidate)
|
||||||
if storage.IsNotFound(err) {
|
if storage.IsNotFound(err) {
|
||||||
record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
|
record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, invalidate)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := record.GetData().UnmarshalNew()
|
s, err = validateSessionOrServiceAccount(record)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s = msg.(sessionOrServiceAccount)
|
|
||||||
if err := s.Validate(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := s.(*session.Session); ok {
|
if _, ok := s.(*session.Session); ok {
|
||||||
a.accessTracker.TrackSessionAccess(sessionID)
|
a.accessTracker.TrackSessionAccess(sessionID)
|
||||||
|
@ -91,6 +98,18 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateSessionOrServiceAccount(record *databroker.Record) (sessionOrServiceAccount, error) {
|
||||||
|
msg, err := record.GetData().UnmarshalNew()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s := msg.(sessionOrServiceAccount)
|
||||||
|
if err := s.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Authorize) getDataBrokerUser(
|
func (a *Authorize) getDataBrokerUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID string,
|
userID string,
|
||||||
|
@ -98,7 +117,7 @@ func (a *Authorize) getDataBrokerUser(
|
||||||
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerUser")
|
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerUser")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0)
|
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
@ -42,11 +43,15 @@ func Test_getDataBrokerRecord(t *testing.T) {
|
||||||
tcq := storage.NewTracingQuerier(cq)
|
tcq := storage.NewTracingQuerier(cq)
|
||||||
qctx := storage.WithQuerier(ctx, tcq)
|
qctx := storage.WithQuerier(ctx, tcq)
|
||||||
|
|
||||||
s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
invalidate := func(record *databroker.Record) bool {
|
||||||
|
return record.GetVersion() < tc.queryVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), invalidate)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, s)
|
assert.NotNil(t, s)
|
||||||
|
|
||||||
s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), invalidate)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, s)
|
assert.NotNil(t, s)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue