From f61e7efe738aa3e0664592adae4eaa5ac75cd538 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 1 Jun 2022 15:40:07 -0600 Subject: [PATCH] authorize: use query instead of sync for databroker data (#3377) --- authorize/authorize.go | 35 ++- authorize/check_response_test.go | 14 +- authorize/databroker.go | 48 ++++ authorize/evaluator/evaluator.go | 3 - authorize/evaluator/evaluator_test.go | 116 +--------- authorize/evaluator/headers_evaluator_test.go | 9 +- authorize/evaluator/policy_evaluator_test.go | 10 +- authorize/grpc.go | 35 ++- authorize/grpc_test.go | 2 - authorize/internal/store/index.go | 196 ---------------- authorize/internal/store/index_test.go | 74 ------ authorize/internal/store/store.go | 110 +++------ authorize/internal/store/store_test.go | 83 ------- authorize/log.go | 52 +++-- authorize/sync.go | 209 ----------------- authorize/sync_test.go | 150 ------------- go.mod | 16 +- go.sum | 7 +- pkg/grpc/audit/audit.pb.go | 38 +--- pkg/grpc/audit/audit.proto | 2 - pkg/grpc/databroker/databroker.go | 22 ++ pkg/storage/cache.go | 155 +++++++++++++ pkg/storage/cache_test.go | 73 ++++++ pkg/storage/querier.go | 210 ++++++++++++++++++ 24 files changed, 661 insertions(+), 1008 deletions(-) create mode 100644 authorize/databroker.go delete mode 100644 authorize/internal/store/index.go delete mode 100644 authorize/internal/store/index_test.go delete mode 100644 authorize/internal/store/store_test.go delete mode 100644 authorize/sync.go delete mode 100644 authorize/sync_test.go create mode 100644 pkg/storage/cache.go create mode 100644 pkg/storage/cache_test.go create mode 100644 pkg/storage/querier.go diff --git a/authorize/authorize.go b/authorize/authorize.go index 34ec82331..ea980c8b1 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -8,6 +8,8 @@ import ( "sync" "time" + "golang.org/x/sync/errgroup" + "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" @@ -17,6 +19,7 @@ import ( "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/storage" ) // Authorize struct holds @@ -25,8 +28,7 @@ type Authorize struct { store *store.Store currentOptions *config.AtomicOptions accessTracker *AccessTracker - - dataBrokerInitialSync chan struct{} + globalCache storage.Cache // The stateLock prevents updating the evaluator store simultaneously with an evaluation. // This should provide a consistent view of the data at a given server/record version and @@ -37,9 +39,9 @@ type Authorize struct { // New validates and creates a new Authorize service from a set of config options. func New(cfg *config.Config) (*Authorize, error) { a := &Authorize{ - currentOptions: config.NewAtomicOptions(), - store: store.New(), - dataBrokerInitialSync: make(chan struct{}), + currentOptions: config.NewAtomicOptions(), + store: store.New(), + globalCache: storage.NewGlobalCache(time.Minute), } a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod) @@ -59,19 +61,16 @@ func (a *Authorize) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli // Run runs the authorize service. func (a *Authorize) Run(ctx context.Context) error { - go a.accessTracker.Run(ctx) - _ = grpc.WaitForReady(ctx, a.state.Load().dataBrokerClientConnection, time.Second*10) - return newDataBrokerSyncer(a).Run(ctx) -} - -// WaitForInitialSync blocks until the initial sync is complete. -func (a *Authorize) WaitForInitialSync(ctx context.Context) error { - select { - case <-ctx.Done(): - return ctx.Err() - case <-a.dataBrokerInitialSync: - } - return nil + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + a.accessTracker.Run(ctx) + return nil + }) + eg.Go(func() error { + _ = grpc.WaitForReady(ctx, a.state.Load().dataBrokerClientConnection, time.Second*10) + return nil + }) + return eg.Wait() } func validateOptions(o *config.Options) error { diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index 66485afe6..d868461fa 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -20,8 +20,6 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/testutil" - "github.com/pomerium/pomerium/pkg/grpc/session" - "github.com/pomerium/pomerium/pkg/grpc/user" ) func TestAuthorize_okResponse(t *testing.T) { @@ -40,17 +38,7 @@ func TestAuthorize_okResponse(t *testing.T) { encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) a.state.Load().encoder = encoder a.currentOptions.Store(opt) - a.store = store.NewFromProtos(0, - &session.Session{ - Id: "SESSION_ID", - UserId: "USER_ID", - }, - &user.User{ - Id: "USER_ID", - Name: "foo", - Email: "foo@example.com", - }, - ) + a.store = store.New() pe, err := newPolicyEvaluator(opt, a.store) require.NoError(t, err) a.state.Load().evaluator = pe diff --git a/authorize/databroker.go b/authorize/databroker.go new file mode 100644 index 000000000..928494c5c --- /dev/null +++ b/authorize/databroker.go @@ -0,0 +1,48 @@ +package authorize + +import ( + "context" + + "github.com/open-policy-agent/opa/storage" + + "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" +) + +type sessionOrServiceAccount interface { + GetUserId() string +} + +func (a *Authorize) getDataBrokerSessionOrServiceAccount(ctx context.Context, sessionID string) (s sessionOrServiceAccount, err error) { + ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerSessionOrServiceAccount") + defer span.End() + + client := a.state.Load().dataBrokerClient + + s, err = session.Get(ctx, client, sessionID) + if storage.IsNotFound(err) { + s, err = user.GetServiceAccount(ctx, client, sessionID) + } + if err != nil { + return nil, err + } + + if _, ok := s.(*session.Session); ok { + a.accessTracker.TrackSessionAccess(sessionID) + } + if _, ok := s.(*user.ServiceAccount); ok { + a.accessTracker.TrackServiceAccountAccess(sessionID) + } + return s, nil +} + +func (a *Authorize) getDataBrokerUser(ctx context.Context, userID string) (u *user.User, err error) { + ctx, span := trace.StartSpan(ctx, "authorize.getDataBrokerUser") + defer span.End() + + client := a.state.Load().dataBrokerClient + + u, err = user.Get(ctx, client, userID) + return u, err +} diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index 8b24c7d2e..7efe7ad4f 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -72,8 +72,6 @@ type Result struct { Allow RuleResult Deny RuleResult Headers http.Header - - DataBrokerServerVersion, DataBrokerRecordVersion uint64 } // An Evaluator evaluates policies. @@ -170,7 +168,6 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error) Deny: policyOutput.Deny, Headers: headersOutput.Headers, } - res.DataBrokerServerVersion, res.DataBrokerRecordVersion = e.store.GetDataBrokerVersions() return res, nil } diff --git a/authorize/evaluator/evaluator_test.go b/authorize/evaluator/evaluator_test.go index 13e3c35fa..acf3c21cb 100644 --- a/authorize/evaluator/evaluator_test.go +++ b/authorize/evaluator/evaluator_test.go @@ -2,29 +2,24 @@ package evaluator import ( "context" - "fmt" - "math" "net/http" "net/url" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/pkg/cryptutil" - "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/policy/criteria" "github.com/pomerium/pomerium/pkg/policy/parser" - "github.com/pomerium/pomerium/pkg/protoutil" + "github.com/pomerium/pomerium/pkg/storage" ) func TestEvaluator(t *testing.T) { @@ -36,13 +31,15 @@ func TestEvaluator(t *testing.T) { require.NoError(t, err) eval := func(t *testing.T, options []Option, data []proto.Message, req *Request) (*Result, error) { - store := store.NewFromProtos(math.MaxUint64, data...) + ctx := context.Background() + ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...)) + store := store.New() store.UpdateIssuer("authenticate.example.com") store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY")) store.UpdateSigningKey(privateJWK) - e, err := New(context.Background(), store, options...) + e, err := New(ctx, store, options...) require.NoError(t, err) - return e.Evaluate(context.Background(), req) + return e.Evaluate(ctx, req) } policies := []config.Policy{ @@ -511,104 +508,3 @@ func mustParseURL(str string) *url.URL { } return u } - -func BenchmarkEvaluator_Evaluate(b *testing.B) { - store := store.New() - - policies := []config.Policy{ - { - From: "https://from.example.com", - To: config.WeightedURLs{ - {URL: *mustParseURL("https://to.example.com")}, - }, - AllowedUsers: []string{"SOME_USER"}, - }, - } - options := []Option{ - WithAuthenticateURL("https://authn.example.com"), - WithPolicies(policies), - } - - e, err := New(context.Background(), store, options...) - if !assert.NoError(b, err) { - return - } - - lastSessionID := "" - - for i := 0; i < 100000; i++ { - sessionID := uuid.New().String() - lastSessionID = sessionID - userID := uuid.New().String() - data := protoutil.NewAny(&session.Session{ - Version: fmt.Sprint(i), - Id: sessionID, - UserId: userID, - IdToken: &session.IDToken{ - Issuer: "benchmark", - Subject: userID, - IssuedAt: timestamppb.Now(), - }, - OauthToken: &session.OAuthToken{ - AccessToken: "ACCESS TOKEN", - TokenType: "Bearer", - RefreshToken: "REFRESH TOKEN", - }, - }) - store.UpdateRecord(0, &databroker.Record{ - Version: uint64(i), - Type: "type.googleapis.com/session.Session", - Id: sessionID, - Data: data, - }) - data = protoutil.NewAny(&user.User{ - Version: fmt.Sprint(i), - Id: userID, - }) - store.UpdateRecord(0, &databroker.Record{ - Version: uint64(i), - Type: "type.googleapis.com/user.User", - Id: userID, - Data: data, - }) - - data = protoutil.NewAny(&directory.User{ - Version: fmt.Sprint(i), - Id: userID, - GroupIds: []string{"1", "2", "3", "4"}, - }) - store.UpdateRecord(0, &databroker.Record{ - Version: uint64(i), - Type: data.TypeUrl, - Id: userID, - Data: data, - }) - - data = protoutil.NewAny(&directory.Group{ - Version: fmt.Sprint(i), - Id: fmt.Sprint(i), - }) - store.UpdateRecord(0, &databroker.Record{ - Version: uint64(i), - Type: data.TypeUrl, - Id: fmt.Sprint(i), - Data: data, - }) - } - - b.ResetTimer() - ctx := context.Background() - for i := 0; i < b.N; i++ { - _, _ = e.Evaluate(ctx, &Request{ - Policy: &policies[0], - HTTP: RequestHTTP{ - Method: "GET", - URL: "https://example.com/path", - Headers: map[string]string{}, - }, - Session: RequestSession{ - ID: lastSessionID, - }, - }) - } -} diff --git a/authorize/evaluator/headers_evaluator_test.go b/authorize/evaluator/headers_evaluator_test.go index 3e0899965..f406f8265 100644 --- a/authorize/evaluator/headers_evaluator_test.go +++ b/authorize/evaluator/headers_evaluator_test.go @@ -18,6 +18,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/storage" ) func TestNewHeadersRequestFromPolicy(t *testing.T) { @@ -51,13 +52,15 @@ func TestHeadersEvaluator(t *testing.T) { require.NoError(t, err) eval := func(t *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) { - store := store.NewFromProtos(math.MaxUint64, data...) + ctx := context.Background() + ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...)) + store := store.New() store.UpdateIssuer("authenticate.example.com") store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY")) store.UpdateSigningKey(privateJWK) - e, err := NewHeadersEvaluator(context.Background(), store) + e, err := NewHeadersEvaluator(ctx, store) require.NoError(t, err) - return e.Evaluate(context.Background(), input) + return e.Evaluate(ctx, input) } t.Run("groups", func(t *testing.T) { diff --git a/authorize/evaluator/policy_evaluator_test.go b/authorize/evaluator/policy_evaluator_test.go index 4787685b0..5f530ce00 100644 --- a/authorize/evaluator/policy_evaluator_test.go +++ b/authorize/evaluator/policy_evaluator_test.go @@ -2,7 +2,6 @@ package evaluator import ( "context" - "math" "strings" "testing" @@ -18,6 +17,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/policy" "github.com/pomerium/pomerium/pkg/policy/criteria" + "github.com/pomerium/pomerium/pkg/storage" ) func TestPolicyEvaluator(t *testing.T) { @@ -29,13 +29,15 @@ func TestPolicyEvaluator(t *testing.T) { require.NoError(t, err) eval := func(t *testing.T, policy *config.Policy, data []proto.Message, input *PolicyRequest) (*PolicyResponse, error) { - store := store.NewFromProtos(math.MaxUint64, data...) + ctx := context.Background() + ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...)) + store := store.New() store.UpdateIssuer("authenticate.example.com") store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY")) store.UpdateSigningKey(privateJWK) - e, err := NewPolicyEvaluator(context.Background(), store, policy) + e, err := NewPolicyEvaluator(ctx, store, policy) require.NoError(t, err) - return e.Evaluate(context.Background(), input) + return e.Evaluate(ctx, input) } p1 := &config.Policy{ diff --git a/authorize/grpc.go b/authorize/grpc.go index a4db23c01..a668675bc 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -16,6 +16,8 @@ import ( "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/storage" ) // Check implements the envoy auth server gRPC endpoint. @@ -23,10 +25,16 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe ctx, span := trace.StartSpan(ctx, "authorize.grpc.Check") defer span.End() - // wait for the initial sync to complete so that data is available for evaluation - if err := a.WaitForInitialSync(ctx); err != nil { - return nil, err - } + querier := storage.NewTracingQuerier( + storage.NewCachingQuerier( + storage.NewCachingQuerier( + storage.NewQuerier(a.state.Load().dataBrokerClient), + a.globalCache, + ), + storage.NewLocalCache(), + ), + ) + ctx = storage.WithQuerier(ctx, querier) state := a.state.Load() @@ -48,10 +56,21 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe rawJWT, _ := loadRawSession(hreq, a.currentOptions.Load(), state.encoder) sessionState, _ := loadSession(state.encoder, rawJWT) - s, u, err := a.forceSync(ctx, sessionState) - if err != nil { - log.Warn(ctx).Err(err).Msg("clearing session due to force sync failed") - sessionState = nil + var s sessionOrServiceAccount + var u *user.User + if sessionState != nil { + s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID) + if err != nil { + log.Warn(ctx).Err(err).Msg("clearing session due to force sync failed") + sessionState = nil + } + } + if s != nil { + u, err = a.getDataBrokerUser(ctx, s.GetUserId()) + if err != nil { + log.Warn(ctx).Err(err).Msg("clearing session due to force sync failed") + sessionState = nil + } } req, err := a.getEvaluatorRequestFromCheckRequest(in, sessionState) diff --git a/authorize/grpc_test.go b/authorize/grpc_test.go index 3845ae827..6ffb6258d 100644 --- a/authorize/grpc_test.go +++ b/authorize/grpc_test.go @@ -337,8 +337,6 @@ func TestAuthorize_Check(t *testing.T) { } a.currentOptions.Store(&config.Options{ForwardAuthURLString: "https://forward-auth.example.com"}) - close(a.dataBrokerInitialSync) - cmpOpts := []cmp.Option{ cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}), cmpopts.IgnoreUnexported(status.Status{}), diff --git a/authorize/internal/store/index.go b/authorize/internal/store/index.go deleted file mode 100644 index 60c3e29ba..000000000 --- a/authorize/internal/store/index.go +++ /dev/null @@ -1,196 +0,0 @@ -package store - -import ( - "sync" - - "github.com/kentik/patricia" - "github.com/kentik/patricia/string_tree" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/structpb" -) - -const ( - indexField = "$index" - cidrField = "cidr" -) - -type index struct { - mu sync.RWMutex - byType map[string]*recordIndex -} - -func newIndex() *index { - idx := new(index) - idx.clear() - return idx -} - -func (idx *index) clear() { - idx.mu.Lock() - defer idx.mu.Unlock() - idx.byType = map[string]*recordIndex{} -} - -func (idx *index) delete(typeURL, id string) { - idx.mu.Lock() - defer idx.mu.Unlock() - - ridx, ok := idx.byType[typeURL] - if !ok { - return - } - ridx.delete(id) - - if len(ridx.byID) == 0 { - delete(idx.byType, typeURL) - } -} - -func (idx *index) find(typeURL, id string) proto.Message { - idx.mu.RLock() - defer idx.mu.RUnlock() - - ridx, ok := idx.byType[typeURL] - if !ok { - return nil - } - return ridx.find(id) -} - -func (idx *index) get(typeURL, id string) proto.Message { - idx.mu.RLock() - defer idx.mu.RUnlock() - - ridx, ok := idx.byType[typeURL] - if !ok { - return nil - } - return ridx.get(id) -} - -func (idx *index) set(typeURL, id string, msg proto.Message) { - idx.mu.Lock() - defer idx.mu.Unlock() - - ridx, ok := idx.byType[typeURL] - if !ok { - ridx = newRecordIndex() - idx.byType[typeURL] = ridx - } - ridx.set(id, msg) -} - -// a recordIndex indexes records for of a specific type -type recordIndex struct { - byID map[string]proto.Message - byCIDRV4 *string_tree.TreeV4 - byCIDRV6 *string_tree.TreeV6 -} - -// newRecordIndex creates a new record index. -func newRecordIndex() *recordIndex { - return &recordIndex{ - byID: map[string]proto.Message{}, - byCIDRV4: string_tree.NewTreeV4(), - byCIDRV6: string_tree.NewTreeV6(), - } -} - -func (idx *recordIndex) delete(id string) { - r, ok := idx.byID[id] - if !ok { - return - } - - delete(idx.byID, id) - - addr4, addr6 := getIndexCIDR(r) - if addr4 != nil { - idx.byCIDRV4.Delete(*addr4, func(payload, val string) bool { - return payload == val - }, id) - } - if addr6 != nil { - idx.byCIDRV6.Delete(*addr6, func(payload, val string) bool { - return payload == val - }, id) - } -} - -func (idx *recordIndex) find(idOrString string) proto.Message { - r, ok := idx.byID[idOrString] - if ok { - return r - } - - addrv4, addrv6, _ := patricia.ParseIPFromString(idOrString) - if addrv4 != nil { - found, id := idx.byCIDRV4.FindDeepestTag(*addrv4) - if found { - return idx.byID[id] - } - } - if addrv6 != nil { - found, id := idx.byCIDRV6.FindDeepestTag(*addrv6) - if found { - return idx.byID[id] - } - } - - return nil -} - -func (idx *recordIndex) get(id string) proto.Message { - return idx.byID[id] -} - -func (idx *recordIndex) set(id string, msg proto.Message) { - _, ok := idx.byID[id] - if ok { - idx.delete(id) - } - - idx.byID[id] = msg - addr4, addr6 := getIndexCIDR(msg) - if addr4 != nil { - idx.byCIDRV4.Set(*addr4, id) - } - if addr6 != nil { - idx.byCIDRV6.Set(*addr6, id) - } -} - -func getIndexCIDR(msg proto.Message) (*patricia.IPv4Address, *patricia.IPv6Address) { - var s *structpb.Struct - if sv, ok := msg.(*structpb.Value); ok { - s = sv.GetStructValue() - } else { - s, _ = msg.(*structpb.Struct) - } - if s == nil { - return nil, nil - } - - f, ok := s.Fields[indexField] - if !ok { - return nil, nil - } - - obj := f.GetStructValue() - if obj == nil { - return nil, nil - } - - cf, ok := obj.Fields[cidrField] - if !ok { - return nil, nil - } - - c := cf.GetStringValue() - if c == "" { - return nil, nil - } - - addr4, addr6, _ := patricia.ParseIPFromString(c) - return addr4, addr6 -} diff --git a/authorize/internal/store/index_test.go b/authorize/internal/store/index_test.go deleted file mode 100644 index ba1e8121a..000000000 --- a/authorize/internal/store/index_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package store - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/types/known/structpb" -) - -func TestByID(t *testing.T) { - idx := newIndex() - - r1 := &structpb.Struct{Fields: map[string]*structpb.Value{ - "id": structpb.NewStringValue("r1"), - }} - - idx.set("example.com/record", "r1", r1) - assert.Equal(t, r1, idx.get("example.com/record", "r1")) - idx.delete("example.com/record", "r1") - assert.Nil(t, idx.get("example.com/record", "r1")) -} - -func TestByCIDR(t *testing.T) { - t.Run("ipv4", func(t *testing.T) { - idx := newIndex() - - r1 := &structpb.Struct{Fields: map[string]*structpb.Value{ - "$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{ - "cidr": structpb.NewStringValue("192.168.0.0/16"), - }}), - "id": structpb.NewStringValue("r1"), - }} - idx.set("example.com/record", "r1", r1) - - r2 := &structpb.Struct{Fields: map[string]*structpb.Value{ - "$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{ - "cidr": structpb.NewStringValue("192.168.0.0/24"), - }}), - "id": structpb.NewStringValue("r2"), - }} - idx.set("example.com/record", "r2", r2) - - assert.Equal(t, r2, idx.find("example.com/record", "192.168.0.7")) - idx.delete("example.com/record", "r2") - assert.Equal(t, r1, idx.find("example.com/record", "192.168.0.7")) - idx.delete("example.com/record", "r1") - assert.Nil(t, idx.find("example.com/record", "192.168.0.7")) - }) - t.Run("ipv6", func(t *testing.T) { - idx := newIndex() - - r1 := &structpb.Struct{Fields: map[string]*structpb.Value{ - "$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{ - "cidr": structpb.NewStringValue("2001:db8::/32"), - }}), - "id": structpb.NewStringValue("r1"), - }} - idx.set("example.com/record", "r1", r1) - - r2 := &structpb.Struct{Fields: map[string]*structpb.Value{ - "$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{ - "cidr": structpb.NewStringValue("2001:db8::/48"), - }}), - "id": structpb.NewStringValue("r2"), - }} - idx.set("example.com/record", "r2", r2) - - assert.Equal(t, r2, idx.find("example.com/record", "2001:db8::")) - idx.delete("example.com/record", "r2") - assert.Equal(t, r1, idx.find("example.com/record", "2001:db8::")) - idx.delete("example.com/record", "r1") - assert.Nil(t, idx.find("example.com/record", "2001:db8::")) - }) -} diff --git a/authorize/internal/store/store.go b/authorize/internal/store/store.go index 33652be57..e4ad99f4b 100644 --- a/authorize/internal/store/store.go +++ b/authorize/internal/store/store.go @@ -5,78 +5,33 @@ import ( "context" "encoding/json" "fmt" - "sync/atomic" "github.com/go-jose/go-jose/v3" - "github.com/google/uuid" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/storage" + opastorage "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/storage/inmem" "github.com/open-policy-agent/opa/types" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/protoutil" + "github.com/pomerium/pomerium/pkg/storage" ) // A Store stores data for the OPA rego policy evaluation. type Store struct { - storage.Store - index *index - - dataBrokerServerVersion, dataBrokerRecordVersion uint64 + opastorage.Store } // New creates a new Store. func New() *Store { return &Store{ Store: inmem.New(), - index: newIndex(), } } -// NewFromProtos creates a new Store from an existing set of protobuf messages. -func NewFromProtos(serverVersion uint64, msgs ...proto.Message) *Store { - s := New() - for _, msg := range msgs { - any := protoutil.NewAny(msg) - record := new(databroker.Record) - record.ModifiedAt = timestamppb.Now() - record.Version = cryptutil.NewRandomUInt64() - record.Id = uuid.New().String() - record.Data = any - record.Type = any.TypeUrl - if hasID, ok := msg.(interface{ GetId() string }); ok { - record.Id = hasID.GetId() - } - - s.UpdateRecord(serverVersion, record) - } - return s -} - -// ClearRecords removes all the records from the store. -func (s *Store) ClearRecords() { - s.index.clear() -} - -// GetDataBrokerVersions gets the databroker versions. -func (s *Store) GetDataBrokerVersions() (serverVersion, recordVersion uint64) { - return atomic.LoadUint64(&s.dataBrokerServerVersion), - atomic.LoadUint64(&s.dataBrokerRecordVersion) -} - -// GetRecordData gets a record's data from the store. `nil` is returned -// if no record exists for the given type and id. -func (s *Store) GetRecordData(typeURL, idOrValue string) proto.Message { - return s.index.find(typeURL, idOrValue) -} - // UpdateIssuer updates the issuer in the store. The issuer is used as part of JWT construction. func (s *Store) UpdateIssuer(issuer string) { s.write("/issuer", issuer) @@ -98,20 +53,6 @@ func (s *Store) UpdateRoutePolicies(routePolicies []config.Policy) { s.write("/route_policies", routePolicies) } -// UpdateRecord updates a record in the store. -func (s *Store) UpdateRecord(serverVersion uint64, record *databroker.Record) { - if record.GetDeletedAt() != nil { - s.index.delete(record.GetType(), record.GetId()) - } else { - msg, _ := record.GetData().UnmarshalNew() - s.index.set(record.GetType(), record.GetId(), msg) - } - s.write("/databroker_server_version", fmt.Sprint(serverVersion)) - s.write("/databroker_record_version", fmt.Sprint(record.GetVersion())) - atomic.StoreUint64(&s.dataBrokerServerVersion, serverVersion) - atomic.StoreUint64(&s.dataBrokerRecordVersion, record.GetVersion()) -} - // UpdateSigningKey updates the signing key stored in the database. Signing operations // in rego use JWKs, so we take in that format. func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) { @@ -120,7 +61,7 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) { func (s *Store) write(rawPath string, value interface{}) { ctx := context.TODO() - err := storage.Txn(ctx, s.Store, storage.WriteParams, func(txn storage.Transaction) error { + err := opastorage.Txn(ctx, s.Store, opastorage.WriteParams, func(txn opastorage.Transaction) error { return s.writeTxn(txn, rawPath, value) }) if err != nil { @@ -129,23 +70,23 @@ func (s *Store) write(rawPath string, value interface{}) { } } -func (s *Store) writeTxn(txn storage.Transaction, rawPath string, value interface{}) error { - p, ok := storage.ParsePath(rawPath) +func (s *Store) writeTxn(txn opastorage.Transaction, rawPath string, value interface{}) error { + p, ok := opastorage.ParsePath(rawPath) if !ok { return fmt.Errorf("invalid path") } if len(p) > 1 { - err := storage.MakeDir(context.Background(), s, txn, p[:len(p)-1]) + err := opastorage.MakeDir(context.Background(), s, txn, p[:len(p)-1]) if err != nil { return err } } - var op storage.PatchOp = storage.ReplaceOp + var op opastorage.PatchOp = opastorage.ReplaceOp _, err := s.Read(context.Background(), txn, p) - if storage.IsNotFound(err) { - op = storage.AddOp + if opastorage.IsNotFound(err) { + op = opastorage.AddOp } else if err != nil { return err } @@ -167,23 +108,42 @@ func (s *Store) GetDataBrokerRecordOption() func(*rego.Rego) { return nil, fmt.Errorf("invalid record type: %T", op1) } - recordID, ok := op2.Value.(ast.String) + value, ok := op2.Value.(ast.String) if !ok { return nil, fmt.Errorf("invalid record id: %T", op2) } - msg := s.GetRecordData(string(recordType), string(recordID)) - if msg == nil { + req := &databroker.QueryRequest{ + Type: string(recordType), + Limit: 1, + } + req.SetFilterByIDOrIndex(string(value)) + + res, err := storage.GetQuerier(bctx.Context).Query(bctx.Context, req) + if err != nil { + log.Error(bctx.Context).Err(err).Msg("authorize/store: error retrieving record") return ast.NullTerm(), nil } + + if len(res.GetRecords()) == 0 { + return ast.NullTerm(), nil + } + + msg, _ := res.GetRecords()[0].GetData().UnmarshalNew() + if msg == nil { + if msg == nil { + return ast.NullTerm(), nil + } + } obj := toMap(msg) - value, err := ast.InterfaceToValue(obj) + regoValue, err := ast.InterfaceToValue(obj) if err != nil { - return nil, err + log.Error(bctx.Context).Err(err).Msg("authorize/store: error converting object to rego") + return ast.NullTerm(), nil } - return ast.NewTerm(value), nil + return ast.NewTerm(regoValue), nil }) } diff --git a/authorize/internal/store/store_test.go b/authorize/internal/store/store_test.go deleted file mode 100644 index ca743dca8..000000000 --- a/authorize/internal/store/store_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package store - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/types/known/structpb" - "google.golang.org/protobuf/types/known/timestamppb" - - "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/grpc/user" - "github.com/pomerium/pomerium/pkg/protoutil" -) - -func TestStore(t *testing.T) { - t.Run("records", func(t *testing.T) { - s := New() - u := &user.User{ - Version: "v1", - Id: "u1", - Name: "name", - Email: "name@example.com", - } - any := protoutil.NewAny(u) - s.UpdateRecord(0, &databroker.Record{ - Version: 1, - Type: any.GetTypeUrl(), - Id: u.GetId(), - Data: any, - }) - - v := s.GetRecordData(any.GetTypeUrl(), u.GetId()) - assert.Equal(t, map[string]interface{}{ - "version": "v1", - "id": "u1", - "name": "name", - "email": "name@example.com", - }, toMap(v)) - - s.UpdateRecord(0, &databroker.Record{ - Version: 2, - Type: any.GetTypeUrl(), - Id: u.GetId(), - Data: any, - DeletedAt: timestamppb.Now(), - }) - - v = s.GetRecordData(any.GetTypeUrl(), u.GetId()) - assert.Nil(t, v) - - s.UpdateRecord(0, &databroker.Record{ - Version: 3, - Type: any.GetTypeUrl(), - Id: u.GetId(), - Data: any, - }) - - v = s.GetRecordData(any.GetTypeUrl(), u.GetId()) - assert.NotNil(t, v) - - s.ClearRecords() - v = s.GetRecordData(any.GetTypeUrl(), u.GetId()) - assert.Nil(t, v) - }) - t.Run("cidr", func(t *testing.T) { - s := New() - any := protoutil.NewAny(&structpb.Struct{Fields: map[string]*structpb.Value{ - "$index": structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{ - "cidr": structpb.NewStringValue("192.168.0.0/16"), - }}), - "id": structpb.NewStringValue("r1"), - }}) - s.UpdateRecord(0, &databroker.Record{ - Version: 1, - Type: any.GetTypeUrl(), - Id: "r1", - Data: any, - }) - - v := s.GetRecordData(any.GetTypeUrl(), "192.168.0.7") - assert.NotNil(t, v) - }) -} diff --git a/authorize/log.go b/authorize/log.go index 04483e7ec..7628b5277 100644 --- a/authorize/log.go +++ b/authorize/log.go @@ -12,9 +12,11 @@ import ( "github.com/pomerium/pomerium/internal/telemetry/requestid" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/grpc/audit" + "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpcutil" + "github.com/pomerium/pomerium/pkg/storage" ) func (a *Authorize) logAuthorizeCheck( @@ -39,7 +41,7 @@ func (a *Authorize) logAuthorizeCheck( // session information if s, ok := s.(*session.Session); ok { - evt = a.populateLogSessionDetails(evt, s) + evt = a.populateLogSessionDetails(ctx, evt, s) } if sa, ok := s.(*user.ServiceAccount); ok { evt = evt.Str("service-account-id", sa.GetId()) @@ -61,8 +63,6 @@ func (a *Authorize) logAuthorizeCheck( } evt = evt.Str("user", u.GetId()) evt = evt.Str("email", u.GetEmail()) - evt = evt.Uint64("databroker_server_version", res.DataBrokerServerVersion) - evt = evt.Uint64("databroker_record_version", res.DataBrokerRecordVersion) } // potentially sensitive, only log if debug mode @@ -80,10 +80,6 @@ func (a *Authorize) logAuthorizeCheck( Request: in, Response: out, } - if res != nil { - record.DatabrokerServerVersion = res.DataBrokerServerVersion - record.DatabrokerRecordVersion = res.DataBrokerRecordVersion - } sealed, err := enc.Encrypt(record) if err != nil { log.Warn(ctx).Err(err).Msg("authorize: error encrypting audit record") @@ -96,26 +92,50 @@ func (a *Authorize) logAuthorizeCheck( } } -func (a *Authorize) populateLogSessionDetails(evt *zerolog.Event, s *session.Session) *zerolog.Event { +func (a *Authorize) populateLogSessionDetails(ctx context.Context, evt *zerolog.Event, s *session.Session) *zerolog.Event { evt = evt.Str("session-id", s.GetId()) if s.GetImpersonateSessionId() == "" { return evt } + querier := storage.GetQuerier(ctx) + evt = evt.Str("impersonate-session-id", s.GetImpersonateSessionId()) - impersonatedSession, ok := a.store.GetRecordData( - grpcutil.GetTypeURL(new(session.Session)), - s.GetImpersonateSessionId(), - ).(*session.Session) + req := &databroker.QueryRequest{ + Type: grpcutil.GetTypeURL(new(session.Session)), + Limit: 1, + } + req.SetFilterByID(s.GetImpersonateSessionId()) + res, err := querier.Query(ctx, req) + if err != nil || len(res.GetRecords()) == 0 { + return evt + } + + impersonatedSessionMsg, err := res.GetRecords()[0].GetData().UnmarshalNew() + if err != nil { + return evt + } + impersonatedSession, ok := impersonatedSessionMsg.(*session.Session) if !ok { return evt } evt = evt.Str("impersonate-user-id", impersonatedSession.GetUserId()) - impersonatedUser, ok := a.store.GetRecordData( - grpcutil.GetTypeURL(new(user.User)), - impersonatedSession.GetUserId(), - ).(*user.User) + req = &databroker.QueryRequest{ + Type: grpcutil.GetTypeURL(new(user.User)), + Limit: 1, + } + req.SetFilterByID(impersonatedSession.GetUserId()) + res, err = querier.Query(ctx, req) + if err != nil || len(res.GetRecords()) == 0 { + return evt + } + + impersonatedUserMsg, err := res.GetRecords()[0].GetData().UnmarshalNew() + if err != nil { + return evt + } + impersonatedUser, ok := impersonatedUserMsg.(*user.User) if !ok { return evt } diff --git a/authorize/sync.go b/authorize/sync.go deleted file mode 100644 index ed3f0d385..000000000 --- a/authorize/sync.go +++ /dev/null @@ -1,209 +0,0 @@ -package authorize - -import ( - "context" - "errors" - "sync" - "time" - - "github.com/cenkalti/backoff/v4" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" - - "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/sessions" - "github.com/pomerium/pomerium/internal/telemetry/trace" - "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/grpc/session" - "github.com/pomerium/pomerium/pkg/grpc/user" - "github.com/pomerium/pomerium/pkg/grpcutil" -) - -const ( - forceSyncRecordMaxWait = 5 * time.Second -) - -type sessionOrServiceAccount interface { - GetUserId() string -} - -type dataBrokerSyncer struct { - *databroker.Syncer - authorize *Authorize - signalOnce sync.Once -} - -func newDataBrokerSyncer(authorize *Authorize) *dataBrokerSyncer { - syncer := &dataBrokerSyncer{ - authorize: authorize, - } - syncer.Syncer = databroker.NewSyncer("authorize", syncer) - return syncer -} - -func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { - return syncer.authorize.state.Load().dataBrokerClient -} - -func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) { - syncer.authorize.stateLock.Lock() - syncer.authorize.store.ClearRecords() - syncer.authorize.stateLock.Unlock() -} - -func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) { - syncer.authorize.stateLock.Lock() - for _, record := range records { - syncer.authorize.store.UpdateRecord(serverVersion, record) - } - syncer.authorize.stateLock.Unlock() - - // the first time we update records we signal the initial sync - syncer.signalOnce.Do(func() { - log.Info(ctx).Msg("initial sync from databroker complete") - close(syncer.authorize.dataBrokerInitialSync) - }) -} - -func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) (sessionOrServiceAccount, *user.User, error) { - ctx, span := trace.StartSpan(ctx, "authorize.forceSync") - defer span.End() - if ss == nil { - return nil, nil, nil - } - - // if the session state has databroker versions, wait for those to finish syncing - if ss.DatabrokerServerVersion != 0 && ss.DatabrokerRecordVersion != 0 { - a.forceSyncToVersion(ctx, ss.DatabrokerServerVersion, ss.DatabrokerRecordVersion) - } - - s := a.forceSyncSession(ctx, ss.ID) - if s == nil { - return nil, nil, errors.New("session not found") - } - u := a.forceSyncUser(ctx, s.GetUserId()) - return s, u, nil -} - -func (a *Authorize) forceSyncToVersion(ctx context.Context, serverVersion, recordVersion uint64) (ready bool) { - ctx, span := trace.StartSpan(ctx, "authorize.forceSyncToVersion") - defer span.End() - - ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait) - defer clearTimeout() - - ticker := time.NewTicker(time.Millisecond * 50) - defer ticker.Stop() - for { - currentServerVersion, currentRecordVersion := a.store.GetDataBrokerVersions() - // check if the local record version is up to date with the expected record version - if currentServerVersion == serverVersion && currentRecordVersion >= recordVersion { - return true - } - - select { - case <-ctx.Done(): - return false - case <-ticker.C: - } - } -} - -func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) sessionOrServiceAccount { - ctx, span := trace.StartSpan(ctx, "authorize.forceSyncSession") - defer span.End() - - ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait) - defer clearTimeout() - - s, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(session.Session)), sessionID).(*session.Session) - if ok { - a.accessTracker.TrackSessionAccess(sessionID) - return s - } - - sa, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID).(*user.ServiceAccount) - if ok { - a.accessTracker.TrackServiceAccountAccess(sessionID) - return sa - } - - // wait for the session to show up - record, err := a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID) - if err != nil { - return nil - } - s, ok = record.(*session.Session) - if !ok { - return nil - } - a.accessTracker.TrackSessionAccess(sessionID) - return s -} - -func (a *Authorize) forceSyncUser(ctx context.Context, userID string) *user.User { - ctx, span := trace.StartSpan(ctx, "authorize.forceSyncUser") - defer span.End() - - ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait) - defer clearTimeout() - - u, ok := a.store.GetRecordData(grpcutil.GetTypeURL(new(user.User)), userID).(*user.User) - if ok { - return u - } - - // wait for the user to show up - record, err := a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(user.User)), userID) - if err != nil { - return nil - } - u, ok = record.(*user.User) - if !ok { - return nil - } - return u -} - -// waitForRecordSync waits for the first sync of a record to complete -func (a *Authorize) waitForRecordSync(ctx context.Context, recordTypeURL, recordID string) (proto.Message, error) { - bo := backoff.NewExponentialBackOff() - bo.InitialInterval = time.Millisecond - bo.MaxElapsedTime = 0 - bo.Reset() - - for { - current := a.store.GetRecordData(recordTypeURL, recordID) - if current != nil { - // record found, so it's already synced - return current, nil - } - - _, err := a.state.Load().dataBrokerClient.Get(ctx, &databroker.GetRequest{ - Type: recordTypeURL, - Id: recordID, - }) - if status.Code(err) == codes.NotFound { - // record not found, so no need to wait - return nil, nil - } else if err != nil { - log.Error(ctx). - Err(err). - Str("type", recordTypeURL). - Str("id", recordID). - Msg("authorize: error retrieving record") - return nil, err - } - - select { - case <-ctx.Done(): - log.Warn(ctx). - Str("type", recordTypeURL). - Str("id", recordID). - Msg("authorize: first sync of record did not complete") - return nil, ctx.Err() - case <-time.After(bo.NextBackOff()): - } - } -} diff --git a/authorize/sync_test.go b/authorize/sync_test.go deleted file mode 100644 index 50843e07e..000000000 --- a/authorize/sync_test.go +++ /dev/null @@ -1,150 +0,0 @@ -package authorize - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" - - "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/grpc/session" - "github.com/pomerium/pomerium/pkg/grpcutil" - "github.com/pomerium/pomerium/pkg/protoutil" -) - -func TestAuthorize_forceSyncToVersion(t *testing.T) { - o := &config.Options{ - AuthenticateURLString: "https://authN.example.com", - DataBrokerURLString: "https://databroker.example.com", - SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", - Policies: testPolicies(t), - } - a, err := New(&config.Config{Options: o}) - require.NoError(t, err) - - a.store.UpdateRecord(1, &databroker.Record{ - Version: 1, - }) - t.Run("ready", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - assert.True(t, a.forceSyncToVersion(ctx, 1, 1)) - }) - t.Run("not ready", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - assert.False(t, a.forceSyncToVersion(ctx, 1, 2)) - }) - t.Run("becomes ready", func(t *testing.T) { - ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) - defer clearTimeout() - - go func() { - <-time.After(time.Millisecond * 100) - a.store.UpdateRecord(1, &databroker.Record{ - Version: 2, - }) - }() - assert.True(t, a.forceSyncToVersion(ctx, 1, 2)) - }) -} - -func TestAuthorize_waitForRecordSync(t *testing.T) { - ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30) - defer clearTimeout() - - o := &config.Options{ - AuthenticateURLString: "https://authN.example.com", - DataBrokerURLString: "https://databroker.example.com", - SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", - Policies: testPolicies(t), - } - t.Run("skip if exists", func(t *testing.T) { - a, err := New(&config.Config{Options: o}) - require.NoError(t, err) - - a.store.UpdateRecord(0, newRecord(&session.Session{ - Id: "SESSION_ID", - })) - a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{ - get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - panic("should never be called") - }, - } - a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID") - }) - t.Run("skip if not found", func(t *testing.T) { - a, err := New(&config.Config{Options: o}) - require.NoError(t, err) - - callCount := 0 - a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{ - get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - callCount++ - return nil, status.Error(codes.NotFound, "not found") - }, - } - a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID") - assert.Equal(t, 1, callCount, "should be called once") - }) - t.Run("poll", func(t *testing.T) { - a, err := New(&config.Config{Options: o}) - require.NoError(t, err) - - callCount := 0 - a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{ - get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - callCount++ - switch callCount { - case 1: - s := &session.Session{Id: "SESSION_ID"} - a.store.UpdateRecord(0, newRecord(s)) - return &databroker.GetResponse{Record: newRecord(s)}, nil - default: - panic("should never be called") - } - }, - } - a.waitForRecordSync(ctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID") - }) - t.Run("timeout", func(t *testing.T) { - a, err := New(&config.Config{Options: o}) - require.NoError(t, err) - - tctx, clearTimeout := context.WithTimeout(ctx, time.Millisecond*100) - defer clearTimeout() - - callCount := 0 - a.state.Load().dataBrokerClient = mockDataBrokerServiceClient{ - get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) { - callCount++ - s := &session.Session{Id: "SESSION_ID"} - return &databroker.GetResponse{Record: newRecord(s)}, nil - }, - } - a.waitForRecordSync(tctx, grpcutil.GetTypeURL(new(session.Session)), "SESSION_ID") - assert.Greater(t, callCount, 5) // should be ~ 20, but allow for non-determinism - }) -} - -type storableMessage interface { - proto.Message - GetId() string -} - -func newRecord(msg storableMessage) *databroker.Record { - any := protoutil.NewAny(msg) - return &databroker.Record{ - Version: 1, - Type: any.GetTypeUrl(), - Id: msg.GetId(), - Data: any, - } -} diff --git a/go.mod b/go.mod index 5920475ab..c8a64a897 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,9 @@ require ( contrib.go.opencensus.io/exporter/jaeger v0.2.1 contrib.go.opencensus.io/exporter/prometheus v0.4.1 contrib.go.opencensus.io/exporter/zipkin v0.1.2 + github.com/CAFxX/httpcompression v0.0.8 github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20200406135749-5c268882acf0 + github.com/VictoriaMetrics/fastcache v1.10.0 github.com/caddyserver/certmagic v0.16.0 github.com/cenkalti/backoff/v4 v4.1.3 github.com/cespare/xxhash/v2 v2.1.2 @@ -29,9 +31,11 @@ require ( github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 - github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/golang-lru v0.5.4 + github.com/jackc/pgconn v1.12.1 + github.com/jackc/pgtype v1.11.0 + github.com/jackc/pgx/v4 v4.16.1 github.com/martinlindhe/base36 v1.1.0 github.com/mholt/acmez v1.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2 @@ -75,14 +79,6 @@ require ( sigs.k8s.io/yaml v1.3.0 ) -require ( - github.com/CAFxX/httpcompression v0.0.8 - github.com/jackc/pgconn v1.12.1 - github.com/jackc/pgtype v1.11.0 - github.com/jackc/pgx/v4 v4.16.1 - github.com/kentik/patricia v1.0.0 -) - require ( 4d63.com/gochecknoglobals v0.1.0 // indirect cloud.google.com/go/compute v1.6.1 // indirect @@ -148,6 +144,7 @@ require ( github.com/gofrs/flock v0.8.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/golangci/check v0.0.0-20180506172741-cfe4005ccda2 // indirect github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a // indirect github.com/golangci/go-misc v0.0.0-20220329215616-d24fe342adfe // indirect @@ -166,6 +163,7 @@ require ( github.com/gostaticanalysis/comment v1.4.2 // indirect github.com/gostaticanalysis/forcetypeassert v0.1.0 // indirect github.com/gostaticanalysis/nilerr v0.1.1 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-version v1.4.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hexops/gotextdiff v1.0.3 // indirect diff --git a/go.sum b/go.sum index e396d34ce..adb967cf9 100644 --- a/go.sum +++ b/go.sum @@ -165,6 +165,8 @@ github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWX github.com/Shopify/sarama v1.30.0/go.mod h1:zujlQQx1kzHsh4jfV1USnptCQrHAEZ2Hk8fTKCulPVs= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/Shopify/toxiproxy/v2 v2.1.6-0.20210914104332-15ea381dcdae/go.mod h1:/cvHQkZ1fst0EmZnA5dFtiQdWCNCFYzb+uE2vqVgvx0= +github.com/VictoriaMetrics/fastcache v1.10.0 h1:5hDJnLsKLpnUEToub7ETuRu8RCkb40woBZAUiKonXzY= +github.com/VictoriaMetrics/fastcache v1.10.0/go.mod h1:tjiYeEfYXCqacuvYw/7UoDIeJaNxq6132xHICNP77w8= github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -175,6 +177,8 @@ github.com/alexflint/go-filemutex v0.0.0-20171022225611-72bdc8eae2ae/go.mod h1:C github.com/alexflint/go-filemutex v1.1.0/go.mod h1:7P4iRhttt/nUvUOrYIhcpMzv2G6CY9UnI16Z+UJqRyk= github.com/alexkohler/prealloc v1.0.0 h1:Hbq0/3fJPQhNkN0dR95AVrr6R7tou91y0uHG5pOcUuw= github.com/alexkohler/prealloc v1.0.0/go.mod h1:VetnK3dIgFBBKmg0YnD9F9x6Icjd+9cvfHR56wJVlKE= +github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8= +github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM= github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ= github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= @@ -995,8 +999,6 @@ github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8 github.com/julz/importas v0.1.0 h1:F78HnrsjY3cR7j0etXy5+TU1Zuy7Xt08X/1aJnH5xXY= github.com/julz/importas v0.1.0/go.mod h1:oSFU2R4XK/P7kNBrnL/FEQlDGN1/6WoxXEjSSXO0DV0= github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= -github.com/kentik/patricia v1.0.0 h1:jx/8kXf0JvQEHNPX4njL+PDzpxxqNKg0RjA8hJcX38A= -github.com/kentik/patricia v1.0.0/go.mod h1:e0nkPLU9NQl8v05ukfHU6+R5ykbKcXO+NqaC3ifTm0Y= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= @@ -2054,6 +2056,7 @@ golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220502124256-b6088ccd6cba/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/pkg/grpc/audit/audit.pb.go b/pkg/grpc/audit/audit.pb.go index ef838991e..83c4e6038 100644 --- a/pkg/grpc/audit/audit.pb.go +++ b/pkg/grpc/audit/audit.pb.go @@ -26,10 +26,8 @@ type Record struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Request *v3.CheckRequest `protobuf:"bytes,1,opt,name=request,proto3" json:"request,omitempty"` - Response *v3.CheckResponse `protobuf:"bytes,2,opt,name=response,proto3" json:"response,omitempty"` - DatabrokerServerVersion uint64 `protobuf:"varint,3,opt,name=databroker_server_version,json=databrokerServerVersion,proto3" json:"databroker_server_version,omitempty"` - DatabrokerRecordVersion uint64 `protobuf:"varint,4,opt,name=databroker_record_version,json=databrokerRecordVersion,proto3" json:"databroker_record_version,omitempty"` + Request *v3.CheckRequest `protobuf:"bytes,1,opt,name=request,proto3" json:"request,omitempty"` + Response *v3.CheckResponse `protobuf:"bytes,2,opt,name=response,proto3" json:"response,omitempty"` } func (x *Record) Reset() { @@ -78,20 +76,6 @@ func (x *Record) GetResponse() *v3.CheckResponse { return nil } -func (x *Record) GetDatabrokerServerVersion() uint64 { - if x != nil { - return x.DatabrokerServerVersion - } - return 0 -} - -func (x *Record) GetDatabrokerRecordVersion() uint64 { - if x != nil { - return x.DatabrokerRecordVersion - } - return 0 -} - var File_audit_proto protoreflect.FileDescriptor var file_audit_proto_rawDesc = []byte{ @@ -99,7 +83,7 @@ var file_audit_proto_rawDesc = []byte{ 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x61, 0x75, 0x64, 0x69, 0x74, 0x1a, 0x29, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x2f, 0x76, 0x33, 0x2f, 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x5f, 0x61, 0x75, - 0x74, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x81, 0x02, 0x0a, 0x06, 0x52, 0x65, 0x63, + 0x74, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x89, 0x01, 0x0a, 0x06, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x3d, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x33, 0x2e, 0x43, 0x68, 0x65, @@ -108,18 +92,10 @@ var file_audit_proto_rawDesc = []byte{ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x33, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x52, 0x08, 0x72, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3a, 0x0a, 0x19, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x6f, 0x6b, - 0x65, 0x72, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, - 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x04, 0x52, 0x17, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x6f, - 0x6b, 0x65, 0x72, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, - 0x12, 0x3a, 0x0a, 0x19, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x5f, 0x72, - 0x65, 0x63, 0x6f, 0x72, 0x64, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x04, 0x52, 0x17, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x52, - 0x65, 0x63, 0x6f, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x42, 0x2d, 0x5a, 0x2b, - 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, - 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, - 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75, 0x64, 0x69, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, + 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x2d, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, + 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75, + 0x64, 0x69, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/pkg/grpc/audit/audit.proto b/pkg/grpc/audit/audit.proto index c97b361f1..cd86e8121 100644 --- a/pkg/grpc/audit/audit.proto +++ b/pkg/grpc/audit/audit.proto @@ -8,6 +8,4 @@ import "envoy/service/auth/v3/external_auth.proto"; message Record { envoy.service.auth.v3.CheckRequest request = 1; envoy.service.auth.v3.CheckResponse response = 2; - uint64 databroker_server_version = 3; - uint64 databroker_record_version = 4; } diff --git a/pkg/grpc/databroker/databroker.go b/pkg/grpc/databroker/databroker.go index 3e7bf4840..d191a946b 100644 --- a/pkg/grpc/databroker/databroker.go +++ b/pkg/grpc/databroker/databroker.go @@ -7,6 +7,7 @@ import ( "io" "google.golang.org/protobuf/proto" + structpb "google.golang.org/protobuf/types/known/structpb" "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/protoutil" @@ -118,6 +119,27 @@ func (x *PutResponse) GetRecord() *Record { return records[0] } +// SetFilterByID sets the filter to an id. +func (x *QueryRequest) SetFilterByID(id string) { + x.Filter = &structpb.Struct{Fields: map[string]*structpb.Value{ + "id": structpb.NewStringValue(id), + }} +} + +// SetFilterByIDOrIndex sets the filter to an id or an index. +func (x *QueryRequest) SetFilterByIDOrIndex(idOrIndex string) { + x.Filter = &structpb.Struct{Fields: map[string]*structpb.Value{ + "$or": structpb.NewListValue(&structpb.ListValue{Values: []*structpb.Value{ + structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{ + "id": structpb.NewStringValue(idOrIndex), + }}), + structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{ + "$index": structpb.NewStringValue(idOrIndex), + }}), + }}), + }} +} + // default is 4MB, but we'll do 1MB const maxMessageSize = 1024 * 1024 * 1 diff --git a/pkg/storage/cache.go b/pkg/storage/cache.go new file mode 100644 index 000000000..b780e1e81 --- /dev/null +++ b/pkg/storage/cache.go @@ -0,0 +1,155 @@ +package storage + +import ( + "context" + "encoding/binary" + "sync" + "time" + + "github.com/VictoriaMetrics/fastcache" + "golang.org/x/sync/singleflight" +) + +// A Cache will return cached data when available or call update when not. +type Cache interface { + GetOrUpdate( + ctx context.Context, + key []byte, + update func(ctx context.Context) ([]byte, error), + ) ([]byte, error) + Invalidate(key []byte) +} + +type localCache struct { + singleflight singleflight.Group + mu sync.RWMutex + m map[string][]byte +} + +// NewLocalCache creates a new Cache backed by a map. +func NewLocalCache() Cache { + return &localCache{ + m: make(map[string][]byte), + } +} + +func (cache *localCache) GetOrUpdate( + ctx context.Context, + key []byte, + update func(ctx context.Context) ([]byte, error), +) ([]byte, error) { + strkey := string(key) + + cache.mu.RLock() + cached, ok := cache.m[strkey] + cache.mu.RUnlock() + if ok { + return cached, nil + } + + v, err, _ := cache.singleflight.Do(strkey, func() (interface{}, error) { + cache.mu.RLock() + cached, ok := cache.m[strkey] + cache.mu.RUnlock() + if ok { + return cached, nil + } + + result, err := update(ctx) + if err != nil { + return nil, err + } + + cache.mu.Lock() + cache.m[strkey] = result + cache.mu.Unlock() + + return result, nil + }) + if err != nil { + return nil, err + } + return v.([]byte), nil +} + +func (cache *localCache) Invalidate(key []byte) { + cache.mu.Lock() + delete(cache.m, string(key)) + cache.mu.Unlock() +} + +type globalCache struct { + ttl time.Duration + + singleflight singleflight.Group + mu sync.RWMutex + fastcache *fastcache.Cache +} + +// NewGlobalCache creates a new Cache backed by fastcache and a TTL. +func NewGlobalCache(ttl time.Duration) Cache { + return &globalCache{ + ttl: ttl, + fastcache: fastcache.New(256 * 1024 * 1024), // up to 256MB of RAM + } +} + +func (cache *globalCache) GetOrUpdate( + ctx context.Context, + key []byte, + update func(ctx context.Context) ([]byte, error), +) ([]byte, error) { + now := time.Now() + data, expiry, ok := cache.get(key) + if ok && now.Before(expiry) { + return data, nil + } + + v, err, _ := cache.singleflight.Do(string(key), func() (interface{}, error) { + data, expiry, ok := cache.get(key) + if ok && now.Before(expiry) { + return data, nil + } + + value, err := update(ctx) + if err != nil { + return nil, err + } + cache.set(key, value) + return value, nil + }) + if err != nil { + return nil, err + } + return v.([]byte), nil +} + +func (cache *globalCache) Invalidate(key []byte) { + cache.mu.Lock() + cache.fastcache.Del(key) + cache.mu.Unlock() +} + +func (cache *globalCache) get(k []byte) (data []byte, expiry time.Time, ok bool) { + cache.mu.RLock() + item := cache.fastcache.Get(nil, k) + cache.mu.RUnlock() + if len(item) < 8 { + return data, expiry, false + } + + unix, data := binary.LittleEndian.Uint64(item), item[8:] + expiry = time.UnixMilli(int64(unix)) + return data, expiry, true +} + +func (cache *globalCache) set(k, v []byte) { + unix := time.Now().Add(cache.ttl).UnixMilli() + item := make([]byte, len(v)+8) + binary.LittleEndian.PutUint64(item, uint64(unix)) + copy(item[8:], v) + + cache.mu.Lock() + cache.fastcache.Set(k, item) + cache.mu.Unlock() +} diff --git a/pkg/storage/cache_test.go b/pkg/storage/cache_test.go new file mode 100644 index 000000000..1f73db6e8 --- /dev/null +++ b/pkg/storage/cache_test.go @@ -0,0 +1,73 @@ +package storage + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestLocalCache(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + callCount := 0 + update := func(ctx context.Context) ([]byte, error) { + callCount++ + return []byte("v1"), nil + } + c := NewLocalCache() + v, err := c.GetOrUpdate(ctx, []byte("k1"), update) + assert.NoError(t, err) + assert.Equal(t, []byte("v1"), v) + assert.Equal(t, 1, callCount) + + v, err = c.GetOrUpdate(ctx, []byte("k1"), update) + assert.NoError(t, err) + assert.Equal(t, []byte("v1"), v) + assert.Equal(t, 1, callCount) + + c.Invalidate([]byte("k1")) + + v, err = c.GetOrUpdate(ctx, []byte("k1"), update) + assert.NoError(t, err) + assert.Equal(t, []byte("v1"), v) + assert.Equal(t, 2, callCount) +} + +func TestGlobalCache(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + callCount := 0 + update := func(ctx context.Context) ([]byte, error) { + callCount++ + return []byte("v1"), nil + } + c := NewGlobalCache(time.Millisecond * 100) + v, err := c.GetOrUpdate(ctx, []byte("k1"), update) + assert.NoError(t, err) + assert.Equal(t, []byte("v1"), v) + assert.Equal(t, 1, callCount) + + v, err = c.GetOrUpdate(ctx, []byte("k1"), update) + assert.NoError(t, err) + assert.Equal(t, []byte("v1"), v) + assert.Equal(t, 1, callCount) + + c.Invalidate([]byte("k1")) + + v, err = c.GetOrUpdate(ctx, []byte("k1"), update) + assert.NoError(t, err) + assert.Equal(t, []byte("v1"), v) + assert.Equal(t, 2, callCount) + + assert.Eventually(t, func() bool { + _, err := c.GetOrUpdate(ctx, []byte("k1"), func(ctx context.Context) ([]byte, error) { + return nil, fmt.Errorf("ERROR") + }) + return err != nil + }, time.Second, time.Millisecond*10, "should honor TTL") +} diff --git a/pkg/storage/querier.go b/pkg/storage/querier.go new file mode 100644 index 000000000..4c8a992ab --- /dev/null +++ b/pkg/storage/querier.go @@ -0,0 +1,210 @@ +package storage + +import ( + "context" + "sync" + + "github.com/google/uuid" + grpc "google.golang.org/grpc" + "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +// A Querier is a read-only subset of the client methods +type Querier interface { + Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) +} + +// nilQuerier always returns NotFound. +type nilQuerier struct{} + +func (nilQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { + return nil, status.Error(codes.NotFound, "not found") +} + +var querierKey struct{} + +// GetQuerier gets the databroker Querier from the context. +func GetQuerier(ctx context.Context) Querier { + q, ok := ctx.Value(querierKey).(Querier) + if !ok { + q = nilQuerier{} + } + return q +} + +// WithQuerier sets the databroker Querier on a context. +func WithQuerier(ctx context.Context, querier Querier) context.Context { + return context.WithValue(ctx, querierKey, querier) +} + +type staticQuerier struct { + records []*databroker.Record +} + +// NewStaticQuerier creates a Querier that returns statically defined protobuf records. +func NewStaticQuerier(msgs ...proto.Message) Querier { + getter := &staticQuerier{} + for _, msg := range msgs { + any := protoutil.NewAny(msg) + record := new(databroker.Record) + record.ModifiedAt = timestamppb.Now() + record.Version = cryptutil.NewRandomUInt64() + record.Id = uuid.New().String() + record.Data = any + record.Type = any.TypeUrl + if hasID, ok := msg.(interface{ GetId() string }); ok { + record.Id = hasID.GetId() + } + getter.records = append(getter.records, record) + } + return getter +} + +// Query queries for records. +func (q *staticQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { + expr, err := FilterExpressionFromStruct(in.GetFilter()) + if err != nil { + return nil, err + } + + filter, err := RecordStreamFilterFromFilterExpression(expr) + if err != nil { + return nil, err + } + + res := new(databroker.QueryResponse) + for _, record := range q.records { + if record.GetType() != in.GetType() { + continue + } + + if !filter(record) { + continue + } + + if in.GetQuery() != "" && !MatchAny(record.GetData(), in.GetQuery()) { + continue + } + + res.Records = append(res.Records, record) + } + + var total int + res.Records, total = databroker.ApplyOffsetAndLimit( + res.Records, + int(in.GetOffset()), + int(in.GetLimit()), + ) + res.TotalCount = int64(total) + return res, nil +} + +type clientQuerier struct { + client databroker.DataBrokerServiceClient +} + +// NewQuerier creates a new Querier that implements the Querier interface by making calls to the databroker over gRPC. +func NewQuerier(client databroker.DataBrokerServiceClient) Querier { + return &clientQuerier{client: client} +} + +// Query queries for records. +func (q *clientQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { + return q.client.Query(ctx, in, opts...) +} + +// A TracingQuerier records calls to Query. +type TracingQuerier struct { + underlying Querier + + mu sync.Mutex + traces []QueryTrace +} + +// A QueryTrace traces a call to Query. +type QueryTrace struct { + ServerVersion, RecordVersion uint64 + + RecordType string + Query string + Filter *structpb.Struct +} + +// NewTracingQuerier creates a new TracingQuerier. +func NewTracingQuerier(q Querier) *TracingQuerier { + return &TracingQuerier{ + underlying: q, + } +} + +// Query queries for records. +func (q *TracingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { + res, err := q.underlying.Query(ctx, in, opts...) + if err == nil { + q.mu.Lock() + q.traces = append(q.traces, QueryTrace{ + RecordType: in.GetType(), + Query: in.GetQuery(), + Filter: in.GetFilter(), + }) + q.mu.Unlock() + } + return res, err +} + +// Traces returns all the traces. +func (q *TracingQuerier) Traces() []QueryTrace { + q.mu.Lock() + traces := make([]QueryTrace, len(q.traces)) + copy(traces, q.traces) + q.mu.Unlock() + return traces +} + +type cachingQuerier struct { + q Querier + cache Cache +} + +// NewCachingQuerier creates a new querier that caches results in a Cache. +func NewCachingQuerier(q Querier, cache Cache) Querier { + return &cachingQuerier{ + q: q, + cache: cache, + } +} + +func (q *cachingQuerier) Query(ctx context.Context, in *databroker.QueryRequest, opts ...grpc.CallOption) (*databroker.QueryResponse, error) { + key, err := (&proto.MarshalOptions{ + Deterministic: true, + }).Marshal(in) + if err != nil { + return nil, err + } + + rawResult, err := q.cache.GetOrUpdate(ctx, key, func(ctx context.Context) ([]byte, error) { + res, err := q.q.Query(ctx, in, opts...) + if err != nil { + return nil, err + } + return proto.Marshal(res) + }) + if err != nil { + return nil, err + } + + var res databroker.QueryResponse + err = proto.Unmarshal(rawResult, &res) + if err != nil { + return nil, err + } + return &res, nil +}