This commit is contained in:
Caleb Doxsey 2025-04-07 11:40:11 -06:00
parent 2ad3493c80
commit 3617c67e41
5 changed files with 78 additions and 60 deletions

View file

@ -34,11 +34,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check")
defer span.End()
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
storage.GlobalCache,
)
ctx = storage.WithQuerier(ctx, querier)
ctx = a.withQuerierForCheckRequest(ctx)
state := a.state.Load()
@ -172,6 +168,14 @@ func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy {
return nil
}
func (a *Authorize) withQuerierForCheckRequest(ctx context.Context) context.Context {
querier := storage.NewCachingQuerier(
storage.NewQuerier(a.state.Load().dataBrokerClient),
storage.GlobalCache,
)
return storage.WithQuerier(ctx, querier)
}
func getHTTPRequestFromCheckRequest(req *envoy_service_auth_v3.CheckRequest) *http.Request {
hattrs := req.GetAttributes().GetRequest().GetHttp()
u := getCheckRequestURL(req)

View file

@ -25,6 +25,10 @@ var (
// RuntimeFlagAddExtraMetricsLabels enables adding extra labels to metrics (host and installation id)
RuntimeFlagAddExtraMetricsLabels = runtimeFlag("add_extra_metrics_labels", true)
// RuntimeFlagAuthorizeUseSyncedData enables synced data for querying the databroker for
// certain types of data.
RuntimeFlagAuthorizeUseSyncedData = runtimeFlag("authorize_use_synced_data", false)
)
// RuntimeFlag is a runtime flag that can flip on/off certain features

View file

@ -0,0 +1,48 @@
package storage
import (
"context"
"errors"
grpc "google.golang.org/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type fallbackQuerier []Querier
// NewFallbackQuerier creates a new FallbackQuerier.
func NewFallbackQuerier(queriers ...Querier) Querier {
return fallbackQuerier(queriers)
}
// InvalidateCache invalidates the cache of all the queriers.
func (q fallbackQuerier) InvalidateCache(ctx context.Context, req *databroker.QueryRequest) {
for _, qq := range q {
qq.InvalidateCache(ctx, req)
}
}
// Query returns the first querier's results that doesn't result in an error.
func (q fallbackQuerier) Query(ctx context.Context, req *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
if len(q) == 0 {
return nil, ErrUnavailable
}
var merr error
for _, qq := range q {
res, err := qq.Query(ctx, req, opts...)
if err != nil {
return res, nil
}
merr = errors.Join(merr, err)
}
return nil, merr
}
// Stop stops all the queriers.
func (q fallbackQuerier) Stop() {
for _, qq := range q {
qq.Stop()
}
}

View file

@ -17,10 +17,12 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// ErrUnavailable is an error indicating that
var ErrUnavailable = fmt.Errorf("unavailable")
type syncQuerier struct {
client databroker.DataBrokerServiceClient
recordType string
fallback Querier
cancel context.CancelFunc
serverVersion uint64
@ -36,12 +38,10 @@ type syncQuerier struct {
func NewSyncQuerier(
client databroker.DataBrokerServiceClient,
recordType string,
fallback Querier,
) Querier {
q := &syncQuerier{
client: client,
recordType: recordType,
fallback: fallback,
records: NewRecordCollection(),
}
@ -52,29 +52,15 @@ func NewSyncQuerier(
return q
}
func (q *syncQuerier) InvalidateCache(
ctx context.Context,
req *databroker.QueryRequest,
) {
q.mu.RLock()
ready := q.ready
q.mu.RUnlock()
// only invalidate the fallback querier if we aren't ready yet
if ready {
q.fallback.InvalidateCache(ctx, req)
}
func (q *syncQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {
// do nothing
}
func (q *syncQuerier) Query(
ctx context.Context,
req *databroker.QueryRequest,
opts ...grpc.CallOption,
) (*databroker.QueryResponse, error) {
func (q *syncQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
q.mu.RLock()
if !q.ready || req.GetType() != q.recordType {
q.mu.RUnlock()
return q.fallback.Query(ctx, req, opts...)
return nil, ErrUnavailable
}
defer q.mu.RUnlock()
return QueryRecordCollections(map[string]RecordCollection{

View file

@ -29,39 +29,27 @@ func TestSyncQuerier(t *testing.T) {
client := databrokerpb.NewDataBrokerServiceClient(cc)
q1r1 := &databrokerpb.Record{
Type: "t1",
Id: "r1",
Data: protoutil.ToAny("q1"),
}
q1r2 := &databrokerpb.Record{
Type: "t2",
Id: "r2",
Data: protoutil.ToAny("q2"),
}
q1 := storage.NewStaticQuerier(q1r1, q1r2)
q2r1 := &databrokerpb.Record{
r1 := &databrokerpb.Record{
Type: "t1",
Id: "r1",
Data: protoutil.ToAny("q2"),
}
_, err := client.Put(ctx, &databrokerpb.PutRequest{
Records: []*databrokerpb.Record{q2r1},
Records: []*databrokerpb.Record{r1},
})
require.NoError(t, err)
q2r2 := &databrokerpb.Record{
r2 := &databrokerpb.Record{
Type: "t1",
Id: "r2",
Data: protoutil.ToAny("q2"),
}
q2 := storage.NewSyncQuerier(client, "t1", q1)
t.Cleanup(q2.Stop)
q := storage.NewSyncQuerier(client, "t1")
t.Cleanup(q.Stop)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
res, err := q2.Query(ctx, &databrokerpb.QueryRequest{
res, err := q.Query(ctx, &databrokerpb.QueryRequest{
Type: "t1",
Filter: newStruct(t, map[string]any{
"id": "r1",
@ -69,29 +57,17 @@ func TestSyncQuerier(t *testing.T) {
Limit: 1,
})
if assert.NoError(c, err) && assert.Len(c, res.Records, 1) {
assert.Empty(c, cmp.Diff(q2r1.Data, res.Records[0].Data, protocmp.Transform()))
assert.Empty(c, cmp.Diff(r1.Data, res.Records[0].Data, protocmp.Transform()))
}
}, time.Second*10, time.Millisecond*50, "should sync records")
res, err := q2.Query(ctx, &databrokerpb.QueryRequest{
Type: "t2",
Filter: newStruct(t, map[string]any{
"id": "r2",
}),
Limit: 1,
})
if assert.NoError(t, err) && assert.Len(t, res.Records, 1) {
assert.Empty(t, cmp.Diff(q1r2.Data, res.Records[0].Data, protocmp.Transform()),
"should use fallback querier for other record types")
}
_, err = client.Put(ctx, &databrokerpb.PutRequest{
Records: []*databrokerpb.Record{q2r2},
Records: []*databrokerpb.Record{r2},
})
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
res, err := q2.Query(ctx, &databrokerpb.QueryRequest{
res, err := q.Query(ctx, &databrokerpb.QueryRequest{
Type: "t1",
Filter: newStruct(t, map[string]any{
"id": "r2",
@ -99,7 +75,7 @@ func TestSyncQuerier(t *testing.T) {
Limit: 1,
})
if assert.NoError(c, err) && assert.Len(c, res.Records, 1) {
assert.Empty(c, cmp.Diff(q2r2.Data, res.Records[0].Data, protocmp.Transform()))
assert.Empty(c, cmp.Diff(r2.Data, res.Records[0].Data, protocmp.Transform()))
}
}, time.Second*10, time.Millisecond*50, "should pick up changes")
}