This commit is contained in:
Caleb Doxsey 2025-04-07 11:40:11 -06:00
parent 2ad3493c80
commit 3617c67e41
5 changed files with 78 additions and 60 deletions

View file

@ -34,11 +34,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check") ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check")
defer span.End() defer span.End()
querier := storage.NewCachingQuerier( ctx = a.withQuerierForCheckRequest(ctx)
storage.NewQuerier(a.state.Load().dataBrokerClient),
storage.GlobalCache,
)
ctx = storage.WithQuerier(ctx, querier)
state := a.state.Load() state := a.state.Load()
@ -172,6 +168,14 @@ func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy {
return nil return nil
} }
func (a *Authorize) withQuerierForCheckRequest(ctx context.Context) context.Context {
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
storage.GlobalCache,
)
return storage.WithQuerier(ctx, querier)
}
func getHTTPRequestFromCheckRequest(req *envoy_service_auth_v3.CheckRequest) *http.Request { func getHTTPRequestFromCheckRequest(req *envoy_service_auth_v3.CheckRequest) *http.Request {
hattrs := req.GetAttributes().GetRequest().GetHttp() hattrs := req.GetAttributes().GetRequest().GetHttp()
u := getCheckRequestURL(req) u := getCheckRequestURL(req)

View file

@ -25,6 +25,10 @@ var (
// RuntimeFlagAddExtraMetricsLabels enables adding extra labels to metrics (host and installation id) // RuntimeFlagAddExtraMetricsLabels enables adding extra labels to metrics (host and installation id)
RuntimeFlagAddExtraMetricsLabels = runtimeFlag("add_extra_metrics_labels", true) RuntimeFlagAddExtraMetricsLabels = runtimeFlag("add_extra_metrics_labels", true)
// RuntimeFlagAuthorizeUseSyncedData enables synced data for querying the databroker for
// certain types of data.
RuntimeFlagAuthorizeUseSyncedData = runtimeFlag("authorize_use_synced_data", false)
) )
// RuntimeFlag is a runtime flag that can flip on/off certain features // RuntimeFlag is a runtime flag that can flip on/off certain features

View file

@ -0,0 +1,48 @@
package storage
import (
"context"
"errors"
grpc "google.golang.org/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type fallbackQuerier []Querier
// NewFallbackQuerier creates a new FallbackQuerier.
func NewFallbackQuerier(queriers ...Querier) Querier {
return fallbackQuerier(queriers)
}
// InvalidateCache invalidates the cache of all the queriers.
func (q fallbackQuerier) InvalidateCache(ctx context.Context, req *databroker.QueryRequest) {
for _, qq := range q {
qq.InvalidateCache(ctx, req)
}
}
// Query returns the first querier's results that doesn't result in an error.
func (q fallbackQuerier) Query(ctx context.Context, req *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
if len(q) == 0 {
return nil, ErrUnavailable
}
var merr error
for _, qq := range q {
res, err := qq.Query(ctx, req, opts...)
if err != nil {
return res, nil
}
merr = errors.Join(merr, err)
}
return nil, merr
}
// Stop stops all the queriers.
func (q fallbackQuerier) Stop() {
for _, qq := range q {
qq.Stop()
}
}

View file

@ -17,10 +17,12 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
) )
// ErrUnavailable is an error indicating that
var ErrUnavailable = fmt.Errorf("unavailable")
type syncQuerier struct { type syncQuerier struct {
client databroker.DataBrokerServiceClient client databroker.DataBrokerServiceClient
recordType string recordType string
fallback Querier
cancel context.CancelFunc cancel context.CancelFunc
serverVersion uint64 serverVersion uint64
@ -36,12 +38,10 @@ type syncQuerier struct {
func NewSyncQuerier( func NewSyncQuerier(
client databroker.DataBrokerServiceClient, client databroker.DataBrokerServiceClient,
recordType string, recordType string,
fallback Querier,
) Querier { ) Querier {
q := &syncQuerier{ q := &syncQuerier{
client: client, client: client,
recordType: recordType, recordType: recordType,
fallback: fallback,
records: NewRecordCollection(), records: NewRecordCollection(),
} }
@ -52,29 +52,15 @@ func NewSyncQuerier(
return q return q
} }
func (q *syncQuerier) InvalidateCache( func (q *syncQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {
ctx context.Context, // do nothing
req *databroker.QueryRequest,
) {
q.mu.RLock()
ready := q.ready
q.mu.RUnlock()
// only invalidate the fallback querier if we aren't ready yet
if ready {
q.fallback.InvalidateCache(ctx, req)
}
} }
func (q *syncQuerier) Query( func (q *syncQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
ctx context.Context,
req *databroker.QueryRequest,
opts ...grpc.CallOption,
) (*databroker.QueryResponse, error) {
q.mu.RLock() q.mu.RLock()
if !q.ready || req.GetType() != q.recordType { if !q.ready || req.GetType() != q.recordType {
q.mu.RUnlock() q.mu.RUnlock()
return q.fallback.Query(ctx, req, opts...) return nil, ErrUnavailable
} }
defer q.mu.RUnlock() defer q.mu.RUnlock()
return QueryRecordCollections(map[string]RecordCollection{ return QueryRecordCollections(map[string]RecordCollection{

View file

@ -29,39 +29,27 @@ func TestSyncQuerier(t *testing.T) {
client := databrokerpb.NewDataBrokerServiceClient(cc) client := databrokerpb.NewDataBrokerServiceClient(cc)
q1r1 := &databrokerpb.Record{ r1 := &databrokerpb.Record{
Type: "t1",
Id: "r1",
Data: protoutil.ToAny("q1"),
}
q1r2 := &databrokerpb.Record{
Type: "t2",
Id: "r2",
Data: protoutil.ToAny("q2"),
}
q1 := storage.NewStaticQuerier(q1r1, q1r2)
q2r1 := &databrokerpb.Record{
Type: "t1", Type: "t1",
Id: "r1", Id: "r1",
Data: protoutil.ToAny("q2"), Data: protoutil.ToAny("q2"),
} }
_, err := client.Put(ctx, &databrokerpb.PutRequest{ _, err := client.Put(ctx, &databrokerpb.PutRequest{
Records: []*databrokerpb.Record{q2r1}, Records: []*databrokerpb.Record{r1},
}) })
require.NoError(t, err) require.NoError(t, err)
q2r2 := &databrokerpb.Record{ r2 := &databrokerpb.Record{
Type: "t1", Type: "t1",
Id: "r2", Id: "r2",
Data: protoutil.ToAny("q2"), Data: protoutil.ToAny("q2"),
} }
q2 := storage.NewSyncQuerier(client, "t1", q1) q := storage.NewSyncQuerier(client, "t1")
t.Cleanup(q2.Stop) t.Cleanup(q.Stop)
assert.EventuallyWithT(t, func(c *assert.CollectT) { assert.EventuallyWithT(t, func(c *assert.CollectT) {
res, err := q2.Query(ctx, &databrokerpb.QueryRequest{ res, err := q.Query(ctx, &databrokerpb.QueryRequest{
Type: "t1", Type: "t1",
Filter: newStruct(t, map[string]any{ Filter: newStruct(t, map[string]any{
"id": "r1", "id": "r1",
@ -69,29 +57,17 @@ func TestSyncQuerier(t *testing.T) {
Limit: 1, Limit: 1,
}) })
if assert.NoError(c, err) && assert.Len(c, res.Records, 1) { if assert.NoError(c, err) && assert.Len(c, res.Records, 1) {
assert.Empty(c, cmp.Diff(q2r1.Data, res.Records[0].Data, protocmp.Transform())) assert.Empty(c, cmp.Diff(r1.Data, res.Records[0].Data, protocmp.Transform()))
} }
}, time.Second*10, time.Millisecond*50, "should sync records") }, time.Second*10, time.Millisecond*50, "should sync records")
res, err := q2.Query(ctx, &databrokerpb.QueryRequest{
Type: "t2",
Filter: newStruct(t, map[string]any{
"id": "r2",
}),
Limit: 1,
})
if assert.NoError(t, err) && assert.Len(t, res.Records, 1) {
assert.Empty(t, cmp.Diff(q1r2.Data, res.Records[0].Data, protocmp.Transform()),
"should use fallback querier for other record types")
}
_, err = client.Put(ctx, &databrokerpb.PutRequest{ _, err = client.Put(ctx, &databrokerpb.PutRequest{
Records: []*databrokerpb.Record{q2r2}, Records: []*databrokerpb.Record{r2},
}) })
require.NoError(t, err) require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) { assert.EventuallyWithT(t, func(c *assert.CollectT) {
res, err := q2.Query(ctx, &databrokerpb.QueryRequest{ res, err := q.Query(ctx, &databrokerpb.QueryRequest{
Type: "t1", Type: "t1",
Filter: newStruct(t, map[string]any{ Filter: newStruct(t, map[string]any{
"id": "r2", "id": "r2",
@ -99,7 +75,7 @@ func TestSyncQuerier(t *testing.T) {
Limit: 1, Limit: 1,
}) })
if assert.NoError(c, err) && assert.Len(c, res.Records, 1) { if assert.NoError(c, err) && assert.Len(c, res.Records, 1) {
assert.Empty(c, cmp.Diff(q2r2.Data, res.Records[0].Data, protocmp.Transform())) assert.Empty(c, cmp.Diff(r2.Data, res.Records[0].Data, protocmp.Transform()))
} }
}, time.Second*10, time.Millisecond*50, "should pick up changes") }, time.Second*10, time.Millisecond*50, "should pick up changes")
} }