From bc263e3ee58705b1bc7b3c6cab2370d719d12b45 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Thu, 20 Mar 2025 09:50:22 -0600 Subject: [PATCH] proxy: use querier cache for user info (#5532) --- authorize/authorize.go | 5 +- authorize/databroker.go | 50 +------------ authorize/databroker_test.go | 37 ---------- authorize/grpc.go | 14 +--- pkg/grpc/databroker/databroker.go | 30 -------- pkg/storage/cache.go | 3 + pkg/storage/querier.go | 112 ++++++++++++++++++++++++++++++ pkg/storage/querier_test.go | 101 +++++++++++++++++++++++++++ proxy/data.go | 25 +++---- proxy/data_test.go | 8 ++- proxy/proxy.go | 16 +++++ proxy/state.go | 14 ++-- 12 files changed, 259 insertions(+), 156 deletions(-) create mode 100644 pkg/storage/querier_test.go diff --git a/authorize/authorize.go b/authorize/authorize.go index 82ab173a3..f6ed77525 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -6,7 +6,6 @@ import ( "context" "fmt" "slices" - "time" "github.com/rs/zerolog" oteltrace "go.opentelemetry.io/otel/trace" @@ -31,7 +30,6 @@ type Authorize struct { store *store.Store currentConfig *atomicutil.Value[*config.Config] accessTracker *AccessTracker - globalCache storage.Cache groupsCacheWarmer *cacheWarmer tracerProvider oteltrace.TracerProvider @@ -45,7 +43,6 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) { a := &Authorize{ currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}), store: store.New(), - globalCache: storage.NewGlobalCache(time.Minute), tracerProvider: tracerProvider, tracer: tracer, } @@ -57,7 +54,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) { } a.state = atomicutil.NewValue(state) - a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, a.globalCache, directory.GroupRecordType) + a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, directory.GroupRecordType) return a, nil } diff --git a/authorize/databroker.go b/authorize/databroker.go index 2c59e4c30..1e474e792 100644 --- a/authorize/databroker.go +++ b/authorize/databroker.go @@ -3,9 +3,6 @@ package authorize import ( "context" - "google.golang.org/grpc" - - "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpcutil" @@ -18,47 +15,6 @@ type sessionOrServiceAccount interface { Validate() error } -func getDataBrokerRecord( - ctx context.Context, - recordType string, - recordID string, - lowestRecordVersion uint64, -) (*databroker.Record, error) { - q := storage.GetQuerier(ctx) - - req := &databroker.QueryRequest{ - Type: recordType, - Limit: 1, - } - req.SetFilterByIDOrIndex(recordID) - - res, err := q.Query(ctx, req, grpc.WaitForReady(true)) - if err != nil { - return nil, err - } - if len(res.GetRecords()) == 0 { - return nil, storage.ErrNotFound - } - - // if the current record version is less than the lowest we'll accept, invalidate the cache - if res.GetRecords()[0].GetVersion() < lowestRecordVersion { - q.InvalidateCache(ctx, req) - } else { - return res.GetRecords()[0], nil - } - - // retry with an up to date cache - res, err = q.Query(ctx, req) - if err != nil { - return nil, err - } - if len(res.GetRecords()) == 0 { - return nil, storage.ErrNotFound - } - - return res.GetRecords()[0], nil -} - func (a *Authorize) getDataBrokerSessionOrServiceAccount( ctx context.Context, sessionID string, @@ -67,9 +23,9 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount( ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerSessionOrServiceAccount") defer span.End() - record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion) + record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion) if storage.IsNotFound(err) { - record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion) + record, err = storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion) } if err != nil { return nil, err @@ -100,7 +56,7 @@ func (a *Authorize) getDataBrokerUser( ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerUser") defer span.End() - record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0) + record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0) if err != nil { return nil, err } diff --git a/authorize/databroker_test.go b/authorize/databroker_test.go index eefb5987b..f8f47bef7 100644 --- a/authorize/databroker_test.go +++ b/authorize/databroker_test.go @@ -2,7 +2,6 @@ package authorize import ( "context" - "fmt" "testing" "time" @@ -12,45 +11,9 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/pkg/grpc/session" - "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/storage" ) -func Test_getDataBrokerRecord(t *testing.T) { - t.Parallel() - - ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) - t.Cleanup(clearTimeout) - - for _, tc := range []struct { - name string - recordVersion, queryVersion uint64 - underlyingQueryCount, cachedQueryCount int - }{ - {"cached", 1, 1, 1, 2}, - {"invalidated", 1, 2, 3, 4}, - } { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)} - - sq := storage.NewStaticQuerier(s1) - cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute)) - qctx := storage.WithQuerier(ctx, cq) - - s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) - assert.NoError(t, err) - assert.NotNil(t, s) - - s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) - assert.NoError(t, err) - assert.NotNil(t, s) - }) - } -} - func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) { t.Parallel() diff --git a/authorize/grpc.go b/authorize/grpc.go index e5a1980b5..94c172b30 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -36,7 +36,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe querier := storage.NewCachingQuerier( storage.NewQuerier(a.state.Load().dataBrokerClient), - a.globalCache, + storage.GlobalCache, ) ctx = storage.WithQuerier(ctx, querier) @@ -98,7 +98,7 @@ func (a *Authorize) loadSession( // attempt to create a session from an incoming idp token s, err = config.NewIncomingIDPTokenSessionCreator( func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) { - return getDataBrokerRecord(ctx, recordType, recordID, 0) + return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0) }, func(ctx context.Context, records []*databroker.Record) error { _, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{ @@ -107,15 +107,7 @@ func (a *Authorize) loadSession( if err != nil { return err } - // invalidate cache - for _, record := range records { - q := &databroker.QueryRequest{ - Type: record.GetType(), - Limit: 1, - } - q.SetFilterByIDOrIndex(record.GetId()) - storage.GetQuerier(ctx).InvalidateCache(ctx, q) - } + storage.InvalidateCacheForDataBrokerRecords(ctx, records...) return nil }, ).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq) diff --git a/pkg/grpc/databroker/databroker.go b/pkg/grpc/databroker/databroker.go index 26841ca7f..d3933dbfa 100644 --- a/pkg/grpc/databroker/databroker.go +++ b/pkg/grpc/databroker/databroker.go @@ -3,14 +3,12 @@ package databroker import ( "context" - "encoding/json" "errors" "fmt" "io" "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" structpb "google.golang.org/protobuf/types/known/structpb" @@ -53,34 +51,6 @@ func Get(ctx context.Context, client DataBrokerServiceClient, object recordObjec return res.GetRecord().GetData().UnmarshalTo(object) } -// GetViaJSON gets a record from the databroker, marshals it to JSON, and then unmarshals it to the given type. -func GetViaJSON[T any](ctx context.Context, client DataBrokerServiceClient, recordType, recordID string) (*T, error) { - res, err := client.Get(ctx, &GetRequest{ - Type: recordType, - Id: recordID, - }) - if err != nil { - return nil, err - } - - msg, err := res.GetRecord().GetData().UnmarshalNew() - if err != nil { - return nil, err - } - - bs, err := protojson.Marshal(msg) - if err != nil { - return nil, err - } - - var obj T - err = json.Unmarshal(bs, &obj) - if err != nil { - return nil, err - } - return &obj, nil -} - // Put puts a record into the databroker. func Put(ctx context.Context, client DataBrokerServiceClient, objects ...recordObject) (*PutResponse, error) { records := make([]*Record, len(objects)) diff --git a/pkg/storage/cache.go b/pkg/storage/cache.go index 1f0e391b4..3d8d563e6 100644 --- a/pkg/storage/cache.go +++ b/pkg/storage/cache.go @@ -107,3 +107,6 @@ func (cache *globalCache) set(expiry time.Time, key, value []byte) { cache.fastcache.Set(key, item) cache.mu.Unlock() } + +// GlobalCache is a global cache with a TTL of one minute. +var GlobalCache = NewGlobalCache(time.Minute) diff --git a/pkg/storage/querier.go b/pkg/storage/querier.go index c70247e5c..ddda78769 100644 --- a/pkg/storage/querier.go +++ b/pkg/storage/querier.go @@ -15,6 +15,7 @@ import ( "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/protoutil" ) @@ -222,3 +223,114 @@ func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) { Deterministic: true, }).Marshal(res) } + +// GetDataBrokerRecord uses a querier to get a databroker record. +func GetDataBrokerRecord( + ctx context.Context, + recordType string, + recordID string, + lowestRecordVersion uint64, +) (*databroker.Record, error) { + q := GetQuerier(ctx) + + req := &databroker.QueryRequest{ + Type: recordType, + Limit: 1, + } + req.SetFilterByIDOrIndex(recordID) + + res, err := q.Query(ctx, req, grpc.WaitForReady(true)) + if err != nil { + return nil, err + } + if len(res.GetRecords()) == 0 { + return nil, ErrNotFound + } + + // if the current record version is less than the lowest we'll accept, invalidate the cache + if res.GetRecords()[0].GetVersion() < lowestRecordVersion { + q.InvalidateCache(ctx, req) + } else { + return res.GetRecords()[0], nil + } + + // retry with an up to date cache + res, err = q.Query(ctx, req) + if err != nil { + return nil, err + } + if len(res.GetRecords()) == 0 { + return nil, ErrNotFound + } + + return res.GetRecords()[0], nil +} + +// GetDataBrokerMessage gets a databroker record and converts it into the message type. +func GetDataBrokerMessage[T any, TMessage interface { + *T + proto.Message +}]( + ctx context.Context, + recordID string, + lowestRecordVersion uint64, +) (TMessage, error) { + var msg T + + record, err := GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(TMessage(&msg)), recordID, lowestRecordVersion) + if err != nil { + return nil, err + } + + err = record.GetData().UnmarshalTo(TMessage(&msg)) + if err != nil { + return nil, err + } + + return TMessage(&msg), nil +} + +// GetDataBrokerObjectViaJSON gets a databroker record and converts it into the object type by going through protojson. +func GetDataBrokerObjectViaJSON[T any]( + ctx context.Context, + recordType string, + recordID string, + lowestRecordVersion uint64, +) (*T, error) { + record, err := GetDataBrokerRecord(ctx, recordType, recordID, lowestRecordVersion) + if err != nil { + return nil, err + } + + msg, err := record.GetData().UnmarshalNew() + if err != nil { + return nil, err + } + + bs, err := protojson.Marshal(msg) + if err != nil { + return nil, err + } + + var obj T + err = json.Unmarshal(bs, &obj) + if err != nil { + return nil, err + } + return &obj, nil +} + +// InvalidateCacheForDataBrokerRecords invalidates the cache of the querier for the databroker records. +func InvalidateCacheForDataBrokerRecords( + ctx context.Context, + records ...*databroker.Record, +) { + for _, record := range records { + q := &databroker.QueryRequest{ + Type: record.GetType(), + Limit: 1, + } + q.SetFilterByIDOrIndex(record.GetId()) + GetQuerier(ctx).InvalidateCache(ctx, q) + } +} diff --git a/pkg/storage/querier_test.go b/pkg/storage/querier_test.go new file mode 100644 index 000000000..904f261cd --- /dev/null +++ b/pkg/storage/querier_test.go @@ -0,0 +1,101 @@ +package storage_test + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/pomerium/datasource/pkg/directory" + "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpcutil" + "github.com/pomerium/pomerium/pkg/storage" +) + +func TestGetDataBrokerRecord(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + t.Cleanup(clearTimeout) + + for _, tc := range []struct { + name string + recordVersion, queryVersion uint64 + underlyingQueryCount, cachedQueryCount int + }{ + {"cached", 1, 1, 1, 2}, + {"invalidated", 1, 2, 3, 4}, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)} + + sq := storage.NewStaticQuerier(s1) + cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute)) + qctx := storage.WithQuerier(ctx, cq) + + s, err := storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) + assert.NoError(t, err) + assert.NotNil(t, s) + + s, err = storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) + assert.NoError(t, err) + assert.NotNil(t, s) + }) + } +} + +func TestGetDataBrokerMessage(t *testing.T) { + t.Parallel() + + ctx := testutil.GetContext(t, time.Minute) + + s1 := &session.Session{Id: "s1"} + sq := storage.NewStaticQuerier(s1) + cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute)) + qctx := storage.WithQuerier(ctx, cq) + + s2, err := storage.GetDataBrokerMessage[session.Session](qctx, "s1", 0) + assert.NoError(t, err) + assert.Empty(t, cmp.Diff(s1, s2, protocmp.Transform())) + + _, err = storage.GetDataBrokerMessage[session.Session](qctx, "s2", 0) + assert.ErrorIs(t, err, storage.ErrNotFound) +} + +func TestGetDataBrokerObjectViaJSON(t *testing.T) { + t.Parallel() + + ctx := testutil.GetContext(t, time.Minute) + + du1 := &directory.User{ + ID: "u1", + Email: "u1@example.com", + DisplayName: "User 1!", + } + sq := storage.NewStaticQuerier(newDirectoryUserRecord(du1)) + cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute)) + qctx := storage.WithQuerier(ctx, cq) + + du2, err := storage.GetDataBrokerObjectViaJSON[directory.User](qctx, directory.UserRecordType, "u1", 0) + assert.NoError(t, err) + assert.Empty(t, cmp.Diff(du1, du2, protocmp.Transform())) +} + +func newDirectoryUserRecord(directoryUser *directory.User) *databroker.Record { + m := map[string]any{} + bs, _ := json.Marshal(directoryUser) + _ = json.Unmarshal(bs, &m) + s, _ := structpb.NewStruct(m) + return storage.NewStaticRecord(directory.UserRecordType, s) +} diff --git a/proxy/data.go b/proxy/data.go index 3a63c3d3d..3fad0f412 100644 --- a/proxy/data.go +++ b/proxy/data.go @@ -9,28 +9,24 @@ import ( "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/handlers/webauthn" "github.com/pomerium/pomerium/internal/urlutil" - "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/webauthnutil" ) func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Session, isImpersonated bool, err error) { - client := p.state.Load().dataBrokerClient - isImpersonated = false - s, err = session.Get(ctx, client, sessionID) + s, err = storage.GetDataBrokerMessage[session.Session](ctx, sessionID, 0) if s.GetImpersonateSessionId() != "" { - s, err = session.Get(ctx, client, s.GetImpersonateSessionId()) + s, err = storage.GetDataBrokerMessage[session.Session](ctx, s.GetImpersonateSessionId(), 0) isImpersonated = true } - return s, isImpersonated, err } func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) { - client := p.state.Load().dataBrokerClient - return user.Get(ctx, client, userID) + return storage.GetDataBrokerMessage[user.User](ctx, userID, 0) } func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData { @@ -72,21 +68,16 @@ func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData { } func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) { - client := p.state.Load().dataBrokerClient - - res, _ := client.Get(ctx, &databroker.GetRequest{ - Type: "type.googleapis.com/pomerium.config.Config", - Id: "dashboard-settings", - }) - data.IsEnterprise = res.GetRecord() != nil + record, _ := storage.GetDataBrokerRecord(ctx, "type.googleapis.com/pomerium.config.Config", "dashboard-settings", 0) + data.IsEnterprise = record != nil if !data.IsEnterprise { return } - data.DirectoryUser, _ = databroker.GetViaJSON[directory.User](ctx, client, directory.UserRecordType, data.Session.GetUserId()) + data.DirectoryUser, _ = storage.GetDataBrokerObjectViaJSON[directory.User](ctx, directory.UserRecordType, data.Session.GetUserId(), 0) if data.DirectoryUser != nil { for _, groupID := range data.DirectoryUser.GroupIDs { - directoryGroup, _ := databroker.GetViaJSON[directory.Group](ctx, client, directory.GroupRecordType, groupID) + directoryGroup, _ := storage.GetDataBrokerObjectViaJSON[directory.Group](ctx, directory.GroupRecordType, groupID, 0) if directoryGroup != nil { data.DirectoryGroups = append(data.DirectoryGroups, directoryGroup) } diff --git a/proxy/data_test.go b/proxy/data_test.go index 8f9042c12..2ff111ab0 100644 --- a/proxy/data_test.go +++ b/proxy/data_test.go @@ -25,6 +25,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/protoutil" + "github.com/pomerium/pomerium/pkg/storage" ) func Test_getUserInfoData(t *testing.T) { @@ -65,6 +66,7 @@ func Test_getUserInfoData(t *testing.T) { proxy, err := New(ctx, &config.Config{Options: opts}) require.NoError(t, err) proxy.state.Load().dataBrokerClient = client + ctx = storage.WithQuerier(ctx, storage.NewQuerier(client)) require.NoError(t, databrokerpb.PutMulti(ctx, client, makeRecord(&session.Session{ @@ -81,7 +83,7 @@ func Test_getUserInfoData(t *testing.T) { "group_ids": []any{"G1", "G2", "G3"}, }))) - r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) + r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/.pomerium/", nil) r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{ ID: "S1", })) @@ -89,7 +91,9 @@ func Test_getUserInfoData(t *testing.T) { assert.Equal(t, "S1", data.Session.Id) assert.Equal(t, "U1", data.User.Id) assert.True(t, data.IsEnterprise) - assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs) + if assert.NotNil(t, data.DirectoryUser) { + assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs) + } }) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 3926d3589..d48977e25 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -21,6 +21,7 @@ import ( "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/proxy/portal" ) @@ -124,6 +125,8 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error { r.StrictSlash(true) // dashboard handlers are registered to all routes r = p.registerDashboardHandlers(r, opts) + // attach the querier to the context + r.Use(p.querierMiddleware) r.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(p.tracerProvider))) p.currentRouter.Store(r) @@ -133,3 +136,16 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error { func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { p.currentRouter.Load().ServeHTTP(w, r) } + +func (p *Proxy) querierMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx = storage.WithQuerier(ctx, storage.NewCachingQuerier( + storage.NewQuerier(p.state.Load().dataBrokerClient), + storage.GlobalCache, + )) + r = r.WithContext(ctx) + + h.ServeHTTP(w, r) + }) +} diff --git a/proxy/state.go b/proxy/state.go index dab2563d4..110459293 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -13,6 +13,7 @@ import ( "github.com/pomerium/pomerium/internal/authenticateflow" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/storage" ) var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) @@ -87,19 +88,16 @@ func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.Trace state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator( func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) { - res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{ - Type: recordType, - Id: recordID, - }) - if err != nil { - return nil, err - } - return res.GetRecord(), nil + return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0) }, func(ctx context.Context, records []*databroker.Record) error { _, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{ Records: records, }) + if err != nil { + return err + } + storage.InvalidateCacheForDataBrokerRecords(ctx, records...) return err }, )