storage: add fallback querier

This commit is contained in:
Caleb Doxsey 2025-04-10 11:20:31 -06:00
parent 3891293fa7
commit 04a5506d1b
3 changed files with 83 additions and 1 deletions

View file

@ -3,6 +3,7 @@ package storage
import (
"context"
"encoding/json"
"errors"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/codes"
@ -14,6 +15,9 @@ import (
"github.com/pomerium/pomerium/pkg/grpcutil"
)
// ErrUnavailable indicates that a querier is not available.
var ErrUnavailable = errors.New("unavailable")
// A Querier is a read-only subset of the client methods
type Querier interface {
InvalidateCache(ctx context.Context, in *databroker.QueryRequest)
@ -26,7 +30,7 @@ type nilQuerier struct{}
func (nilQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {}
func (nilQuerier) Query(_ context.Context, _ *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
return nil, status.Error(codes.NotFound, "not found")
return nil, errors.Join(ErrUnavailable, status.Error(codes.NotFound, "not found"))
}
type querierKey struct{}

View file

@ -0,0 +1,42 @@
package storage
import (
"context"
"errors"
grpc "google.golang.org/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type fallbackQuerier []Querier
// NewFallbackQuerier creates a new fallback-querier. The first call to Query that
// does not return an error will be used.
func NewFallbackQuerier(queriers ...Querier) Querier {
return fallbackQuerier(queriers)
}
// InvalidateCache invalidates the cache of all the queriers.
func (q fallbackQuerier) InvalidateCache(ctx context.Context, req *databroker.QueryRequest) {
for _, qq := range q {
qq.InvalidateCache(ctx, req)
}
}
// Query returns the first querier's results that doesn't result in an error.
func (q fallbackQuerier) Query(ctx context.Context, req *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) {
if len(q) == 0 {
return nil, ErrUnavailable
}
var merr error
for _, qq := range q {
res, err := qq.Query(ctx, req, opts...)
if err == nil {
return res, nil
}
merr = errors.Join(merr, err)
}
return nil, merr
}

View file

@ -0,0 +1,36 @@
package storage_test
import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/testing/protocmp"
"github.com/pomerium/pomerium/internal/testutil"
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
func TestFallbackQuerier(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute)
q1 := storage.GetQuerier(ctx) // nil querier
q2 := storage.NewStaticQuerier(&databrokerpb.Record{
Type: "t1",
Id: "r1",
Version: 1,
})
res, err := storage.NewFallbackQuerier(q1, q2).Query(ctx, &databrokerpb.QueryRequest{
Type: "t1",
Limit: 1,
})
assert.NoError(t, err, "should fallback")
assert.Empty(t, cmp.Diff(&databrokerpb.QueryResponse{
Records: []*databrokerpb.Record{{Type: "t1", Id: "r1", Version: 1}},
TotalCount: 1,
RecordVersion: 1,
}, res, protocmp.Transform()))
}