From 47890e9ee1ed55538e07b30e3ac01cad7c198db0 Mon Sep 17 00:00:00 2001 From: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com> Date: Thu, 2 Nov 2023 11:03:00 -0700 Subject: [PATCH] storage/inmemory: implement patch operation (#4654) Add a new Patch() method that updates specific fields of an existing record's data, based on a field mask. Extract some logic from the existing Get() and Put() methods so it can be shared with the new Patch() method. --- pkg/storage/inmemory/backend.go | 90 ++++++++++++++++---- pkg/storage/inmemory/backend_test.go | 4 + pkg/storage/patch.go | 36 ++++++++ pkg/storage/patch_test.go | 45 ++++++++++ pkg/storage/storagetest/storagetest.go | 111 +++++++++++++++++++++++++ 5 files changed, 270 insertions(+), 16 deletions(-) create mode 100644 pkg/storage/patch.go create mode 100644 pkg/storage/patch_test.go create mode 100644 pkg/storage/storagetest/storagetest.go diff --git a/pkg/storage/inmemory/backend.go b/pkg/storage/inmemory/backend.go index 42199bca7..98357cda7 100644 --- a/pkg/storage/inmemory/backend.go +++ b/pkg/storage/inmemory/backend.go @@ -13,6 +13,7 @@ import ( "github.com/rs/zerolog" "golang.org/x/exp/maps" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/log" @@ -130,18 +131,25 @@ func (backend *Backend) Close() error { func (backend *Backend) Get(_ context.Context, recordType, id string) (*databroker.Record, error) { backend.mu.RLock() defer backend.mu.RUnlock() + if record := backend.get(recordType, id); record != nil { + return record, nil + } + return nil, storage.ErrNotFound +} +// get gets a record from the in-memory store, assuming the RWMutex is held. +func (backend *Backend) get(recordType, id string) *databroker.Record { records := backend.lookup[recordType] if records == nil { - return nil, storage.ErrNotFound + return nil } record := records.Get(id) if record == nil { - return nil, storage.ErrNotFound + return nil } - return dup(record), nil + return dup(record) } // GetOptions returns the options for a type in the in-memory store. @@ -216,19 +224,7 @@ func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) ( Str("db_type", record.Type) }) - backend.recordChange(record) - - c, ok := backend.lookup[record.GetType()] - if !ok { - c = NewRecordCollection() - backend.lookup[record.GetType()] = c - } - - if record.GetDeletedAt() != nil { - c.Delete(record.GetId()) - } else { - c.Put(dup(record)) - } + backend.update(record) recordTypes[record.GetType()] = struct{}{} } @@ -239,6 +235,68 @@ func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) ( return backend.serverVersion, nil } +// update stores a record into the in-memory store, assuming the RWMutex is held. +func (backend *Backend) update(record *databroker.Record) { + backend.recordChange(record) + + c, ok := backend.lookup[record.GetType()] + if !ok { + c = NewRecordCollection() + backend.lookup[record.GetType()] = c + } + + if record.GetDeletedAt() != nil { + c.Delete(record.GetId()) + } else { + c.Put(dup(record)) + } +} + +// Patch updates the specified fields of existing record(s). +func (backend *Backend) Patch( + ctx context.Context, records []*databroker.Record, fields *fieldmaskpb.FieldMask, +) (serverVersion uint64, patchedRecords []*databroker.Record, err error) { + backend.mu.Lock() + defer backend.mu.Unlock() + defer backend.onChange.Broadcast(ctx) + + serverVersion = backend.serverVersion + patchedRecords = make([]*databroker.Record, 0, len(records)) + + for _, record := range records { + err = backend.patch(record, fields) + if storage.IsNotFound(err) { + // Skip any record that does not currently exist. + continue + } else if err != nil { + return + } + patchedRecords = append(patchedRecords, record) + } + + return +} + +// patch updates the specified fields of an existing record, assuming the RWMutex is held. +func (backend *Backend) patch(record *databroker.Record, fields *fieldmaskpb.FieldMask) error { + if record == nil { + return fmt.Errorf("cannot patch using a nil record") + } + + existing := backend.get(record.GetType(), record.GetId()) + if existing == nil { + return storage.ErrNotFound + } + + if err := storage.PatchRecord(existing, record, fields); err != nil { + return err + } + + backend.update(record) + + return nil +} + // SetOptions sets the options for a type in the in-memory store. func (backend *Backend) SetOptions(_ context.Context, recordType string, options *databroker.Options) error { backend.mu.Lock() diff --git a/pkg/storage/inmemory/backend_test.go b/pkg/storage/inmemory/backend_test.go index 643afd460..ebbdc7033 100644 --- a/pkg/storage/inmemory/backend_test.go +++ b/pkg/storage/inmemory/backend_test.go @@ -15,6 +15,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" + "github.com/pomerium/pomerium/pkg/storage/storagetest" ) func TestBackend(t *testing.T) { @@ -72,6 +73,9 @@ func TestBackend(t *testing.T) { assert.Error(t, err) assert.Nil(t, record) }) + t.Run("patch", func(t *testing.T) { + storagetest.TestBackendPatch(t, ctx, backend) + }) } func TestExpiry(t *testing.T) { diff --git a/pkg/storage/patch.go b/pkg/storage/patch.go new file mode 100644 index 000000000..19b0843d4 --- /dev/null +++ b/pkg/storage/patch.go @@ -0,0 +1,36 @@ +package storage + +import ( + "fmt" + + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/fieldmaskpb" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +// PatchRecord extracts the data from existing and record, updates the existing +// data subject to the provided field mask, and stores the result back into +// record. The existing record is not modified. +func PatchRecord(existing, record *databroker.Record, fields *fieldmaskpb.FieldMask) error { + dst, err := existing.GetData().UnmarshalNew() + if err != nil { + return fmt.Errorf("could not unmarshal existing record data: %w", err) + } + + src, err := record.GetData().UnmarshalNew() + if err != nil { + return fmt.Errorf("could not unmarshal new record data: %w", err) + } + + if err := protoutil.OverwriteMasked(dst, src, fields); err != nil { + return fmt.Errorf("cannot patch record: %w", err) + } + + record.Data, err = anypb.New(dst) + if err != nil { + return fmt.Errorf("could not marshal new record data: %w", err) + } + return nil +} diff --git a/pkg/storage/patch_test.go b/pkg/storage/patch_test.go new file mode 100644 index 000000000..9a868183c --- /dev/null +++ b/pkg/storage/patch_test.go @@ -0,0 +1,45 @@ +package storage_test + +import ( + "testing" + "time" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/fieldmaskpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/storage" +) + +func TestPatchRecord(t *testing.T) { + tm := timestamppb.New(time.Date(2023, 10, 31, 12, 0, 0, 0, time.UTC)) + + s1 := &session.Session{Id: "session-id"} + a1, _ := anypb.New(s1) + r1 := &databroker.Record{Data: a1} + + s2 := &session.Session{Id: "new-session-id", AccessedAt: tm} + a2, _ := anypb.New(s2) + r2 := &databroker.Record{Data: a2} + + originalR1 := proto.Clone(r1).(*databroker.Record) + + m, _ := fieldmaskpb.New(&session.Session{}, "accessed_at") + + storage.PatchRecord(r1, r2, m) + + testutil.AssertProtoJSONEqual(t, `{ + "data": { + "@type": "type.googleapis.com/session.Session", + "accessedAt": "2023-10-31T12:00:00Z", + "id": "session-id" + } + }`, r2) + + // The existing record should not be modified. + testutil.AssertProtoEqual(t, originalR1, r1) +} diff --git a/pkg/storage/storagetest/storagetest.go b/pkg/storage/storagetest/storagetest.go new file mode 100644 index 000000000..b8a290550 --- /dev/null +++ b/pkg/storage/storagetest/storagetest.go @@ -0,0 +1,111 @@ +// Package storagetest contains test cases for use in verifying the behavior of +// a storage.Backend implementation. +package storagetest + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/fieldmaskpb" + + "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/storage" +) + +// BackendWithPatch is a storage.Backend with an additional Patch() method. +// TODO: delete this type once Patch() is added to the storage.Backend interface +type BackendWithPatch interface { + storage.Backend + Patch(context.Context, []*databroker.Record, *fieldmaskpb.FieldMask) (uint64, []*databroker.Record, error) +} + +// TestBackendPatch verifies the behavior of the backend Patch() method. +func TestBackendPatch(t *testing.T, ctx context.Context, backend BackendWithPatch) { //nolint:revive + mkRecord := func(s *session.Session) *databroker.Record { + a, _ := anypb.New(s) + return &databroker.Record{ + Type: a.TypeUrl, + Id: s.Id, + Data: a, + } + } + + // Populate an initial set of session records. + s1 := &session.Session{ + Id: "session-1", + IdToken: &session.IDToken{Issuer: "issuer-1"}, + OauthToken: &session.OAuthToken{AccessToken: "access-token-1"}, + } + s2 := &session.Session{ + Id: "session-2", + IdToken: &session.IDToken{Issuer: "issuer-2"}, + OauthToken: &session.OAuthToken{AccessToken: "access-token-2"}, + } + s3 := &session.Session{ + Id: "session-3", + IdToken: &session.IDToken{Issuer: "issuer-3"}, + OauthToken: &session.OAuthToken{AccessToken: "access-token-3"}, + } + initial := []*databroker.Record{mkRecord(s1), mkRecord(s2), mkRecord(s3)} + + _, err := backend.Put(ctx, initial) + require.NoError(t, err) + + // Now patch just the oauth_token field. + u1 := &session.Session{ + Id: "session-1", + OauthToken: &session.OAuthToken{AccessToken: "access-token-1-new"}, + } + u2 := &session.Session{ + Id: "session-4-does-not-exist", + OauthToken: &session.OAuthToken{AccessToken: "access-token-4-new"}, + } + u3 := &session.Session{ + Id: "session-3", + OauthToken: &session.OAuthToken{AccessToken: "access-token-3-new"}, + } + + mask, _ := fieldmaskpb.New(&session.Session{}, "oauth_token") + + _, updated, err := backend.Patch( + ctx, []*databroker.Record{mkRecord(u1), mkRecord(u2), mkRecord(u3)}, mask) + require.NoError(t, err) + + // The OAuthToken message should be updated but the IDToken message should + // be unchanged, as it was not included in the field mask. The results + // should indicate that only two records were updated (one did not exist). + assert.Equal(t, 2, len(updated)) + assert.Greater(t, updated[0].Version, initial[0].Version) + assert.Greater(t, updated[1].Version, initial[2].Version) + testutil.AssertProtoJSONEqual(t, `{ + "@type": "type.googleapis.com/session.Session", + "id": "session-1", + "idToken": { + "issuer": "issuer-1" + }, + "oauthToken": { + "accessToken": "access-token-1-new" + } + }`, updated[0].Data) + testutil.AssertProtoJSONEqual(t, `{ + "@type": "type.googleapis.com/session.Session", + "id": "session-3", + "idToken": { + "issuer": "issuer-3" + }, + "oauthToken": { + "accessToken": "access-token-3-new" + } + }`, updated[1].Data) + + // Verify that the updates will indeed be seen by a subsequent Get(). + r1, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-1") + testutil.AssertProtoEqual(t, updated[0], r1) + r3, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-3") + testutil.AssertProtoEqual(t, updated[1], r3) +}