mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-03 03:12:50 +02:00
proxy: use querier cache for user info (#5532)
This commit is contained in:
parent
08623ef346
commit
bc263e3ee5
12 changed files with 259 additions and 156 deletions
|
@ -9,28 +9,24 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/internal/handlers/webauthn"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/pkg/webauthnutil"
|
||||
)
|
||||
|
||||
func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Session, isImpersonated bool, err error) {
|
||||
client := p.state.Load().dataBrokerClient
|
||||
|
||||
isImpersonated = false
|
||||
s, err = session.Get(ctx, client, sessionID)
|
||||
s, err = storage.GetDataBrokerMessage[session.Session](ctx, sessionID, 0)
|
||||
if s.GetImpersonateSessionId() != "" {
|
||||
s, err = session.Get(ctx, client, s.GetImpersonateSessionId())
|
||||
s, err = storage.GetDataBrokerMessage[session.Session](ctx, s.GetImpersonateSessionId(), 0)
|
||||
isImpersonated = true
|
||||
}
|
||||
|
||||
return s, isImpersonated, err
|
||||
}
|
||||
|
||||
func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) {
|
||||
client := p.state.Load().dataBrokerClient
|
||||
return user.Get(ctx, client, userID)
|
||||
return storage.GetDataBrokerMessage[user.User](ctx, userID, 0)
|
||||
}
|
||||
|
||||
func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
||||
|
@ -72,21 +68,16 @@ func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
|||
}
|
||||
|
||||
func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) {
|
||||
client := p.state.Load().dataBrokerClient
|
||||
|
||||
res, _ := client.Get(ctx, &databroker.GetRequest{
|
||||
Type: "type.googleapis.com/pomerium.config.Config",
|
||||
Id: "dashboard-settings",
|
||||
})
|
||||
data.IsEnterprise = res.GetRecord() != nil
|
||||
record, _ := storage.GetDataBrokerRecord(ctx, "type.googleapis.com/pomerium.config.Config", "dashboard-settings", 0)
|
||||
data.IsEnterprise = record != nil
|
||||
if !data.IsEnterprise {
|
||||
return
|
||||
}
|
||||
|
||||
data.DirectoryUser, _ = databroker.GetViaJSON[directory.User](ctx, client, directory.UserRecordType, data.Session.GetUserId())
|
||||
data.DirectoryUser, _ = storage.GetDataBrokerObjectViaJSON[directory.User](ctx, directory.UserRecordType, data.Session.GetUserId(), 0)
|
||||
if data.DirectoryUser != nil {
|
||||
for _, groupID := range data.DirectoryUser.GroupIDs {
|
||||
directoryGroup, _ := databroker.GetViaJSON[directory.Group](ctx, client, directory.GroupRecordType, groupID)
|
||||
directoryGroup, _ := storage.GetDataBrokerObjectViaJSON[directory.Group](ctx, directory.GroupRecordType, groupID, 0)
|
||||
if directoryGroup != nil {
|
||||
data.DirectoryGroups = append(data.DirectoryGroups, directoryGroup)
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func Test_getUserInfoData(t *testing.T) {
|
||||
|
@ -65,6 +66,7 @@ func Test_getUserInfoData(t *testing.T) {
|
|||
proxy, err := New(ctx, &config.Config{Options: opts})
|
||||
require.NoError(t, err)
|
||||
proxy.state.Load().dataBrokerClient = client
|
||||
ctx = storage.WithQuerier(ctx, storage.NewQuerier(client))
|
||||
|
||||
require.NoError(t, databrokerpb.PutMulti(ctx, client,
|
||||
makeRecord(&session.Session{
|
||||
|
@ -81,7 +83,7 @@ func Test_getUserInfoData(t *testing.T) {
|
|||
"group_ids": []any{"G1", "G2", "G3"},
|
||||
})))
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/.pomerium/", nil)
|
||||
r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{
|
||||
ID: "S1",
|
||||
}))
|
||||
|
@ -89,7 +91,9 @@ func Test_getUserInfoData(t *testing.T) {
|
|||
assert.Equal(t, "S1", data.Session.Id)
|
||||
assert.Equal(t, "U1", data.User.Id)
|
||||
assert.True(t, data.IsEnterprise)
|
||||
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
|
||||
if assert.NotNil(t, data.DirectoryUser) {
|
||||
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/proxy/portal"
|
||||
)
|
||||
|
||||
|
@ -124,6 +125,8 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
|
|||
r.StrictSlash(true)
|
||||
// dashboard handlers are registered to all routes
|
||||
r = p.registerDashboardHandlers(r, opts)
|
||||
// attach the querier to the context
|
||||
r.Use(p.querierMiddleware)
|
||||
r.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(p.tracerProvider)))
|
||||
|
||||
p.currentRouter.Store(r)
|
||||
|
@ -133,3 +136,16 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
|
|||
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
p.currentRouter.Load().ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (p *Proxy) querierMiddleware(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx = storage.WithQuerier(ctx, storage.NewCachingQuerier(
|
||||
storage.NewQuerier(p.state.Load().dataBrokerClient),
|
||||
storage.GlobalCache,
|
||||
))
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||
|
@ -87,19 +88,16 @@ func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.Trace
|
|||
|
||||
state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
|
||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
||||
Type: recordType,
|
||||
Id: recordID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.GetRecord(), nil
|
||||
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||
},
|
||||
func(ctx context.Context, records []*databroker.Record) error {
|
||||
_, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Records: records,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
|
||||
return err
|
||||
},
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue