diff --git a/pkg/storage/querier.go b/pkg/storage/querier.go index 108e60f47..131e1046e 100644 --- a/pkg/storage/querier.go +++ b/pkg/storage/querier.go @@ -3,6 +3,7 @@ package storage import ( "context" "encoding/json" + "errors" grpc "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -14,6 +15,9 @@ import ( "github.com/pomerium/pomerium/pkg/grpcutil" ) +// ErrUnavailable indicates that a querier is not available. +var ErrUnavailable = errors.New("unavailable") + // A Querier is a read-only subset of the client methods type Querier interface { InvalidateCache(ctx context.Context, in *databroker.QueryRequest) @@ -26,7 +30,7 @@ type nilQuerier struct{} func (nilQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {} func (nilQuerier) Query(_ context.Context, _ *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) { - return nil, status.Error(codes.NotFound, "not found") + return nil, errors.Join(ErrUnavailable, status.Error(codes.NotFound, "not found")) } type querierKey struct{} diff --git a/pkg/storage/querier_fallback.go b/pkg/storage/querier_fallback.go new file mode 100644 index 000000000..46ca06fbd --- /dev/null +++ b/pkg/storage/querier_fallback.go @@ -0,0 +1,42 @@ +package storage + +import ( + "context" + "errors" + + grpc "google.golang.org/grpc" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +type fallbackQuerier []Querier + +// NewFallbackQuerier creates a new fallback-querier. The first call to Query that +// does not return an error will be used. +func NewFallbackQuerier(queriers ...Querier) Querier { + return fallbackQuerier(queriers) +} + +// InvalidateCache invalidates the cache of all the queriers. +func (q fallbackQuerier) InvalidateCache(ctx context.Context, req *databroker.QueryRequest) { + for _, qq := range q { + qq.InvalidateCache(ctx, req) + } +} + +// Query returns the first querier's results that doesn't result in an error. +func (q fallbackQuerier) Query(ctx context.Context, req *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { + if len(q) == 0 { + return nil, ErrUnavailable + } + + var merr error + for _, qq := range q { + res, err := qq.Query(ctx, req, opts...) + if err == nil { + return res, nil + } + merr = errors.Join(merr, err) + } + return nil, merr +} diff --git a/pkg/storage/querier_fallback_test.go b/pkg/storage/querier_fallback_test.go new file mode 100644 index 000000000..a7eb0f5e9 --- /dev/null +++ b/pkg/storage/querier_fallback_test.go @@ -0,0 +1,36 @@ +package storage_test + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/testing/protocmp" + + "github.com/pomerium/pomerium/internal/testutil" + databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/storage" +) + +func TestFallbackQuerier(t *testing.T) { + t.Parallel() + + ctx := testutil.GetContext(t, time.Minute) + q1 := storage.GetQuerier(ctx) // nil querier + q2 := storage.NewStaticQuerier(&databrokerpb.Record{ + Type: "t1", + Id: "r1", + Version: 1, + }) + res, err := storage.NewFallbackQuerier(q1, q2).Query(ctx, &databrokerpb.QueryRequest{ + Type: "t1", + Limit: 1, + }) + assert.NoError(t, err, "should fallback") + assert.Empty(t, cmp.Diff(&databrokerpb.QueryResponse{ + Records: []*databrokerpb.Record{{Type: "t1", Id: "r1", Version: 1}}, + TotalCount: 1, + RecordVersion: 1, + }, res, protocmp.Transform())) +}