mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-05 10:58:11 +02:00
wip
This commit is contained in:
parent
676c00ac97
commit
44f776a223
6 changed files with 356 additions and 6 deletions
|
@ -23,6 +23,7 @@ import (
|
||||||
type Querier interface {
|
type Querier interface {
|
||||||
InvalidateCache(ctx context.Context, in *databroker.QueryRequest)
|
InvalidateCache(ctx context.Context, in *databroker.QueryRequest)
|
||||||
Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error)
|
Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error)
|
||||||
|
Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
// nilQuerier always returns NotFound.
|
// nilQuerier always returns NotFound.
|
||||||
|
@ -34,6 +35,8 @@ func (nilQuerier) Query(_ context.Context, _ *databroker.QueryRequest, _ ...grpc
|
||||||
return nil, status.Error(codes.NotFound, "not found")
|
return nil, status.Error(codes.NotFound, "not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (nilQuerier) Stop() {}
|
||||||
|
|
||||||
type querierKey struct{}
|
type querierKey struct{}
|
||||||
|
|
||||||
// GetQuerier gets the databroker Querier from the context.
|
// GetQuerier gets the databroker Querier from the context.
|
||||||
|
@ -117,6 +120,8 @@ func (q *staticQuerier) Query(_ context.Context, req *databroker.QueryRequest, _
|
||||||
return QueryRecordCollections(q.records, req)
|
return QueryRecordCollections(q.records, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *staticQuerier) Stop() {}
|
||||||
|
|
||||||
type clientQuerier struct {
|
type clientQuerier struct {
|
||||||
client databroker.DataBrokerServiceClient
|
client databroker.DataBrokerServiceClient
|
||||||
}
|
}
|
||||||
|
@ -133,6 +138,8 @@ func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
|
||||||
return q.client.Query(ctx, in, opts...)
|
return q.client.Query(ctx, in, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *clientQuerier) Stop() {}
|
||||||
|
|
||||||
type cachingQuerier struct {
|
type cachingQuerier struct {
|
||||||
q Querier
|
q Querier
|
||||||
cache Cache
|
cache Cache
|
||||||
|
@ -182,6 +189,10 @@ func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest,
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *cachingQuerier) Stop() {
|
||||||
|
q.cache.InvalidateAll()
|
||||||
|
}
|
||||||
|
|
||||||
// MarshalQueryRequest marshales the query request.
|
// MarshalQueryRequest marshales the query request.
|
||||||
func MarshalQueryRequest(req *databroker.QueryRequest) ([]byte, error) {
|
func MarshalQueryRequest(req *databroker.QueryRequest) ([]byte, error) {
|
||||||
return (&proto.MarshalOptions{
|
return (&proto.MarshalOptions{
|
||||||
|
|
176
pkg/storage/querier_sync.go
Normal file
176
pkg/storage/querier_sync.go
Normal file
|
@ -0,0 +1,176 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
grpc "google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
)
|
||||||
|
|
||||||
|
type syncQuerier struct {
|
||||||
|
client databroker.DataBrokerServiceClient
|
||||||
|
recordType string
|
||||||
|
fallback Querier
|
||||||
|
|
||||||
|
cancel context.CancelFunc
|
||||||
|
serverVersion uint64
|
||||||
|
latestRecordVersion uint64
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
ready bool
|
||||||
|
records RecordCollection
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSyncQuerier creates a new Querier backed by an in-memory record collection
|
||||||
|
// filled via sync calls to the databroker.
|
||||||
|
func NewSyncQuerier(
|
||||||
|
client databroker.DataBrokerServiceClient,
|
||||||
|
recordType string,
|
||||||
|
fallback Querier,
|
||||||
|
) Querier {
|
||||||
|
q := &syncQuerier{
|
||||||
|
client: client,
|
||||||
|
recordType: recordType,
|
||||||
|
fallback: fallback,
|
||||||
|
records: NewRecordCollection(),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
q.cancel = cancel
|
||||||
|
go q.run(ctx)
|
||||||
|
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *syncQuerier) InvalidateCache(
|
||||||
|
ctx context.Context,
|
||||||
|
req *databroker.QueryRequest,
|
||||||
|
) {
|
||||||
|
q.mu.RLock()
|
||||||
|
ready := q.ready
|
||||||
|
q.mu.RUnlock()
|
||||||
|
|
||||||
|
// only invalidate the fallback querier if we aren't ready yet
|
||||||
|
if ready {
|
||||||
|
q.fallback.InvalidateCache(ctx, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *syncQuerier) Query(
|
||||||
|
ctx context.Context,
|
||||||
|
req *databroker.QueryRequest,
|
||||||
|
opts ...grpc.CallOption,
|
||||||
|
) (*databroker.QueryResponse, error) {
|
||||||
|
q.mu.RLock()
|
||||||
|
if !q.ready || req.GetType() != q.recordType {
|
||||||
|
q.mu.RUnlock()
|
||||||
|
return q.fallback.Query(ctx, req, opts...)
|
||||||
|
}
|
||||||
|
defer q.mu.RUnlock()
|
||||||
|
return QueryRecordCollections(map[string]RecordCollection{
|
||||||
|
q.recordType: q.records,
|
||||||
|
}, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *syncQuerier) Stop() {
|
||||||
|
q.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *syncQuerier) run(ctx context.Context) {
|
||||||
|
bo := backoff.WithContext(backoff.NewExponentialBackOff(backoff.WithMaxElapsedTime(0)), ctx)
|
||||||
|
_ = backoff.RetryNotify(func() error {
|
||||||
|
if q.serverVersion == 0 {
|
||||||
|
err := q.syncLatest(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return q.sync(ctx)
|
||||||
|
}, bo, func(err error, d time.Duration) {
|
||||||
|
log.Ctx(ctx).Error().
|
||||||
|
Err(err).
|
||||||
|
Dur("delay", d).
|
||||||
|
Msg("storage/sync-querier: error syncing records")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *syncQuerier) syncLatest(ctx context.Context) error {
|
||||||
|
stream, err := q.client.SyncLatest(ctx, &databroker.SyncLatestRequest{
|
||||||
|
Type: q.recordType,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error starting sync latest stream: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
q.mu.Lock()
|
||||||
|
q.ready = false
|
||||||
|
q.records.Clear()
|
||||||
|
q.mu.Unlock()
|
||||||
|
|
||||||
|
for {
|
||||||
|
res, err := stream.Recv()
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("error receiving sync latest message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch res := res.Response.(type) {
|
||||||
|
case *databroker.SyncLatestResponse_Record:
|
||||||
|
q.mu.Lock()
|
||||||
|
q.records.Put(res.Record)
|
||||||
|
q.mu.Unlock()
|
||||||
|
case *databroker.SyncLatestResponse_Versions:
|
||||||
|
q.serverVersion = res.Versions.ServerVersion
|
||||||
|
q.latestRecordVersion = res.Versions.LatestRecordVersion
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown message type from sync latest: %T", res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
q.mu.Lock()
|
||||||
|
q.ready = true
|
||||||
|
q.mu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *syncQuerier) sync(ctx context.Context) error {
|
||||||
|
stream, err := q.client.Sync(ctx, &databroker.SyncRequest{
|
||||||
|
ServerVersion: q.serverVersion,
|
||||||
|
RecordVersion: q.latestRecordVersion,
|
||||||
|
Type: q.recordType,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error starting sync stream: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
res, err := stream.Recv()
|
||||||
|
if status.Code(err) == codes.Aborted {
|
||||||
|
// this indicates the server version changed, so we need to reset
|
||||||
|
q.serverVersion = 0
|
||||||
|
q.latestRecordVersion = 0
|
||||||
|
return fmt.Errorf("stream was aborted due to mismatched server versions: %w", err)
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("error receiving sync message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
q.latestRecordVersion = max(q.latestRecordVersion, res.Record.Version)
|
||||||
|
|
||||||
|
q.mu.Lock()
|
||||||
|
q.records.Put(res.Record)
|
||||||
|
q.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
105
pkg/storage/querier_sync_test.go
Normal file
105
pkg/storage/querier_sync_test.go
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
package storage_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.opentelemetry.io/otel/trace/noop"
|
||||||
|
grpc "google.golang.org/grpc"
|
||||||
|
"google.golang.org/protobuf/testing/protocmp"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/databroker"
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSyncQuerier(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.GetContext(t, 10*time.Minute)
|
||||||
|
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
||||||
|
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx, noop.NewTracerProvider()))
|
||||||
|
})
|
||||||
|
t.Cleanup(func() { cc.Close() })
|
||||||
|
|
||||||
|
client := databrokerpb.NewDataBrokerServiceClient(cc)
|
||||||
|
|
||||||
|
q1r1 := &databrokerpb.Record{
|
||||||
|
Type: "t1",
|
||||||
|
Id: "r1",
|
||||||
|
Data: protoutil.ToAny("q1"),
|
||||||
|
}
|
||||||
|
q1r2 := &databrokerpb.Record{
|
||||||
|
Type: "t2",
|
||||||
|
Id: "r2",
|
||||||
|
Data: protoutil.ToAny("q2"),
|
||||||
|
}
|
||||||
|
q1 := storage.NewStaticQuerier(q1r1, q1r2)
|
||||||
|
|
||||||
|
q2r1 := &databrokerpb.Record{
|
||||||
|
Type: "t1",
|
||||||
|
Id: "r1",
|
||||||
|
Data: protoutil.ToAny("q2"),
|
||||||
|
}
|
||||||
|
_, err := client.Put(ctx, &databrokerpb.PutRequest{
|
||||||
|
Records: []*databrokerpb.Record{q2r1},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
q2r2 := &databrokerpb.Record{
|
||||||
|
Type: "t1",
|
||||||
|
Id: "r2",
|
||||||
|
Data: protoutil.ToAny("q2"),
|
||||||
|
}
|
||||||
|
|
||||||
|
q2 := storage.NewSyncQuerier(client, "t1", q1)
|
||||||
|
t.Cleanup(q2.Stop)
|
||||||
|
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
res, err := q2.Query(ctx, &databrokerpb.QueryRequest{
|
||||||
|
Type: "t1",
|
||||||
|
Filter: newStruct(t, map[string]any{
|
||||||
|
"id": "r1",
|
||||||
|
}),
|
||||||
|
Limit: 1,
|
||||||
|
})
|
||||||
|
if assert.NoError(c, err) && assert.Len(c, res.Records, 1) {
|
||||||
|
assert.Empty(c, cmp.Diff(q2r1.Data, res.Records[0].Data, protocmp.Transform()))
|
||||||
|
}
|
||||||
|
}, time.Second*10, time.Millisecond*50, "should sync records")
|
||||||
|
|
||||||
|
res, err := q2.Query(ctx, &databrokerpb.QueryRequest{
|
||||||
|
Type: "t2",
|
||||||
|
Filter: newStruct(t, map[string]any{
|
||||||
|
"id": "r2",
|
||||||
|
}),
|
||||||
|
Limit: 1,
|
||||||
|
})
|
||||||
|
if assert.NoError(t, err) && assert.Len(t, res.Records, 1) {
|
||||||
|
assert.Empty(t, cmp.Diff(q1r2.Data, res.Records[0].Data, protocmp.Transform()),
|
||||||
|
"should use fallback querier for other record types")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.Put(ctx, &databrokerpb.PutRequest{
|
||||||
|
Records: []*databrokerpb.Record{q2r2},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
res, err := q2.Query(ctx, &databrokerpb.QueryRequest{
|
||||||
|
Type: "t1",
|
||||||
|
Filter: newStruct(t, map[string]any{
|
||||||
|
"id": "r2",
|
||||||
|
}),
|
||||||
|
Limit: 1,
|
||||||
|
})
|
||||||
|
if assert.NoError(c, err) && assert.Len(c, res.Records, 1) {
|
||||||
|
assert.Empty(c, cmp.Diff(q2r2.Data, res.Records[0].Data, protocmp.Transform()))
|
||||||
|
}
|
||||||
|
}, time.Second*10, time.Millisecond*50, "should pick up changes")
|
||||||
|
}
|
45
pkg/storage/querier_typed.go
Normal file
45
pkg/storage/querier_typed.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
grpc "google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
)
|
||||||
|
|
||||||
|
type typedQuerier struct {
|
||||||
|
defaultQuerier Querier
|
||||||
|
queriersByType map[string]Querier
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTypedQuerier creates a new Querier that dispatches to other queries based on the type.
|
||||||
|
func NewTypedQuerier(defaultQuerier Querier, queriersByType map[string]Querier) Querier {
|
||||||
|
return &typedQuerier{
|
||||||
|
defaultQuerier: defaultQuerier,
|
||||||
|
queriersByType: queriersByType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *typedQuerier) InvalidateCache(ctx context.Context, req *databroker.QueryRequest) {
|
||||||
|
qq, ok := q.queriersByType[req.Type]
|
||||||
|
if !ok {
|
||||||
|
qq = q.defaultQuerier
|
||||||
|
}
|
||||||
|
qq.InvalidateCache(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *typedQuerier) Query(ctx context.Context, req *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||||
|
qq, ok := q.queriersByType[req.Type]
|
||||||
|
if !ok {
|
||||||
|
qq = q.defaultQuerier
|
||||||
|
}
|
||||||
|
return qq.Query(ctx, req, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *typedQuerier) Stop() {
|
||||||
|
q.defaultQuerier.Stop()
|
||||||
|
for _, qq := range q.queriersByType {
|
||||||
|
qq.Stop()
|
||||||
|
}
|
||||||
|
}
|
|
@ -20,6 +20,8 @@ import (
|
||||||
type RecordCollection interface {
|
type RecordCollection interface {
|
||||||
// All returns all of the databroker records as a slice. The slice is in insertion order.
|
// All returns all of the databroker records as a slice. The slice is in insertion order.
|
||||||
All() []*databroker.Record
|
All() []*databroker.Record
|
||||||
|
// Clear removes all the records from the collection.
|
||||||
|
Clear()
|
||||||
// Get returns a record based on the record id.
|
// Get returns a record based on the record id.
|
||||||
Get(recordID string) (*databroker.Record, bool)
|
Get(recordID string) (*databroker.Record, bool)
|
||||||
// Len returns the number of records stored in the collection.
|
// Len returns the number of records stored in the collection.
|
||||||
|
@ -65,6 +67,12 @@ func (c *recordCollection) All() []*databroker.Record {
|
||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) Clear() {
|
||||||
|
c.cidrIndex = bart.Table[[]string]{}
|
||||||
|
clear(c.records)
|
||||||
|
c.insertionOrder = list.New()
|
||||||
|
}
|
||||||
|
|
||||||
func (c *recordCollection) Get(recordID string) (*databroker.Record, bool) {
|
func (c *recordCollection) Get(recordID string) (*databroker.Record, bool) {
|
||||||
node, ok := c.records[recordID]
|
node, ok := c.records[recordID]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
|
@ -21,7 +21,7 @@ func TestRecordCollection(t *testing.T) {
|
||||||
|
|
||||||
r1 := &databroker.Record{
|
r1 := &databroker.Record{
|
||||||
Id: "r1",
|
Id: "r1",
|
||||||
Data: newStructData(t, map[string]any{
|
Data: newStructAny(t, map[string]any{
|
||||||
"$index": map[string]any{
|
"$index": map[string]any{
|
||||||
"cidr": "10.0.0.0/24",
|
"cidr": "10.0.0.0/24",
|
||||||
},
|
},
|
||||||
|
@ -29,7 +29,7 @@ func TestRecordCollection(t *testing.T) {
|
||||||
}
|
}
|
||||||
r2 := &databroker.Record{
|
r2 := &databroker.Record{
|
||||||
Id: "r2",
|
Id: "r2",
|
||||||
Data: newStructData(t, map[string]any{
|
Data: newStructAny(t, map[string]any{
|
||||||
"$index": map[string]any{
|
"$index": map[string]any{
|
||||||
"cidr": "192.168.0.0/24",
|
"cidr": "192.168.0.0/24",
|
||||||
},
|
},
|
||||||
|
@ -37,7 +37,7 @@ func TestRecordCollection(t *testing.T) {
|
||||||
}
|
}
|
||||||
r3 := &databroker.Record{
|
r3 := &databroker.Record{
|
||||||
Id: "r3",
|
Id: "r3",
|
||||||
Data: newStructData(t, map[string]any{
|
Data: newStructAny(t, map[string]any{
|
||||||
"$index": map[string]any{
|
"$index": map[string]any{
|
||||||
"cidr": "10.0.0.0/16",
|
"cidr": "10.0.0.0/16",
|
||||||
},
|
},
|
||||||
|
@ -45,7 +45,7 @@ func TestRecordCollection(t *testing.T) {
|
||||||
}
|
}
|
||||||
r4 := &databroker.Record{
|
r4 := &databroker.Record{
|
||||||
Id: "r4",
|
Id: "r4",
|
||||||
Data: newStructData(t, map[string]any{
|
Data: newStructAny(t, map[string]any{
|
||||||
"$index": map[string]any{
|
"$index": map[string]any{
|
||||||
"cidr": "10.0.0.0/24",
|
"cidr": "10.0.0.0/24",
|
||||||
},
|
},
|
||||||
|
@ -123,11 +123,16 @@ func TestRecordCollection(t *testing.T) {
|
||||||
assert.Empty(t, cmp.Diff([]*databroker.Record{r3}, rs, protocmp.Transform()))
|
assert.Empty(t, cmp.Diff([]*databroker.Record{r3}, rs, protocmp.Transform()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func newStructData(t *testing.T, m map[string]any) *anypb.Any {
|
func newStruct(t *testing.T, m map[string]any) *structpb.Struct {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
s, err := structpb.NewStruct(m)
|
s, err := structpb.NewStruct(m)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStructAny(t *testing.T, m map[string]any) *anypb.Any {
|
||||||
|
t.Helper()
|
||||||
|
s, err := structpb.NewStruct(m)
|
||||||
|
require.NoError(t, err)
|
||||||
return protoutil.NewAny(s)
|
return protoutil.NewAny(s)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue