mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 02:46:30 +02:00
authorize: fix device synchronization (#3482)
This commit is contained in:
parent
bc078f8bd2
commit
fe61a74e1b
3 changed files with 86 additions and 6 deletions
|
@ -4,8 +4,10 @@ import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,20 +15,69 @@ type sessionOrServiceAccount interface {
|
||||||
GetUserId() string
|
GetUserId() string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authorize) getDataBrokerSessionOrServiceAccount(ctx context.Context, sessionID string) (s sessionOrServiceAccount, err error) {
|
func getDataBrokerRecord(
|
||||||
|
ctx context.Context,
|
||||||
|
recordType string,
|
||||||
|
recordID string,
|
||||||
|
lowestRecordVersion uint64,
|
||||||
|
) (*databroker.Record, error) {
|
||||||
|
q := storage.GetQuerier(ctx)
|
||||||
|
|
||||||
|
req := &databroker.QueryRequest{
|
||||||
|
Type: recordType,
|
||||||
|
Limit: 1,
|
||||||
|
}
|
||||||
|
req.SetFilterByIDOrIndex(recordID)
|
||||||
|
|
||||||
|
res, err := q.Query(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(res.GetRecords()) == 0 {
|
||||||
|
return nil, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the current record version is less than the lowest we'll accept, invalidate the cache
|
||||||
|
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
|
||||||
|
q.InvalidateCache(ctx, req)
|
||||||
|
} else {
|
||||||
|
return res.GetRecords()[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// retry with an up to date cache
|
||||||
|
res, err = q.Query(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(res.GetRecords()) == 0 {
|
||||||
|
return nil, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.GetRecords()[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
||||||
|
ctx context.Context,
|
||||||
|
sessionID string,
|
||||||
|
dataBrokerRecordVersion uint64,
|
||||||
|
) (s sessionOrServiceAccount, err error) {
|
||||||
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
|
ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
client := a.state.Load().dataBrokerClient
|
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
|
||||||
|
|
||||||
s, err = session.Get(ctx, client, sessionID)
|
|
||||||
if storage.IsNotFound(err) {
|
if storage.IsNotFound(err) {
|
||||||
s, err = user.GetServiceAccount(ctx, client, sessionID)
|
record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
msg, err := record.GetData().UnmarshalNew()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s = msg.(sessionOrServiceAccount)
|
||||||
|
|
||||||
if _, ok := s.(*session.Session); ok {
|
if _, ok := s.(*session.Session); ok {
|
||||||
a.accessTracker.TrackSessionAccess(sessionID)
|
a.accessTracker.TrackSessionAccess(sessionID)
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,7 +58,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
||||||
var s sessionOrServiceAccount
|
var s sessionOrServiceAccount
|
||||||
var u *user.User
|
var u *user.User
|
||||||
if sessionState != nil {
|
if sessionState != nil {
|
||||||
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID)
|
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(ctx).Err(err).Msg("clearing session due to missing session or service account")
|
log.Warn(ctx).Err(err).Msg("clearing session due to missing session or service account")
|
||||||
sessionState = nil
|
sessionState = nil
|
||||||
|
|
|
@ -2,6 +2,7 @@ package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
@ -19,12 +20,15 @@ import (
|
||||||
|
|
||||||
// A Querier is a read-only subset of the client methods
|
// A Querier is a read-only subset of the client methods
|
||||||
type Querier interface {
|
type Querier interface {
|
||||||
|
InvalidateCache(ctx context.Context, in *databroker.QueryRequest)
|
||||||
Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error)
|
Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// nilQuerier always returns NotFound.
|
// nilQuerier always returns NotFound.
|
||||||
type nilQuerier struct{}
|
type nilQuerier struct{}
|
||||||
|
|
||||||
|
func (nilQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {}
|
||||||
|
|
||||||
func (nilQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
func (nilQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||||
return nil, status.Error(codes.NotFound, "not found")
|
return nil, status.Error(codes.NotFound, "not found")
|
||||||
}
|
}
|
||||||
|
@ -63,11 +67,18 @@ func NewStaticQuerier(msgs ...proto.Message) Querier {
|
||||||
if hasID, ok := msg.(interface{ GetId() string }); ok {
|
if hasID, ok := msg.(interface{ GetId() string }); ok {
|
||||||
record.Id = hasID.GetId()
|
record.Id = hasID.GetId()
|
||||||
}
|
}
|
||||||
|
if hasVersion, ok := msg.(interface{ GetVersion() string }); ok {
|
||||||
|
if v, err := strconv.ParseUint(hasVersion.GetVersion(), 10, 64); err == nil {
|
||||||
|
record.Version = v
|
||||||
|
}
|
||||||
|
}
|
||||||
getter.records = append(getter.records, record)
|
getter.records = append(getter.records, record)
|
||||||
}
|
}
|
||||||
return getter
|
return getter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *staticQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {}
|
||||||
|
|
||||||
// Query queries for records.
|
// Query queries for records.
|
||||||
func (q *staticQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
func (q *staticQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||||
expr, err := FilterExpressionFromStruct(in.GetFilter())
|
expr, err := FilterExpressionFromStruct(in.GetFilter())
|
||||||
|
@ -116,6 +127,8 @@ func NewQuerier(client databroker.DataBrokerServiceClient) Querier {
|
||||||
return &clientQuerier{client: client}
|
return &clientQuerier{client: client}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *clientQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {}
|
||||||
|
|
||||||
// Query queries for records.
|
// Query queries for records.
|
||||||
func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||||
return q.client.Query(ctx, in, opts...)
|
return q.client.Query(ctx, in, opts...)
|
||||||
|
@ -145,6 +158,11 @@ func NewTracingQuerier(q Querier) *TracingQuerier {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InvalidateCache invalidates the cache.
|
||||||
|
func (q *TracingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {
|
||||||
|
q.underlying.InvalidateCache(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
// Query queries for records.
|
// Query queries for records.
|
||||||
func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||||
res, err := q.underlying.Query(ctx, in, opts...)
|
res, err := q.underlying.Query(ctx, in, opts...)
|
||||||
|
@ -182,6 +200,17 @@ func NewCachingQuerier(q Querier, cache Cache) Querier {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *cachingQuerier) InvalidateCache(ctx context.Context, in *databroker.QueryRequest) {
|
||||||
|
key, err := (&proto.MarshalOptions{
|
||||||
|
Deterministic: true,
|
||||||
|
}).Marshal(in)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
q.cache.Invalidate(key)
|
||||||
|
q.q.InvalidateCache(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||||
key, err := (&proto.MarshalOptions{
|
key, err := (&proto.MarshalOptions{
|
||||||
Deterministic: true,
|
Deterministic: true,
|
||||||
|
|
Loading…
Add table
Reference in a new issue