mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-14 07:18:21 +02:00
proxy: use querier cache for user info (#5532)
This commit is contained in:
parent
08623ef346
commit
bc263e3ee5
12 changed files with 259 additions and 156 deletions
|
@ -6,7 +6,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
|
@ -31,7 +30,6 @@ type Authorize struct {
|
||||||
store *store.Store
|
store *store.Store
|
||||||
currentConfig *atomicutil.Value[*config.Config]
|
currentConfig *atomicutil.Value[*config.Config]
|
||||||
accessTracker *AccessTracker
|
accessTracker *AccessTracker
|
||||||
globalCache storage.Cache
|
|
||||||
groupsCacheWarmer *cacheWarmer
|
groupsCacheWarmer *cacheWarmer
|
||||||
|
|
||||||
tracerProvider oteltrace.TracerProvider
|
tracerProvider oteltrace.TracerProvider
|
||||||
|
@ -45,7 +43,6 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
||||||
a := &Authorize{
|
a := &Authorize{
|
||||||
currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}),
|
currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}),
|
||||||
store: store.New(),
|
store: store.New(),
|
||||||
globalCache: storage.NewGlobalCache(time.Minute),
|
|
||||||
tracerProvider: tracerProvider,
|
tracerProvider: tracerProvider,
|
||||||
tracer: tracer,
|
tracer: tracer,
|
||||||
}
|
}
|
||||||
|
@ -57,7 +54,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
||||||
}
|
}
|
||||||
a.state = atomicutil.NewValue(state)
|
a.state = atomicutil.NewValue(state)
|
||||||
|
|
||||||
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, a.globalCache, directory.GroupRecordType)
|
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, directory.GroupRecordType)
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,6 @@ package authorize
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
|
@ -18,47 +15,6 @@ type sessionOrServiceAccount interface {
|
||||||
Validate() error
|
Validate() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDataBrokerRecord(
|
|
||||||
ctx context.Context,
|
|
||||||
recordType string,
|
|
||||||
recordID string,
|
|
||||||
lowestRecordVersion uint64,
|
|
||||||
) (*databroker.Record, error) {
|
|
||||||
q := storage.GetQuerier(ctx)
|
|
||||||
|
|
||||||
req := &databroker.QueryRequest{
|
|
||||||
Type: recordType,
|
|
||||||
Limit: 1,
|
|
||||||
}
|
|
||||||
req.SetFilterByIDOrIndex(recordID)
|
|
||||||
|
|
||||||
res, err := q.Query(ctx, req, grpc.WaitForReady(true))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(res.GetRecords()) == 0 {
|
|
||||||
return nil, storage.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the current record version is less than the lowest we'll accept, invalidate the cache
|
|
||||||
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
|
|
||||||
q.InvalidateCache(ctx, req)
|
|
||||||
} else {
|
|
||||||
return res.GetRecords()[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// retry with an up to date cache
|
|
||||||
res, err = q.Query(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(res.GetRecords()) == 0 {
|
|
||||||
return nil, storage.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return res.GetRecords()[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
sessionID string,
|
sessionID string,
|
||||||
|
@ -67,9 +23,9 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
||||||
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
|
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
|
record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
|
||||||
if storage.IsNotFound(err) {
|
if storage.IsNotFound(err) {
|
||||||
record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
|
record, err = storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -100,7 +56,7 @@ func (a *Authorize) getDataBrokerUser(
|
||||||
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerUser")
|
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerUser")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0)
|
record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package authorize
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -12,45 +11,9 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_getDataBrokerRecord(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
|
||||||
t.Cleanup(clearTimeout)
|
|
||||||
|
|
||||||
for _, tc := range []struct {
|
|
||||||
name string
|
|
||||||
recordVersion, queryVersion uint64
|
|
||||||
underlyingQueryCount, cachedQueryCount int
|
|
||||||
}{
|
|
||||||
{"cached", 1, 1, 1, 2},
|
|
||||||
{"invalidated", 1, 2, 3, 4},
|
|
||||||
} {
|
|
||||||
tc := tc
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
|
|
||||||
|
|
||||||
sq := storage.NewStaticQuerier(s1)
|
|
||||||
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
|
||||||
qctx := storage.WithQuerier(ctx, cq)
|
|
||||||
|
|
||||||
s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, s)
|
|
||||||
|
|
||||||
s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, s)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) {
|
func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
||||||
|
|
||||||
querier := storage.NewCachingQuerier(
|
querier := storage.NewCachingQuerier(
|
||||||
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||||||
a.globalCache,
|
storage.GlobalCache,
|
||||||
)
|
)
|
||||||
ctx = storage.WithQuerier(ctx, querier)
|
ctx = storage.WithQuerier(ctx, querier)
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ func (a *Authorize) loadSession(
|
||||||
// attempt to create a session from an incoming idp token
|
// attempt to create a session from an incoming idp token
|
||||||
s, err = config.NewIncomingIDPTokenSessionCreator(
|
s, err = config.NewIncomingIDPTokenSessionCreator(
|
||||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||||
return getDataBrokerRecord(ctx, recordType, recordID, 0)
|
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||||
},
|
},
|
||||||
func(ctx context.Context, records []*databroker.Record) error {
|
func(ctx context.Context, records []*databroker.Record) error {
|
||||||
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||||
|
@ -107,15 +107,7 @@ func (a *Authorize) loadSession(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// invalidate cache
|
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
|
||||||
for _, record := range records {
|
|
||||||
q := &databroker.QueryRequest{
|
|
||||||
Type: record.GetType(),
|
|
||||||
Limit: 1,
|
|
||||||
}
|
|
||||||
q.SetFilterByIDOrIndex(record.GetId())
|
|
||||||
storage.GetQuerier(ctx).InvalidateCache(ctx, q)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq)
|
).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq)
|
||||||
|
|
|
@ -3,14 +3,12 @@ package databroker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
status "google.golang.org/grpc/status"
|
status "google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
structpb "google.golang.org/protobuf/types/known/structpb"
|
structpb "google.golang.org/protobuf/types/known/structpb"
|
||||||
|
|
||||||
|
@ -53,34 +51,6 @@ func Get(ctx context.Context, client DataBrokerServiceClient, object recordObjec
|
||||||
return res.GetRecord().GetData().UnmarshalTo(object)
|
return res.GetRecord().GetData().UnmarshalTo(object)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetViaJSON gets a record from the databroker, marshals it to JSON, and then unmarshals it to the given type.
|
|
||||||
func GetViaJSON[T any](ctx context.Context, client DataBrokerServiceClient, recordType, recordID string) (*T, error) {
|
|
||||||
res, err := client.Get(ctx, &GetRequest{
|
|
||||||
Type: recordType,
|
|
||||||
Id: recordID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := res.GetRecord().GetData().UnmarshalNew()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
bs, err := protojson.Marshal(msg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var obj T
|
|
||||||
err = json.Unmarshal(bs, &obj)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &obj, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put puts a record into the databroker.
|
// Put puts a record into the databroker.
|
||||||
func Put(ctx context.Context, client DataBrokerServiceClient, objects ...recordObject) (*PutResponse, error) {
|
func Put(ctx context.Context, client DataBrokerServiceClient, objects ...recordObject) (*PutResponse, error) {
|
||||||
records := make([]*Record, len(objects))
|
records := make([]*Record, len(objects))
|
||||||
|
|
|
@ -107,3 +107,6 @@ func (cache *globalCache) set(expiry time.Time, key, value []byte) {
|
||||||
cache.fastcache.Set(key, item)
|
cache.fastcache.Set(key, item)
|
||||||
cache.mu.Unlock()
|
cache.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GlobalCache is a global cache with a TTL of one minute.
|
||||||
|
var GlobalCache = NewGlobalCache(time.Minute)
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -222,3 +223,114 @@ func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) {
|
||||||
Deterministic: true,
|
Deterministic: true,
|
||||||
}).Marshal(res)
|
}).Marshal(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDataBrokerRecord uses a querier to get a databroker record.
|
||||||
|
func GetDataBrokerRecord(
|
||||||
|
ctx context.Context,
|
||||||
|
recordType string,
|
||||||
|
recordID string,
|
||||||
|
lowestRecordVersion uint64,
|
||||||
|
) (*databroker.Record, error) {
|
||||||
|
q := GetQuerier(ctx)
|
||||||
|
|
||||||
|
req := &databroker.QueryRequest{
|
||||||
|
Type: recordType,
|
||||||
|
Limit: 1,
|
||||||
|
}
|
||||||
|
req.SetFilterByIDOrIndex(recordID)
|
||||||
|
|
||||||
|
res, err := q.Query(ctx, req, grpc.WaitForReady(true))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(res.GetRecords()) == 0 {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the current record version is less than the lowest we'll accept, invalidate the cache
|
||||||
|
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
|
||||||
|
q.InvalidateCache(ctx, req)
|
||||||
|
} else {
|
||||||
|
return res.GetRecords()[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// retry with an up to date cache
|
||||||
|
res, err = q.Query(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(res.GetRecords()) == 0 {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.GetRecords()[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDataBrokerMessage gets a databroker record and converts it into the message type.
|
||||||
|
func GetDataBrokerMessage[T any, TMessage interface {
|
||||||
|
*T
|
||||||
|
proto.Message
|
||||||
|
}](
|
||||||
|
ctx context.Context,
|
||||||
|
recordID string,
|
||||||
|
lowestRecordVersion uint64,
|
||||||
|
) (TMessage, error) {
|
||||||
|
var msg T
|
||||||
|
|
||||||
|
record, err := GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(TMessage(&msg)), recordID, lowestRecordVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = record.GetData().UnmarshalTo(TMessage(&msg))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return TMessage(&msg), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDataBrokerObjectViaJSON gets a databroker record and converts it into the object type by going through protojson.
|
||||||
|
func GetDataBrokerObjectViaJSON[T any](
|
||||||
|
ctx context.Context,
|
||||||
|
recordType string,
|
||||||
|
recordID string,
|
||||||
|
lowestRecordVersion uint64,
|
||||||
|
) (*T, error) {
|
||||||
|
record, err := GetDataBrokerRecord(ctx, recordType, recordID, lowestRecordVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := record.GetData().UnmarshalNew()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
bs, err := protojson.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var obj T
|
||||||
|
err = json.Unmarshal(bs, &obj)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &obj, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateCacheForDataBrokerRecords invalidates the cache of the querier for the databroker records.
|
||||||
|
func InvalidateCacheForDataBrokerRecords(
|
||||||
|
ctx context.Context,
|
||||||
|
records ...*databroker.Record,
|
||||||
|
) {
|
||||||
|
for _, record := range records {
|
||||||
|
q := &databroker.QueryRequest{
|
||||||
|
Type: record.GetType(),
|
||||||
|
Limit: 1,
|
||||||
|
}
|
||||||
|
q.SetFilterByIDOrIndex(record.GetId())
|
||||||
|
GetQuerier(ctx).InvalidateCache(ctx, q)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
101
pkg/storage/querier_test.go
Normal file
101
pkg/storage/querier_test.go
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
package storage_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/protobuf/testing/protocmp"
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
|
||||||
|
"github.com/pomerium/datasource/pkg/directory"
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetDataBrokerRecord(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
t.Cleanup(clearTimeout)
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
recordVersion, queryVersion uint64
|
||||||
|
underlyingQueryCount, cachedQueryCount int
|
||||||
|
}{
|
||||||
|
{"cached", 1, 1, 1, 2},
|
||||||
|
{"invalidated", 1, 2, 3, 4},
|
||||||
|
} {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
|
||||||
|
|
||||||
|
sq := storage.NewStaticQuerier(s1)
|
||||||
|
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||||
|
qctx := storage.WithQuerier(ctx, cq)
|
||||||
|
|
||||||
|
s, err := storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, s)
|
||||||
|
|
||||||
|
s, err = storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, s)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDataBrokerMessage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.GetContext(t, time.Minute)
|
||||||
|
|
||||||
|
s1 := &session.Session{Id: "s1"}
|
||||||
|
sq := storage.NewStaticQuerier(s1)
|
||||||
|
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||||
|
qctx := storage.WithQuerier(ctx, cq)
|
||||||
|
|
||||||
|
s2, err := storage.GetDataBrokerMessage[session.Session](qctx, "s1", 0)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, cmp.Diff(s1, s2, protocmp.Transform()))
|
||||||
|
|
||||||
|
_, err = storage.GetDataBrokerMessage[session.Session](qctx, "s2", 0)
|
||||||
|
assert.ErrorIs(t, err, storage.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDataBrokerObjectViaJSON(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.GetContext(t, time.Minute)
|
||||||
|
|
||||||
|
du1 := &directory.User{
|
||||||
|
ID: "u1",
|
||||||
|
Email: "u1@example.com",
|
||||||
|
DisplayName: "User 1!",
|
||||||
|
}
|
||||||
|
sq := storage.NewStaticQuerier(newDirectoryUserRecord(du1))
|
||||||
|
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||||
|
qctx := storage.WithQuerier(ctx, cq)
|
||||||
|
|
||||||
|
du2, err := storage.GetDataBrokerObjectViaJSON[directory.User](qctx, directory.UserRecordType, "u1", 0)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, cmp.Diff(du1, du2, protocmp.Transform()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDirectoryUserRecord(directoryUser *directory.User) *databroker.Record {
|
||||||
|
m := map[string]any{}
|
||||||
|
bs, _ := json.Marshal(directoryUser)
|
||||||
|
_ = json.Unmarshal(bs, &m)
|
||||||
|
s, _ := structpb.NewStruct(m)
|
||||||
|
return storage.NewStaticRecord(directory.UserRecordType, s)
|
||||||
|
}
|
|
@ -9,28 +9,24 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/handlers"
|
"github.com/pomerium/pomerium/internal/handlers"
|
||||||
"github.com/pomerium/pomerium/internal/handlers/webauthn"
|
"github.com/pomerium/pomerium/internal/handlers/webauthn"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
"github.com/pomerium/pomerium/pkg/webauthnutil"
|
"github.com/pomerium/pomerium/pkg/webauthnutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Session, isImpersonated bool, err error) {
|
func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Session, isImpersonated bool, err error) {
|
||||||
client := p.state.Load().dataBrokerClient
|
|
||||||
|
|
||||||
isImpersonated = false
|
isImpersonated = false
|
||||||
s, err = session.Get(ctx, client, sessionID)
|
s, err = storage.GetDataBrokerMessage[session.Session](ctx, sessionID, 0)
|
||||||
if s.GetImpersonateSessionId() != "" {
|
if s.GetImpersonateSessionId() != "" {
|
||||||
s, err = session.Get(ctx, client, s.GetImpersonateSessionId())
|
s, err = storage.GetDataBrokerMessage[session.Session](ctx, s.GetImpersonateSessionId(), 0)
|
||||||
isImpersonated = true
|
isImpersonated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
return s, isImpersonated, err
|
return s, isImpersonated, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) {
|
func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) {
|
||||||
client := p.state.Load().dataBrokerClient
|
return storage.GetDataBrokerMessage[user.User](ctx, userID, 0)
|
||||||
return user.Get(ctx, client, userID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
||||||
|
@ -72,21 +68,16 @@ func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) {
|
func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) {
|
||||||
client := p.state.Load().dataBrokerClient
|
record, _ := storage.GetDataBrokerRecord(ctx, "type.googleapis.com/pomerium.config.Config", "dashboard-settings", 0)
|
||||||
|
data.IsEnterprise = record != nil
|
||||||
res, _ := client.Get(ctx, &databroker.GetRequest{
|
|
||||||
Type: "type.googleapis.com/pomerium.config.Config",
|
|
||||||
Id: "dashboard-settings",
|
|
||||||
})
|
|
||||||
data.IsEnterprise = res.GetRecord() != nil
|
|
||||||
if !data.IsEnterprise {
|
if !data.IsEnterprise {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data.DirectoryUser, _ = databroker.GetViaJSON[directory.User](ctx, client, directory.UserRecordType, data.Session.GetUserId())
|
data.DirectoryUser, _ = storage.GetDataBrokerObjectViaJSON[directory.User](ctx, directory.UserRecordType, data.Session.GetUserId(), 0)
|
||||||
if data.DirectoryUser != nil {
|
if data.DirectoryUser != nil {
|
||||||
for _, groupID := range data.DirectoryUser.GroupIDs {
|
for _, groupID := range data.DirectoryUser.GroupIDs {
|
||||||
directoryGroup, _ := databroker.GetViaJSON[directory.Group](ctx, client, directory.GroupRecordType, groupID)
|
directoryGroup, _ := storage.GetDataBrokerObjectViaJSON[directory.Group](ctx, directory.GroupRecordType, groupID, 0)
|
||||||
if directoryGroup != nil {
|
if directoryGroup != nil {
|
||||||
data.DirectoryGroups = append(data.DirectoryGroups, directoryGroup)
|
data.DirectoryGroups = append(data.DirectoryGroups, directoryGroup)
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_getUserInfoData(t *testing.T) {
|
func Test_getUserInfoData(t *testing.T) {
|
||||||
|
@ -65,6 +66,7 @@ func Test_getUserInfoData(t *testing.T) {
|
||||||
proxy, err := New(ctx, &config.Config{Options: opts})
|
proxy, err := New(ctx, &config.Config{Options: opts})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
proxy.state.Load().dataBrokerClient = client
|
proxy.state.Load().dataBrokerClient = client
|
||||||
|
ctx = storage.WithQuerier(ctx, storage.NewQuerier(client))
|
||||||
|
|
||||||
require.NoError(t, databrokerpb.PutMulti(ctx, client,
|
require.NoError(t, databrokerpb.PutMulti(ctx, client,
|
||||||
makeRecord(&session.Session{
|
makeRecord(&session.Session{
|
||||||
|
@ -81,7 +83,7 @@ func Test_getUserInfoData(t *testing.T) {
|
||||||
"group_ids": []any{"G1", "G2", "G3"},
|
"group_ids": []any{"G1", "G2", "G3"},
|
||||||
})))
|
})))
|
||||||
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
|
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/.pomerium/", nil)
|
||||||
r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{
|
r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{
|
||||||
ID: "S1",
|
ID: "S1",
|
||||||
}))
|
}))
|
||||||
|
@ -89,7 +91,9 @@ func Test_getUserInfoData(t *testing.T) {
|
||||||
assert.Equal(t, "S1", data.Session.Id)
|
assert.Equal(t, "S1", data.Session.Id)
|
||||||
assert.Equal(t, "U1", data.User.Id)
|
assert.Equal(t, "U1", data.User.Id)
|
||||||
assert.True(t, data.IsEnterprise)
|
assert.True(t, data.IsEnterprise)
|
||||||
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
|
if assert.NotNil(t, data.DirectoryUser) {
|
||||||
|
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
"github.com/pomerium/pomerium/proxy/portal"
|
"github.com/pomerium/pomerium/proxy/portal"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -124,6 +125,8 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
|
||||||
r.StrictSlash(true)
|
r.StrictSlash(true)
|
||||||
// dashboard handlers are registered to all routes
|
// dashboard handlers are registered to all routes
|
||||||
r = p.registerDashboardHandlers(r, opts)
|
r = p.registerDashboardHandlers(r, opts)
|
||||||
|
// attach the querier to the context
|
||||||
|
r.Use(p.querierMiddleware)
|
||||||
r.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(p.tracerProvider)))
|
r.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(p.tracerProvider)))
|
||||||
|
|
||||||
p.currentRouter.Store(r)
|
p.currentRouter.Store(r)
|
||||||
|
@ -133,3 +136,16 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
|
||||||
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
p.currentRouter.Load().ServeHTTP(w, r)
|
p.currentRouter.Load().ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) querierMiddleware(h http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = storage.WithQuerier(ctx, storage.NewCachingQuerier(
|
||||||
|
storage.NewQuerier(p.state.Load().dataBrokerClient),
|
||||||
|
storage.GlobalCache,
|
||||||
|
))
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/authenticateflow"
|
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||||
|
@ -87,19 +88,16 @@ func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.Trace
|
||||||
|
|
||||||
state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
|
state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
|
||||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||||
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||||
Type: recordType,
|
|
||||||
Id: recordID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return res.GetRecord(), nil
|
|
||||||
},
|
},
|
||||||
func(ctx context.Context, records []*databroker.Record) error {
|
func(ctx context.Context, records []*databroker.Record) error {
|
||||||
_, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
_, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||||
Records: records,
|
Records: records,
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
|
||||||
return err
|
return err
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue