mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 09:56:31 +02:00
proxy: use querier cache for user info (#5532)
This commit is contained in:
parent
08623ef346
commit
bc263e3ee5
12 changed files with 259 additions and 156 deletions
|
@ -6,7 +6,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
|
@ -31,7 +30,6 @@ type Authorize struct {
|
|||
store *store.Store
|
||||
currentConfig *atomicutil.Value[*config.Config]
|
||||
accessTracker *AccessTracker
|
||||
globalCache storage.Cache
|
||||
groupsCacheWarmer *cacheWarmer
|
||||
|
||||
tracerProvider oteltrace.TracerProvider
|
||||
|
@ -45,7 +43,6 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
|||
a := &Authorize{
|
||||
currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}),
|
||||
store: store.New(),
|
||||
globalCache: storage.NewGlobalCache(time.Minute),
|
||||
tracerProvider: tracerProvider,
|
||||
tracer: tracer,
|
||||
}
|
||||
|
@ -57,7 +54,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
|||
}
|
||||
a.state = atomicutil.NewValue(state)
|
||||
|
||||
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, a.globalCache, directory.GroupRecordType)
|
||||
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, directory.GroupRecordType)
|
||||
return a, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,9 +3,6 @@ package authorize
|
|||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
|
@ -18,47 +15,6 @@ type sessionOrServiceAccount interface {
|
|||
Validate() error
|
||||
}
|
||||
|
||||
func getDataBrokerRecord(
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
recordID string,
|
||||
lowestRecordVersion uint64,
|
||||
) (*databroker.Record, error) {
|
||||
q := storage.GetQuerier(ctx)
|
||||
|
||||
req := &databroker.QueryRequest{
|
||||
Type: recordType,
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex(recordID)
|
||||
|
||||
res, err := q.Query(ctx, req, grpc.WaitForReady(true))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
// if the current record version is less than the lowest we'll accept, invalidate the cache
|
||||
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
|
||||
q.InvalidateCache(ctx, req)
|
||||
} else {
|
||||
return res.GetRecords()[0], nil
|
||||
}
|
||||
|
||||
// retry with an up to date cache
|
||||
res, err = q.Query(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return res.GetRecords()[0], nil
|
||||
}
|
||||
|
||||
func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
||||
ctx context.Context,
|
||||
sessionID string,
|
||||
|
@ -67,9 +23,9 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount(
|
|||
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
|
||||
defer span.End()
|
||||
|
||||
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
|
||||
record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
|
||||
if storage.IsNotFound(err) {
|
||||
record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
|
||||
record, err = storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -100,7 +56,7 @@ func (a *Authorize) getDataBrokerUser(
|
|||
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerUser")
|
||||
defer span.End()
|
||||
|
||||
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0)
|
||||
record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package authorize
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -12,45 +11,9 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func Test_getDataBrokerRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
t.Cleanup(clearTimeout)
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
recordVersion, queryVersion uint64
|
||||
underlyingQueryCount, cachedQueryCount int
|
||||
}{
|
||||
{"cached", 1, 1, 1, 2},
|
||||
{"invalidated", 1, 2, 3, 4},
|
||||
} {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
|
||||
|
||||
sq := storage.NewStaticQuerier(s1)
|
||||
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||
qctx := storage.WithQuerier(ctx, cq)
|
||||
|
||||
s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
|
||||
s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
|
||||
querier := storage.NewCachingQuerier(
|
||||
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||||
a.globalCache,
|
||||
storage.GlobalCache,
|
||||
)
|
||||
ctx = storage.WithQuerier(ctx, querier)
|
||||
|
||||
|
@ -98,7 +98,7 @@ func (a *Authorize) loadSession(
|
|||
// attempt to create a session from an incoming idp token
|
||||
s, err = config.NewIncomingIDPTokenSessionCreator(
|
||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||
return getDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||
},
|
||||
func(ctx context.Context, records []*databroker.Record) error {
|
||||
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
|
@ -107,15 +107,7 @@ func (a *Authorize) loadSession(
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// invalidate cache
|
||||
for _, record := range records {
|
||||
q := &databroker.QueryRequest{
|
||||
Type: record.GetType(),
|
||||
Limit: 1,
|
||||
}
|
||||
q.SetFilterByIDOrIndex(record.GetId())
|
||||
storage.GetQuerier(ctx).InvalidateCache(ctx, q)
|
||||
}
|
||||
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
|
||||
return nil
|
||||
},
|
||||
).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq)
|
||||
|
|
|
@ -3,14 +3,12 @@ package databroker
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
structpb "google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
|
@ -53,34 +51,6 @@ func Get(ctx context.Context, client DataBrokerServiceClient, object recordObjec
|
|||
return res.GetRecord().GetData().UnmarshalTo(object)
|
||||
}
|
||||
|
||||
// GetViaJSON gets a record from the databroker, marshals it to JSON, and then unmarshals it to the given type.
|
||||
func GetViaJSON[T any](ctx context.Context, client DataBrokerServiceClient, recordType, recordID string) (*T, error) {
|
||||
res, err := client.Get(ctx, &GetRequest{
|
||||
Type: recordType,
|
||||
Id: recordID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := res.GetRecord().GetData().UnmarshalNew()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs, err := protojson.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var obj T
|
||||
err = json.Unmarshal(bs, &obj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &obj, nil
|
||||
}
|
||||
|
||||
// Put puts a record into the databroker.
|
||||
func Put(ctx context.Context, client DataBrokerServiceClient, objects ...recordObject) (*PutResponse, error) {
|
||||
records := make([]*Record, len(objects))
|
||||
|
|
|
@ -107,3 +107,6 @@ func (cache *globalCache) set(expiry time.Time, key, value []byte) {
|
|||
cache.fastcache.Set(key, item)
|
||||
cache.mu.Unlock()
|
||||
}
|
||||
|
||||
// GlobalCache is a global cache with a TTL of one minute.
|
||||
var GlobalCache = NewGlobalCache(time.Minute)
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
|
@ -222,3 +223,114 @@ func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) {
|
|||
Deterministic: true,
|
||||
}).Marshal(res)
|
||||
}
|
||||
|
||||
// GetDataBrokerRecord uses a querier to get a databroker record.
|
||||
func GetDataBrokerRecord(
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
recordID string,
|
||||
lowestRecordVersion uint64,
|
||||
) (*databroker.Record, error) {
|
||||
q := GetQuerier(ctx)
|
||||
|
||||
req := &databroker.QueryRequest{
|
||||
Type: recordType,
|
||||
Limit: 1,
|
||||
}
|
||||
req.SetFilterByIDOrIndex(recordID)
|
||||
|
||||
res, err := q.Query(ctx, req, grpc.WaitForReady(true))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
// if the current record version is less than the lowest we'll accept, invalidate the cache
|
||||
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
|
||||
q.InvalidateCache(ctx, req)
|
||||
} else {
|
||||
return res.GetRecords()[0], nil
|
||||
}
|
||||
|
||||
// retry with an up to date cache
|
||||
res, err = q.Query(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res.GetRecords()) == 0 {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return res.GetRecords()[0], nil
|
||||
}
|
||||
|
||||
// GetDataBrokerMessage gets a databroker record and converts it into the message type.
|
||||
func GetDataBrokerMessage[T any, TMessage interface {
|
||||
*T
|
||||
proto.Message
|
||||
}](
|
||||
ctx context.Context,
|
||||
recordID string,
|
||||
lowestRecordVersion uint64,
|
||||
) (TMessage, error) {
|
||||
var msg T
|
||||
|
||||
record, err := GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(TMessage(&msg)), recordID, lowestRecordVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = record.GetData().UnmarshalTo(TMessage(&msg))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return TMessage(&msg), nil
|
||||
}
|
||||
|
||||
// GetDataBrokerObjectViaJSON gets a databroker record and converts it into the object type by going through protojson.
|
||||
func GetDataBrokerObjectViaJSON[T any](
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
recordID string,
|
||||
lowestRecordVersion uint64,
|
||||
) (*T, error) {
|
||||
record, err := GetDataBrokerRecord(ctx, recordType, recordID, lowestRecordVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := record.GetData().UnmarshalNew()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs, err := protojson.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var obj T
|
||||
err = json.Unmarshal(bs, &obj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &obj, nil
|
||||
}
|
||||
|
||||
// InvalidateCacheForDataBrokerRecords invalidates the cache of the querier for the databroker records.
|
||||
func InvalidateCacheForDataBrokerRecords(
|
||||
ctx context.Context,
|
||||
records ...*databroker.Record,
|
||||
) {
|
||||
for _, record := range records {
|
||||
q := &databroker.QueryRequest{
|
||||
Type: record.GetType(),
|
||||
Limit: 1,
|
||||
}
|
||||
q.SetFilterByIDOrIndex(record.GetId())
|
||||
GetQuerier(ctx).InvalidateCache(ctx, q)
|
||||
}
|
||||
}
|
||||
|
|
101
pkg/storage/querier_test.go
Normal file
101
pkg/storage/querier_test.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package storage_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/testing/protocmp"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/pomerium/datasource/pkg/directory"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func TestGetDataBrokerRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
t.Cleanup(clearTimeout)
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
recordVersion, queryVersion uint64
|
||||
underlyingQueryCount, cachedQueryCount int
|
||||
}{
|
||||
{"cached", 1, 1, 1, 2},
|
||||
{"invalidated", 1, 2, 3, 4},
|
||||
} {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
|
||||
|
||||
sq := storage.NewStaticQuerier(s1)
|
||||
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||
qctx := storage.WithQuerier(ctx, cq)
|
||||
|
||||
s, err := storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
|
||||
s, err = storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDataBrokerMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
|
||||
s1 := &session.Session{Id: "s1"}
|
||||
sq := storage.NewStaticQuerier(s1)
|
||||
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||
qctx := storage.WithQuerier(ctx, cq)
|
||||
|
||||
s2, err := storage.GetDataBrokerMessage[session.Session](qctx, "s1", 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff(s1, s2, protocmp.Transform()))
|
||||
|
||||
_, err = storage.GetDataBrokerMessage[session.Session](qctx, "s2", 0)
|
||||
assert.ErrorIs(t, err, storage.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestGetDataBrokerObjectViaJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
|
||||
du1 := &directory.User{
|
||||
ID: "u1",
|
||||
Email: "u1@example.com",
|
||||
DisplayName: "User 1!",
|
||||
}
|
||||
sq := storage.NewStaticQuerier(newDirectoryUserRecord(du1))
|
||||
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
|
||||
qctx := storage.WithQuerier(ctx, cq)
|
||||
|
||||
du2, err := storage.GetDataBrokerObjectViaJSON[directory.User](qctx, directory.UserRecordType, "u1", 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff(du1, du2, protocmp.Transform()))
|
||||
}
|
||||
|
||||
func newDirectoryUserRecord(directoryUser *directory.User) *databroker.Record {
|
||||
m := map[string]any{}
|
||||
bs, _ := json.Marshal(directoryUser)
|
||||
_ = json.Unmarshal(bs, &m)
|
||||
s, _ := structpb.NewStruct(m)
|
||||
return storage.NewStaticRecord(directory.UserRecordType, s)
|
||||
}
|
|
@ -9,28 +9,24 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/internal/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/pkg/webauthnutil"
|
||||
)
|
||||
|
||||
func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Session, isImpersonated bool, err error) {
|
||||
client := p.state.Load().dataBrokerClient
|
||||
|
||||
isImpersonated = false
|
||||
s, err = session.Get(ctx, client, sessionID)
|
||||
s, err = storage.GetDataBrokerMessage[session.Session](ctx, sessionID, 0)
|
||||
if s.GetImpersonateSessionId() != "" {
|
||||
s, err = session.Get(ctx, client, s.GetImpersonateSessionId())
|
||||
s, err = storage.GetDataBrokerMessage[session.Session](ctx, s.GetImpersonateSessionId(), 0)
|
||||
isImpersonated = true
|
||||
}
|
||||
|
||||
return s, isImpersonated, err
|
||||
}
|
||||
|
||||
func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) {
|
||||
client := p.state.Load().dataBrokerClient
|
||||
return user.Get(ctx, client, userID)
|
||||
return storage.GetDataBrokerMessage[user.User](ctx, userID, 0)
|
||||
}
|
||||
|
||||
func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
||||
|
@ -72,21 +68,16 @@ func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
|||
}
|
||||
|
||||
func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) {
|
||||
client := p.state.Load().dataBrokerClient
|
||||
|
||||
res, _ := client.Get(ctx, &databroker.GetRequest{
|
||||
Type: "type.googleapis.com/pomerium.config.Config",
|
||||
Id: "dashboard-settings",
|
||||
})
|
||||
data.IsEnterprise = res.GetRecord() != nil
|
||||
record, _ := storage.GetDataBrokerRecord(ctx, "type.googleapis.com/pomerium.config.Config", "dashboard-settings", 0)
|
||||
data.IsEnterprise = record != nil
|
||||
if !data.IsEnterprise {
|
||||
return
|
||||
}
|
||||
|
||||
data.DirectoryUser, _ = databroker.GetViaJSON[directory.User](ctx, client, directory.UserRecordType, data.Session.GetUserId())
|
||||
data.DirectoryUser, _ = storage.GetDataBrokerObjectViaJSON[directory.User](ctx, directory.UserRecordType, data.Session.GetUserId(), 0)
|
||||
if data.DirectoryUser != nil {
|
||||
for _, groupID := range data.DirectoryUser.GroupIDs {
|
||||
directoryGroup, _ := databroker.GetViaJSON[directory.Group](ctx, client, directory.GroupRecordType, groupID)
|
||||
directoryGroup, _ := storage.GetDataBrokerObjectViaJSON[directory.Group](ctx, directory.GroupRecordType, groupID, 0)
|
||||
if directoryGroup != nil {
|
||||
data.DirectoryGroups = append(data.DirectoryGroups, directoryGroup)
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func Test_getUserInfoData(t *testing.T) {
|
||||
|
@ -65,6 +66,7 @@ func Test_getUserInfoData(t *testing.T) {
|
|||
proxy, err := New(ctx, &config.Config{Options: opts})
|
||||
require.NoError(t, err)
|
||||
proxy.state.Load().dataBrokerClient = client
|
||||
ctx = storage.WithQuerier(ctx, storage.NewQuerier(client))
|
||||
|
||||
require.NoError(t, databrokerpb.PutMulti(ctx, client,
|
||||
makeRecord(&session.Session{
|
||||
|
@ -81,7 +83,7 @@ func Test_getUserInfoData(t *testing.T) {
|
|||
"group_ids": []any{"G1", "G2", "G3"},
|
||||
})))
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/.pomerium/", nil)
|
||||
r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{
|
||||
ID: "S1",
|
||||
}))
|
||||
|
@ -89,7 +91,9 @@ func Test_getUserInfoData(t *testing.T) {
|
|||
assert.Equal(t, "S1", data.Session.Id)
|
||||
assert.Equal(t, "U1", data.User.Id)
|
||||
assert.True(t, data.IsEnterprise)
|
||||
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
|
||||
if assert.NotNil(t, data.DirectoryUser) {
|
||||
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/proxy/portal"
|
||||
)
|
||||
|
||||
|
@ -124,6 +125,8 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
|
|||
r.StrictSlash(true)
|
||||
// dashboard handlers are registered to all routes
|
||||
r = p.registerDashboardHandlers(r, opts)
|
||||
// attach the querier to the context
|
||||
r.Use(p.querierMiddleware)
|
||||
r.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(p.tracerProvider)))
|
||||
|
||||
p.currentRouter.Store(r)
|
||||
|
@ -133,3 +136,16 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
|
|||
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
p.currentRouter.Load().ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (p *Proxy) querierMiddleware(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx = storage.WithQuerier(ctx, storage.NewCachingQuerier(
|
||||
storage.NewQuerier(p.state.Load().dataBrokerClient),
|
||||
storage.GlobalCache,
|
||||
))
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||
|
@ -87,19 +88,16 @@ func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.Trace
|
|||
|
||||
state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
|
||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
||||
Type: recordType,
|
||||
Id: recordID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.GetRecord(), nil
|
||||
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||
},
|
||||
func(ctx context.Context, records []*databroker.Record) error {
|
||||
_, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Records: records,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
|
||||
return err
|
||||
},
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue