authorize: fix device synchronization (#3482)

This commit is contained in:
Caleb Doxsey 2022-07-15 17:27:06 -06:00 committed by GitHub
parent bc078f8bd2
commit fe61a74e1b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 86 additions and 6 deletions

View file

@ -4,8 +4,10 @@ import (
"context" "context"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
) )
@ -13,20 +15,69 @@ type sessionOrServiceAccount interface {
GetUserId() string GetUserId() string
} }
func (a *Authorize) getDataBrokerSessionOrServiceAccount(ctx context.Context, sessionID string) (s sessionOrServiceAccount, err error) { func getDataBrokerRecord(
ctx context.Context,
recordType string,
recordID string,
lowestRecordVersion uint64,
) (*databroker.Record, error) {
q := storage.GetQuerier(ctx)
req := &databroker.QueryRequest{
Type: recordType,
Limit: 1,
}
req.SetFilterByIDOrIndex(recordID)
res, err := q.Query(ctx, req)
if err != nil {
return nil, err
}
if len(res.GetRecords()) == 0 {
return nil, storage.ErrNotFound
}
// if the current record version is less than the lowest we'll accept, invalidate the cache
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
q.InvalidateCache(ctx, req)
} else {
return res.GetRecords()[0], nil
}
// retry with an up to date cache
res, err = q.Query(ctx, req)
if err != nil {
return nil, err
}
if len(res.GetRecords()) == 0 {
return nil, storage.ErrNotFound
}
return res.GetRecords()[0], nil
}
func (a *Authorize) getDataBrokerSessionOrServiceAccount(
ctx context.Context,
sessionID string,
dataBrokerRecordVersion uint64,
) (s sessionOrServiceAccount, err error) {
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount") ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
defer span.End() defer span.End()
client := a.state.Load().dataBrokerClient record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
s, err = session.Get(ctx, client, sessionID)
if storage.IsNotFound(err) { if storage.IsNotFound(err) {
s, err = user.GetServiceAccount(ctx, client, sessionID) record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg, err := record.GetData().UnmarshalNew()
if err != nil {
return nil, err
}
s = msg.(sessionOrServiceAccount)
if _, ok := s.(*session.Session); ok { if _, ok := s.(*session.Session); ok {
a.accessTracker.TrackSessionAccess(sessionID) a.accessTracker.TrackSessionAccess(sessionID)
} }

View file

@ -58,7 +58,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
var s sessionOrServiceAccount var s sessionOrServiceAccount
var u *user.User var u *user.User
if sessionState != nil { if sessionState != nil {
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID) s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
if err != nil { if err != nil {
log.Warn(ctx).Err(err).Msg("clearing session due to missing session or service account") log.Warn(ctx).Err(err).Msg("clearing session due to missing session or service account")
sessionState = nil sessionState = nil

View file

@ -2,6 +2,7 @@ package storage
import ( import (
"context" "context"
"strconv"
"sync" "sync"
"github.com/google/uuid" "github.com/google/uuid"
@ -19,12 +20,15 @@ import (
// A Querier is a read-only subset of the client methods // A Querier is a read-only subset of the client methods
type Querier interface { type Querier interface {
InvalidateCache(ctx context.Context, in *databroker.QueryRequest)
Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error)
} }
// nilQuerier always returns NotFound. // nilQuerier always returns NotFound.
type nilQuerier struct{} type nilQuerier struct{}
func (nilQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {}
func (nilQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { func (nilQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
return nil, status.Error(codes.NotFound, "not found") return nil, status.Error(codes.NotFound, "not found")
} }
@ -63,11 +67,18 @@ func NewStaticQuerier(msgs ...proto.Message) Querier {
if hasID, ok := msg.(interface{ GetId() string }); ok { if hasID, ok := msg.(interface{ GetId() string }); ok {
record.Id = hasID.GetId() record.Id = hasID.GetId()
} }
if hasVersion, ok := msg.(interface{ GetVersion() string }); ok {
if v, err := strconv.ParseUint(hasVersion.GetVersion(), 10, 64); err == nil {
record.Version = v
}
}
getter.records = append(getter.records, record) getter.records = append(getter.records, record)
} }
return getter return getter
} }
func (q *staticQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {}
// Query queries for records. // Query queries for records.
func (q *staticQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { func (q *staticQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
expr, err := FilterExpressionFromStruct(in.GetFilter()) expr, err := FilterExpressionFromStruct(in.GetFilter())
@ -116,6 +127,8 @@ func NewQuerier(client databroker.DataBrokerServiceClient) Querier {
return &clientQuerier{client: client} return &clientQuerier{client: client}
} }
func (q *clientQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {}
// Query queries for records. // Query queries for records.
func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
return q.client.Query(ctx, in, opts...) return q.client.Query(ctx, in, opts...)
@ -145,6 +158,11 @@ func NewTracingQuerier(q Querier) *TracingQuerier {
} }
} }
// InvalidateCache invalidates the cache.
func (q *TracingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {
q.underlying.InvalidateCache(ctx, in)
}
// Query queries for records. // Query queries for records.
func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
res, err := q.underlying.Query(ctx, in, opts...) res, err := q.underlying.Query(ctx, in, opts...)
@ -182,6 +200,17 @@ func NewCachingQuerier(q Querier, cache Cache) Querier {
} }
} }
func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {
key, err := (&proto.MarshalOptions{
Deterministic: true,
}).Marshal(in)
if err != nil {
return
}
q.cache.Invalidate(key)
q.q.InvalidateCache(ctx, in)
}
func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
key, err := (&proto.MarshalOptions{ key, err := (&proto.MarshalOptions{
Deterministic: true, Deterministic: true,