mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 00:40:25 +02:00
wip
This commit is contained in:
parent
2ad3493c80
commit
3617c67e41
5 changed files with 78 additions and 60 deletions
|
@ -34,11 +34,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
ctx, span := a.tracer.Start(ctx, "authorize.grpc.Check")
|
||||
defer span.End()
|
||||
|
||||
querier := storage.NewCachingQuerier(
|
||||
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||||
storage.GlobalCache,
|
||||
)
|
||||
ctx = storage.WithQuerier(ctx, querier)
|
||||
ctx = a.withQuerierForCheckRequest(ctx)
|
||||
|
||||
state := a.state.Load()
|
||||
|
||||
|
@ -172,6 +168,14 @@ func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *Authorize) withQuerierForCheckRequest(ctx context.Context) context.Context {
|
||||
querier := storage.NewCachingQuerier(
|
||||
storage.NewQuerier(a.state.Load().dataBrokerClient),
|
||||
storage.GlobalCache,
|
||||
)
|
||||
return storage.WithQuerier(ctx, querier)
|
||||
}
|
||||
|
||||
func getHTTPRequestFromCheckRequest(req *envoy_service_auth_v3.CheckRequest) *http.Request {
|
||||
hattrs := req.GetAttributes().GetRequest().GetHttp()
|
||||
u := getCheckRequestURL(req)
|
||||
|
|
|
@ -25,6 +25,10 @@ var (
|
|||
|
||||
// RuntimeFlagAddExtraMetricsLabels enables adding extra labels to metrics (host and installation id)
|
||||
RuntimeFlagAddExtraMetricsLabels = runtimeFlag("add_extra_metrics_labels", true)
|
||||
|
||||
// RuntimeFlagAuthorizeUseSyncedData enables synced data for querying the databroker for
|
||||
// certain types of data.
|
||||
RuntimeFlagAuthorizeUseSyncedData = runtimeFlag("authorize_use_synced_data", false)
|
||||
)
|
||||
|
||||
// RuntimeFlag is a runtime flag that can flip on/off certain features
|
||||
|
|
48
pkg/storage/querier_fallback.go
Normal file
48
pkg/storage/querier_fallback.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
grpc "google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
type fallbackQuerier []Querier
|
||||
|
||||
// NewFallbackQuerier creates a new FallbackQuerier.
|
||||
func NewFallbackQuerier(queriers ...Querier) Querier {
|
||||
return fallbackQuerier(queriers)
|
||||
}
|
||||
|
||||
// InvalidateCache invalidates the cache of all the queriers.
|
||||
func (q fallbackQuerier) InvalidateCache(ctx context.Context, req *databroker.QueryRequest) {
|
||||
for _, qq := range q {
|
||||
qq.InvalidateCache(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
// Query returns the first querier's results that doesn't result in an error.
|
||||
func (q fallbackQuerier) Query(ctx context.Context, req *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||
if len(q) == 0 {
|
||||
return nil, ErrUnavailable
|
||||
}
|
||||
|
||||
var merr error
|
||||
for _, qq := range q {
|
||||
res, err := qq.Query(ctx, req, opts...)
|
||||
if err != nil {
|
||||
return res, nil
|
||||
}
|
||||
merr = errors.Join(merr, err)
|
||||
}
|
||||
return nil, merr
|
||||
}
|
||||
|
||||
// Stop stops all the queriers.
|
||||
func (q fallbackQuerier) Stop() {
|
||||
for _, qq := range q {
|
||||
qq.Stop()
|
||||
}
|
||||
}
|
|
@ -17,10 +17,12 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
// ErrUnavailable is an error indicating that
|
||||
var ErrUnavailable = fmt.Errorf("unavailable")
|
||||
|
||||
type syncQuerier struct {
|
||||
client databroker.DataBrokerServiceClient
|
||||
recordType string
|
||||
fallback Querier
|
||||
|
||||
cancel context.CancelFunc
|
||||
serverVersion uint64
|
||||
|
@ -36,12 +38,10 @@ type syncQuerier struct {
|
|||
func NewSyncQuerier(
|
||||
client databroker.DataBrokerServiceClient,
|
||||
recordType string,
|
||||
fallback Querier,
|
||||
) Querier {
|
||||
q := &syncQuerier{
|
||||
client: client,
|
||||
recordType: recordType,
|
||||
fallback: fallback,
|
||||
records: NewRecordCollection(),
|
||||
}
|
||||
|
||||
|
@ -52,29 +52,15 @@ func NewSyncQuerier(
|
|||
return q
|
||||
}
|
||||
|
||||
func (q *syncQuerier) InvalidateCache(
|
||||
ctx context.Context,
|
||||
req *databroker.QueryRequest,
|
||||
) {
|
||||
q.mu.RLock()
|
||||
ready := q.ready
|
||||
q.mu.RUnlock()
|
||||
|
||||
// only invalidate the fallback querier if we aren't ready yet
|
||||
if ready {
|
||||
q.fallback.InvalidateCache(ctx, req)
|
||||
}
|
||||
func (q *syncQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func (q *syncQuerier) Query(
|
||||
ctx context.Context,
|
||||
req *databroker.QueryRequest,
|
||||
opts ...grpc.CallOption,
|
||||
) (*databroker.QueryResponse, error) {
|
||||
func (q *syncQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||
q.mu.RLock()
|
||||
if !q.ready || req.GetType() != q.recordType {
|
||||
q.mu.RUnlock()
|
||||
return q.fallback.Query(ctx, req, opts...)
|
||||
return nil, ErrUnavailable
|
||||
}
|
||||
defer q.mu.RUnlock()
|
||||
return QueryRecordCollections(map[string]RecordCollection{
|
||||
|
|
|
@ -29,39 +29,27 @@ func TestSyncQuerier(t *testing.T) {
|
|||
|
||||
client := databrokerpb.NewDataBrokerServiceClient(cc)
|
||||
|
||||
q1r1 := &databrokerpb.Record{
|
||||
Type: "t1",
|
||||
Id: "r1",
|
||||
Data: protoutil.ToAny("q1"),
|
||||
}
|
||||
q1r2 := &databrokerpb.Record{
|
||||
Type: "t2",
|
||||
Id: "r2",
|
||||
Data: protoutil.ToAny("q2"),
|
||||
}
|
||||
q1 := storage.NewStaticQuerier(q1r1, q1r2)
|
||||
|
||||
q2r1 := &databrokerpb.Record{
|
||||
r1 := &databrokerpb.Record{
|
||||
Type: "t1",
|
||||
Id: "r1",
|
||||
Data: protoutil.ToAny("q2"),
|
||||
}
|
||||
_, err := client.Put(ctx, &databrokerpb.PutRequest{
|
||||
Records: []*databrokerpb.Record{q2r1},
|
||||
Records: []*databrokerpb.Record{r1},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
q2r2 := &databrokerpb.Record{
|
||||
r2 := &databrokerpb.Record{
|
||||
Type: "t1",
|
||||
Id: "r2",
|
||||
Data: protoutil.ToAny("q2"),
|
||||
}
|
||||
|
||||
q2 := storage.NewSyncQuerier(client, "t1", q1)
|
||||
t.Cleanup(q2.Stop)
|
||||
q := storage.NewSyncQuerier(client, "t1")
|
||||
t.Cleanup(q.Stop)
|
||||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
res, err := q2.Query(ctx, &databrokerpb.QueryRequest{
|
||||
res, err := q.Query(ctx, &databrokerpb.QueryRequest{
|
||||
Type: "t1",
|
||||
Filter: newStruct(t, map[string]any{
|
||||
"id": "r1",
|
||||
|
@ -69,29 +57,17 @@ func TestSyncQuerier(t *testing.T) {
|
|||
Limit: 1,
|
||||
})
|
||||
if assert.NoError(c, err) && assert.Len(c, res.Records, 1) {
|
||||
assert.Empty(c, cmp.Diff(q2r1.Data, res.Records[0].Data, protocmp.Transform()))
|
||||
assert.Empty(c, cmp.Diff(r1.Data, res.Records[0].Data, protocmp.Transform()))
|
||||
}
|
||||
}, time.Second*10, time.Millisecond*50, "should sync records")
|
||||
|
||||
res, err := q2.Query(ctx, &databrokerpb.QueryRequest{
|
||||
Type: "t2",
|
||||
Filter: newStruct(t, map[string]any{
|
||||
"id": "r2",
|
||||
}),
|
||||
Limit: 1,
|
||||
})
|
||||
if assert.NoError(t, err) && assert.Len(t, res.Records, 1) {
|
||||
assert.Empty(t, cmp.Diff(q1r2.Data, res.Records[0].Data, protocmp.Transform()),
|
||||
"should use fallback querier for other record types")
|
||||
}
|
||||
|
||||
_, err = client.Put(ctx, &databrokerpb.PutRequest{
|
||||
Records: []*databrokerpb.Record{q2r2},
|
||||
Records: []*databrokerpb.Record{r2},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
res, err := q2.Query(ctx, &databrokerpb.QueryRequest{
|
||||
res, err := q.Query(ctx, &databrokerpb.QueryRequest{
|
||||
Type: "t1",
|
||||
Filter: newStruct(t, map[string]any{
|
||||
"id": "r2",
|
||||
|
@ -99,7 +75,7 @@ func TestSyncQuerier(t *testing.T) {
|
|||
Limit: 1,
|
||||
})
|
||||
if assert.NoError(c, err) && assert.Len(c, res.Records, 1) {
|
||||
assert.Empty(c, cmp.Diff(q2r2.Data, res.Records[0].Data, protocmp.Transform()))
|
||||
assert.Empty(c, cmp.Diff(r2.Data, res.Records[0].Data, protocmp.Transform()))
|
||||
}
|
||||
}, time.Second*10, time.Millisecond*50, "should pick up changes")
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue