proxy: use querier cache for user info (#5532)

This commit is contained in:
Caleb Doxsey 2025-03-20 09:50:22 -06:00 committed by GitHub
parent 08623ef346
commit bc263e3ee5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 259 additions and 156 deletions

View file

@ -9,28 +9,24 @@ import (
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/internal/handlers/webauthn"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/webauthnutil"
)
func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Session, isImpersonated bool, err error) {
client := p.state.Load().dataBrokerClient
isImpersonated = false
s, err = session.Get(ctx, client, sessionID)
s, err = storage.GetDataBrokerMessage[session.Session](ctx, sessionID, 0)
if s.GetImpersonateSessionId() != "" {
s, err = session.Get(ctx, client, s.GetImpersonateSessionId())
s, err = storage.GetDataBrokerMessage[session.Session](ctx, s.GetImpersonateSessionId(), 0)
isImpersonated = true
}
return s, isImpersonated, err
}
func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) {
client := p.state.Load().dataBrokerClient
return user.Get(ctx, client, userID)
return storage.GetDataBrokerMessage[user.User](ctx, userID, 0)
}
func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
@ -72,21 +68,16 @@ func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
}
func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) {
client := p.state.Load().dataBrokerClient
res, _ := client.Get(ctx, &databroker.GetRequest{
Type: "type.googleapis.com/pomerium.config.Config",
Id: "dashboard-settings",
})
data.IsEnterprise = res.GetRecord() != nil
record, _ := storage.GetDataBrokerRecord(ctx, "type.googleapis.com/pomerium.config.Config", "dashboard-settings", 0)
data.IsEnterprise = record != nil
if !data.IsEnterprise {
return
}
data.DirectoryUser, _ = databroker.GetViaJSON[directory.User](ctx, client, directory.UserRecordType, data.Session.GetUserId())
data.DirectoryUser, _ = storage.GetDataBrokerObjectViaJSON[directory.User](ctx, directory.UserRecordType, data.Session.GetUserId(), 0)
if data.DirectoryUser != nil {
for _, groupID := range data.DirectoryUser.GroupIDs {
directoryGroup, _ := databroker.GetViaJSON[directory.Group](ctx, client, directory.GroupRecordType, groupID)
directoryGroup, _ := storage.GetDataBrokerObjectViaJSON[directory.Group](ctx, directory.GroupRecordType, groupID, 0)
if directoryGroup != nil {
data.DirectoryGroups = append(data.DirectoryGroups, directoryGroup)
}

View file

@ -25,6 +25,7 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
)
func Test_getUserInfoData(t *testing.T) {
@ -65,6 +66,7 @@ func Test_getUserInfoData(t *testing.T) {
proxy, err := New(ctx, &config.Config{Options: opts})
require.NoError(t, err)
proxy.state.Load().dataBrokerClient = client
ctx = storage.WithQuerier(ctx, storage.NewQuerier(client))
require.NoError(t, databrokerpb.PutMulti(ctx, client,
makeRecord(&session.Session{
@ -81,7 +83,7 @@ func Test_getUserInfoData(t *testing.T) {
"group_ids": []any{"G1", "G2", "G3"},
})))
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/.pomerium/", nil)
r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{
ID: "S1",
}))
@ -89,7 +91,9 @@ func Test_getUserInfoData(t *testing.T) {
assert.Equal(t, "S1", data.Session.Id)
assert.Equal(t, "U1", data.User.Id)
assert.True(t, data.IsEnterprise)
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
if assert.NotNil(t, data.DirectoryUser) {
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
}
})
}

View file

@ -21,6 +21,7 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/proxy/portal"
)
@ -124,6 +125,8 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
r.StrictSlash(true)
// dashboard handlers are registered to all routes
r = p.registerDashboardHandlers(r, opts)
// attach the querier to the context
r.Use(p.querierMiddleware)
r.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(p.tracerProvider)))
p.currentRouter.Store(r)
@ -133,3 +136,16 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
p.currentRouter.Load().ServeHTTP(w, r)
}
func (p *Proxy) querierMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = storage.WithQuerier(ctx, storage.NewCachingQuerier(
storage.NewQuerier(p.state.Load().dataBrokerClient),
storage.GlobalCache,
))
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
})
}

View file

@ -13,6 +13,7 @@ import (
"github.com/pomerium/pomerium/internal/authenticateflow"
"github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
@ -87,19 +88,16 @@ func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.Trace
state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
Type: recordType,
Id: recordID,
})
if err != nil {
return nil, err
}
return res.GetRecord(), nil
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
},
func(ctx context.Context, records []*databroker.Record) error {
_, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: records,
})
if err != nil {
return err
}
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
return err
},
)