mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-24 20:18:13 +02:00
authorize: fix device synchronization (#3482)
This commit is contained in:
parent
bc078f8bd2
commit
fe61a74e1b
3 changed files with 86 additions and 6 deletions
|
@ -4,8 +4,10 @@ import (
|
|||
"context"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
|
@ -13,20 +15,69 @@ type sessionOrServiceAccount interface {
|
|||
GetUserId() string
|
||||
}
|
||||
|
||||
func (a *Authorize) getDataBrokerSessionOrServiceAccount(ctx context.Context, sessionID string) (s sessionOrServiceAccount, err error) {
|
||||
func getDataBrokerRecord(
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
recordID string,
|
||||
lowestRecordVersion uint64,
|
||||
) (*databroker.Record, error) {
|
||||
q := storage.GetQuerier(ctx)
|
||||
|
||||
req := &databroker.QueryRequest{
|
||||
Type: recordType,
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex(recordID)
|
||||
|
||||
res, err := q.Query(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
// if the current record version is less than the lowest we'll accept, invalidate the cache
|
||||
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
|
||||
q.InvalidateCache(ctx, req)
|
||||
} else {
|
||||
return res.GetRecords()[0], nil
|
||||
}
|
||||
|
||||
// retry with an up to date cache
|
||||
res, err = q.Query(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return res.GetRecords()[0], nil
|
||||
}
|
||||
|
||||
func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
||||
ctx context.Context,
|
||||
sessionID string,
|
||||
dataBrokerRecordVersion uint64,
|
||||
) (s sessionOrServiceAccount, err error) {
|
||||
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
|
||||
defer span.End()
|
||||
|
||||
client := a.state.Load().dataBrokerClient
|
||||
|
||||
s, err = session.Get(ctx, client, sessionID)
|
||||
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
|
||||
if storage.IsNotFound(err) {
|
||||
s, err = user.GetServiceAccount(ctx, client, sessionID)
|
||||
record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := record.GetData().UnmarshalNew()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s = msg.(sessionOrServiceAccount)
|
||||
|
||||
if _, ok := s.(*session.Session); ok {
|
||||
a.accessTracker.TrackSessionAccess(sessionID)
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
var s sessionOrServiceAccount
|
||||
var u *user.User
|
||||
if sessionState != nil {
|
||||
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID)
|
||||
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Err(err).Msg("clearing session due to missing session or service account")
|
||||
sessionState = nil
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue