proxy: use querier cache for user info (#5532)

This commit is contained in:
Caleb Doxsey 2025-03-20 09:50:22 -06:00 committed by GitHub
parent 08623ef346
commit bc263e3ee5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 259 additions and 156 deletions

View file

@ -6,7 +6,6 @@ import (
"context"
"fmt"
"slices"
"time"
"github.com/rs/zerolog"
oteltrace "go.opentelemetry.io/otel/trace"
@ -31,7 +30,6 @@ type Authorize struct {
store *store.Store
currentConfig *atomicutil.Value[*config.Config]
accessTracker *AccessTracker
globalCache storage.Cache
groupsCacheWarmer *cacheWarmer
tracerProvider oteltrace.TracerProvider
@ -45,7 +43,6 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
a := &Authorize{
currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}),
store: store.New(),
globalCache: storage.NewGlobalCache(time.Minute),
tracerProvider: tracerProvider,
tracer: tracer,
}
@ -57,7 +54,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
}
a.state = atomicutil.NewValue(state)
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, a.globalCache, directory.GroupRecordType)
a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, directory.GroupRecordType)
return a, nil
}

View file

@ -3,9 +3,6 @@ package authorize
import (
"context"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
@ -18,47 +15,6 @@ type sessionOrServiceAccount interface {
Validate() error
}
func getDataBrokerRecord(
ctx context.Context,
recordType string,
recordID string,
lowestRecordVersion uint64,
) (*databroker.Record, error) {
q := storage.GetQuerier(ctx)
req := &databroker.QueryRequest{
Type: recordType,
Limit: 1,
}
req.SetFilterByIDOrIndex(recordID)
res, err := q.Query(ctx, req, grpc.WaitForReady(true))
if err != nil {
return nil, err
}
if len(res.GetRecords()) == 0 {
return nil, storage.ErrNotFound
}
// if the current record version is less than the lowest we'll accept, invalidate the cache
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
q.InvalidateCache(ctx, req)
} else {
return res.GetRecords()[0], nil
}
// retry with an up to date cache
res, err = q.Query(ctx, req)
if err != nil {
return nil, err
}
if len(res.GetRecords()) == 0 {
return nil, storage.ErrNotFound
}
return res.GetRecords()[0], nil
}
func (a *Authorize) getDataBrokerSessionOrServiceAccount(
ctx context.Context,
sessionID string,
@ -67,9 +23,9 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount(
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerSessionOrServiceAccount")
defer span.End()
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion)
if storage.IsNotFound(err) {
record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
record, err = storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion)
}
if err != nil {
return nil, err
@ -100,7 +56,7 @@ func (a *Authorize) getDataBrokerUser(
ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerUser")
defer span.End()
record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0)
record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0)
if err != nil {
return nil, err
}

View file

@ -2,7 +2,6 @@ package authorize
import (
"context"
"fmt"
"testing"
"time"
@ -12,45 +11,9 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/storage"
)
func Test_getDataBrokerRecord(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
for _, tc := range []struct {
name string
recordVersion, queryVersion uint64
underlyingQueryCount, cachedQueryCount int
}{
{"cached", 1, 1, 1, 2},
{"invalidated", 1, 2, 3, 4},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
sq := storage.NewStaticQuerier(s1)
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
qctx := storage.WithQuerier(ctx, cq)
s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
assert.NoError(t, err)
assert.NotNil(t, s)
s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
assert.NoError(t, err)
assert.NotNil(t, s)
})
}
}
func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) {
t.Parallel()

View file

@ -36,7 +36,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
a.globalCache,
storage.GlobalCache,
)
ctx = storage.WithQuerier(ctx, querier)
@ -98,7 +98,7 @@ func (a *Authorize) loadSession(
// attempt to create a session from an incoming idp token
s, err = config.NewIncomingIDPTokenSessionCreator(
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
return getDataBrokerRecord(ctx, recordType, recordID, 0)
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
},
func(ctx context.Context, records []*databroker.Record) error {
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
@ -107,15 +107,7 @@ func (a *Authorize) loadSession(
if err != nil {
return err
}
// invalidate cache
for _, record := range records {
q := &databroker.QueryRequest{
Type: record.GetType(),
Limit: 1,
}
q.SetFilterByIDOrIndex(record.GetId())
storage.GetQuerier(ctx).InvalidateCache(ctx, q)
}
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
return nil
},
).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq)

View file

@ -3,14 +3,12 @@ package databroker
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
structpb "google.golang.org/protobuf/types/known/structpb"
@ -53,34 +51,6 @@ func Get(ctx context.Context, client DataBrokerServiceClient, object recordObjec
return res.GetRecord().GetData().UnmarshalTo(object)
}
// GetViaJSON gets a record from the databroker, marshals it to JSON, and then unmarshals it to the given type.
func GetViaJSON[T any](ctx context.Context, client DataBrokerServiceClient, recordType, recordID string) (*T, error) {
res, err := client.Get(ctx, &GetRequest{
Type: recordType,
Id: recordID,
})
if err != nil {
return nil, err
}
msg, err := res.GetRecord().GetData().UnmarshalNew()
if err != nil {
return nil, err
}
bs, err := protojson.Marshal(msg)
if err != nil {
return nil, err
}
var obj T
err = json.Unmarshal(bs, &obj)
if err != nil {
return nil, err
}
return &obj, nil
}
// Put puts a record into the databroker.
func Put(ctx context.Context, client DataBrokerServiceClient, objects ...recordObject) (*PutResponse, error) {
records := make([]*Record, len(objects))

View file

@ -107,3 +107,6 @@ func (cache *globalCache) set(expiry time.Time, key, value []byte) {
cache.fastcache.Set(key, item)
cache.mu.Unlock()
}
// GlobalCache is a global cache with a TTL of one minute.
var GlobalCache = NewGlobalCache(time.Minute)

View file

@ -15,6 +15,7 @@ import (
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/protoutil"
)
@ -222,3 +223,114 @@ func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) {
Deterministic: true,
}).Marshal(res)
}
// GetDataBrokerRecord uses a querier to get a databroker record.
func GetDataBrokerRecord(
ctx context.Context,
recordType string,
recordID string,
lowestRecordVersion uint64,
) (*databroker.Record, error) {
q := GetQuerier(ctx)
req := &databroker.QueryRequest{
Type: recordType,
Limit: 1,
}
req.SetFilterByIDOrIndex(recordID)
res, err := q.Query(ctx, req, grpc.WaitForReady(true))
if err != nil {
return nil, err
}
if len(res.GetRecords()) == 0 {
return nil, ErrNotFound
}
// if the current record version is less than the lowest we'll accept, invalidate the cache
if res.GetRecords()[0].GetVersion() < lowestRecordVersion {
q.InvalidateCache(ctx, req)
} else {
return res.GetRecords()[0], nil
}
// retry with an up to date cache
res, err = q.Query(ctx, req)
if err != nil {
return nil, err
}
if len(res.GetRecords()) == 0 {
return nil, ErrNotFound
}
return res.GetRecords()[0], nil
}
// GetDataBrokerMessage gets a databroker record and converts it into the message type.
func GetDataBrokerMessage[T any, TMessage interface {
*T
proto.Message
}](
ctx context.Context,
recordID string,
lowestRecordVersion uint64,
) (TMessage, error) {
var msg T
record, err := GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(TMessage(&msg)), recordID, lowestRecordVersion)
if err != nil {
return nil, err
}
err = record.GetData().UnmarshalTo(TMessage(&msg))
if err != nil {
return nil, err
}
return TMessage(&msg), nil
}
// GetDataBrokerObjectViaJSON gets a databroker record and converts it into the object type by going through protojson.
func GetDataBrokerObjectViaJSON[T any](
ctx context.Context,
recordType string,
recordID string,
lowestRecordVersion uint64,
) (*T, error) {
record, err := GetDataBrokerRecord(ctx, recordType, recordID, lowestRecordVersion)
if err != nil {
return nil, err
}
msg, err := record.GetData().UnmarshalNew()
if err != nil {
return nil, err
}
bs, err := protojson.Marshal(msg)
if err != nil {
return nil, err
}
var obj T
err = json.Unmarshal(bs, &obj)
if err != nil {
return nil, err
}
return &obj, nil
}
// InvalidateCacheForDataBrokerRecords invalidates the cache of the querier for the databroker records.
func InvalidateCacheForDataBrokerRecords(
ctx context.Context,
records ...*databroker.Record,
) {
for _, record := range records {
q := &databroker.QueryRequest{
Type: record.GetType(),
Limit: 1,
}
q.SetFilterByIDOrIndex(record.GetId())
GetQuerier(ctx).InvalidateCache(ctx, q)
}
}

101
pkg/storage/querier_test.go Normal file
View file

@ -0,0 +1,101 @@
package storage_test
import (
"context"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/storage"
)
func TestGetDataBrokerRecord(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
for _, tc := range []struct {
name string
recordVersion, queryVersion uint64
underlyingQueryCount, cachedQueryCount int
}{
{"cached", 1, 1, 1, 2},
{"invalidated", 1, 2, 3, 4},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)}
sq := storage.NewStaticQuerier(s1)
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
qctx := storage.WithQuerier(ctx, cq)
s, err := storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
assert.NoError(t, err)
assert.NotNil(t, s)
s, err = storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion)
assert.NoError(t, err)
assert.NotNil(t, s)
})
}
}
func TestGetDataBrokerMessage(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute)
s1 := &session.Session{Id: "s1"}
sq := storage.NewStaticQuerier(s1)
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
qctx := storage.WithQuerier(ctx, cq)
s2, err := storage.GetDataBrokerMessage[session.Session](qctx, "s1", 0)
assert.NoError(t, err)
assert.Empty(t, cmp.Diff(s1, s2, protocmp.Transform()))
_, err = storage.GetDataBrokerMessage[session.Session](qctx, "s2", 0)
assert.ErrorIs(t, err, storage.ErrNotFound)
}
func TestGetDataBrokerObjectViaJSON(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute)
du1 := &directory.User{
ID: "u1",
Email: "u1@example.com",
DisplayName: "User 1!",
}
sq := storage.NewStaticQuerier(newDirectoryUserRecord(du1))
cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute))
qctx := storage.WithQuerier(ctx, cq)
du2, err := storage.GetDataBrokerObjectViaJSON[directory.User](qctx, directory.UserRecordType, "u1", 0)
assert.NoError(t, err)
assert.Empty(t, cmp.Diff(du1, du2, protocmp.Transform()))
}
func newDirectoryUserRecord(directoryUser *directory.User) *databroker.Record {
m := map[string]any{}
bs, _ := json.Marshal(directoryUser)
_ = json.Unmarshal(bs, &m)
s, _ := structpb.NewStruct(m)
return storage.NewStaticRecord(directory.UserRecordType, s)
}

View file

@ -9,28 +9,24 @@ import (
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/internal/handlers/webauthn"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/webauthnutil"
)
func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Session, isImpersonated bool, err error) {
client := p.state.Load().dataBrokerClient
isImpersonated = false
s, err = session.Get(ctx, client, sessionID)
s, err = storage.GetDataBrokerMessage[session.Session](ctx, sessionID, 0)
if s.GetImpersonateSessionId() != "" {
s, err = session.Get(ctx, client, s.GetImpersonateSessionId())
s, err = storage.GetDataBrokerMessage[session.Session](ctx, s.GetImpersonateSessionId(), 0)
isImpersonated = true
}
return s, isImpersonated, err
}
func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) {
client := p.state.Load().dataBrokerClient
return user.Get(ctx, client, userID)
return storage.GetDataBrokerMessage[user.User](ctx, userID, 0)
}
func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
@ -72,21 +68,16 @@ func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData {
}
func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) {
client := p.state.Load().dataBrokerClient
res, _ := client.Get(ctx, &databroker.GetRequest{
Type: "type.googleapis.com/pomerium.config.Config",
Id: "dashboard-settings",
})
data.IsEnterprise = res.GetRecord() != nil
record, _ := storage.GetDataBrokerRecord(ctx, "type.googleapis.com/pomerium.config.Config", "dashboard-settings", 0)
data.IsEnterprise = record != nil
if !data.IsEnterprise {
return
}
data.DirectoryUser, _ = databroker.GetViaJSON[directory.User](ctx, client, directory.UserRecordType, data.Session.GetUserId())
data.DirectoryUser, _ = storage.GetDataBrokerObjectViaJSON[directory.User](ctx, directory.UserRecordType, data.Session.GetUserId(), 0)
if data.DirectoryUser != nil {
for _, groupID := range data.DirectoryUser.GroupIDs {
directoryGroup, _ := databroker.GetViaJSON[directory.Group](ctx, client, directory.GroupRecordType, groupID)
directoryGroup, _ := storage.GetDataBrokerObjectViaJSON[directory.Group](ctx, directory.GroupRecordType, groupID, 0)
if directoryGroup != nil {
data.DirectoryGroups = append(data.DirectoryGroups, directoryGroup)
}

View file

@ -25,6 +25,7 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
)
func Test_getUserInfoData(t *testing.T) {
@ -65,6 +66,7 @@ func Test_getUserInfoData(t *testing.T) {
proxy, err := New(ctx, &config.Config{Options: opts})
require.NoError(t, err)
proxy.state.Load().dataBrokerClient = client
ctx = storage.WithQuerier(ctx, storage.NewQuerier(client))
require.NoError(t, databrokerpb.PutMulti(ctx, client,
makeRecord(&session.Session{
@ -81,7 +83,7 @@ func Test_getUserInfoData(t *testing.T) {
"group_ids": []any{"G1", "G2", "G3"},
})))
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/.pomerium/", nil)
r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{
ID: "S1",
}))
@ -89,7 +91,9 @@ func Test_getUserInfoData(t *testing.T) {
assert.Equal(t, "S1", data.Session.Id)
assert.Equal(t, "U1", data.User.Id)
assert.True(t, data.IsEnterprise)
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
if assert.NotNil(t, data.DirectoryUser) {
assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs)
}
})
}

View file

@ -21,6 +21,7 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/proxy/portal"
)
@ -124,6 +125,8 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
r.StrictSlash(true)
// dashboard handlers are registered to all routes
r = p.registerDashboardHandlers(r, opts)
// attach the querier to the context
r.Use(p.querierMiddleware)
r.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(p.tracerProvider)))
p.currentRouter.Store(r)
@ -133,3 +136,16 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
p.currentRouter.Load().ServeHTTP(w, r)
}
func (p *Proxy) querierMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = storage.WithQuerier(ctx, storage.NewCachingQuerier(
storage.NewQuerier(p.state.Load().dataBrokerClient),
storage.GlobalCache,
))
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
})
}

View file

@ -13,6 +13,7 @@ import (
"github.com/pomerium/pomerium/internal/authenticateflow"
"github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
@ -87,19 +88,16 @@ func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.Trace
state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator(
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{
Type: recordType,
Id: recordID,
})
if err != nil {
return nil, err
}
return res.GetRecord(), nil
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
},
func(ctx context.Context, records []*databroker.Record) error {
_, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: records,
})
if err != nil {
return err
}
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
return err
},
)