storage/inmemory: implement patch operation (#4654)

Add a new Patch() method that updates specific fields of an existing
record's data, based on a field mask.

Extract some logic from the existing Get() and Put() methods so it can
be shared with the new Patch() method.
This commit is contained in:
Kenneth Jenkins 2023-11-02 11:03:00 -07:00 committed by GitHub
parent 5f4e13e130
commit 47890e9ee1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 270 additions and 16 deletions

View file

@ -13,6 +13,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
@ -130,18 +131,25 @@ func (backend *Backend) Close() error {
func (backend *Backend) Get(_ context.Context, recordType, id string) (*databroker.Record, error) { func (backend *Backend) Get(_ context.Context, recordType, id string) (*databroker.Record, error) {
backend.mu.RLock() backend.mu.RLock()
defer backend.mu.RUnlock() defer backend.mu.RUnlock()
if record := backend.get(recordType, id); record != nil {
return record, nil
}
return nil, storage.ErrNotFound
}
// get gets a record from the in-memory store, assuming the RWMutex is held.
func (backend *Backend) get(recordType, id string) *databroker.Record {
records := backend.lookup[recordType] records := backend.lookup[recordType]
if records == nil { if records == nil {
return nil, storage.ErrNotFound return nil
} }
record := records.Get(id) record := records.Get(id)
if record == nil { if record == nil {
return nil, storage.ErrNotFound return nil
} }
return dup(record), nil return dup(record)
} }
// GetOptions returns the options for a type in the in-memory store. // GetOptions returns the options for a type in the in-memory store.
@ -216,19 +224,7 @@ func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (
Str("db_type", record.Type) Str("db_type", record.Type)
}) })
backend.recordChange(record) backend.update(record)
c, ok := backend.lookup[record.GetType()]
if !ok {
c = NewRecordCollection()
backend.lookup[record.GetType()] = c
}
if record.GetDeletedAt() != nil {
c.Delete(record.GetId())
} else {
c.Put(dup(record))
}
recordTypes[record.GetType()] = struct{}{} recordTypes[record.GetType()] = struct{}{}
} }
@ -239,6 +235,68 @@ func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (
return backend.serverVersion, nil return backend.serverVersion, nil
} }
// update stores a record into the in-memory store, assuming the RWMutex is held.
func (backend *Backend) update(record *databroker.Record) {
backend.recordChange(record)
c, ok := backend.lookup[record.GetType()]
if !ok {
c = NewRecordCollection()
backend.lookup[record.GetType()] = c
}
if record.GetDeletedAt() != nil {
c.Delete(record.GetId())
} else {
c.Put(dup(record))
}
}
// Patch updates the specified fields of existing record(s).
func (backend *Backend) Patch(
ctx context.Context, records []*databroker.Record, fields *fieldmaskpb.FieldMask,
) (serverVersion uint64, patchedRecords []*databroker.Record, err error) {
backend.mu.Lock()
defer backend.mu.Unlock()
defer backend.onChange.Broadcast(ctx)
serverVersion = backend.serverVersion
patchedRecords = make([]*databroker.Record, 0, len(records))
for _, record := range records {
err = backend.patch(record, fields)
if storage.IsNotFound(err) {
// Skip any record that does not currently exist.
continue
} else if err != nil {
return
}
patchedRecords = append(patchedRecords, record)
}
return
}
// patch updates the specified fields of an existing record, assuming the RWMutex is held.
func (backend *Backend) patch(record *databroker.Record, fields *fieldmaskpb.FieldMask) error {
if record == nil {
return fmt.Errorf("cannot patch using a nil record")
}
existing := backend.get(record.GetType(), record.GetId())
if existing == nil {
return storage.ErrNotFound
}
if err := storage.PatchRecord(existing, record, fields); err != nil {
return err
}
backend.update(record)
return nil
}
// SetOptions sets the options for a type in the in-memory store. // SetOptions sets the options for a type in the in-memory store.
func (backend *Backend) SetOptions(_ context.Context, recordType string, options *databroker.Options) error { func (backend *Backend) SetOptions(_ context.Context, recordType string, options *databroker.Options) error {
backend.mu.Lock() backend.mu.Lock()

View file

@ -15,6 +15,7 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/storage/storagetest"
) )
func TestBackend(t *testing.T) { func TestBackend(t *testing.T) {
@ -72,6 +73,9 @@ func TestBackend(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, record) assert.Nil(t, record)
}) })
t.Run("patch", func(t *testing.T) {
storagetest.TestBackendPatch(t, ctx, backend)
})
} }
func TestExpiry(t *testing.T) { func TestExpiry(t *testing.T) {

36
pkg/storage/patch.go Normal file
View file

@ -0,0 +1,36 @@
package storage
import (
"fmt"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
// PatchRecord extracts the data from existing and record, updates the existing
// data subject to the provided field mask, and stores the result back into
// record. The existing record is not modified.
func PatchRecord(existing, record *databroker.Record, fields *fieldmaskpb.FieldMask) error {
dst, err := existing.GetData().UnmarshalNew()
if err != nil {
return fmt.Errorf("could not unmarshal existing record data: %w", err)
}
src, err := record.GetData().UnmarshalNew()
if err != nil {
return fmt.Errorf("could not unmarshal new record data: %w", err)
}
if err := protoutil.OverwriteMasked(dst, src, fields); err != nil {
return fmt.Errorf("cannot patch record: %w", err)
}
record.Data, err = anypb.New(dst)
if err != nil {
return fmt.Errorf("could not marshal new record data: %w", err)
}
return nil
}

45
pkg/storage/patch_test.go Normal file
View file

@ -0,0 +1,45 @@
package storage_test
import (
"testing"
"time"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/storage"
)
func TestPatchRecord(t *testing.T) {
tm := timestamppb.New(time.Date(2023, 10, 31, 12, 0, 0, 0, time.UTC))
s1 := &session.Session{Id: "session-id"}
a1, _ := anypb.New(s1)
r1 := &databroker.Record{Data: a1}
s2 := &session.Session{Id: "new-session-id", AccessedAt: tm}
a2, _ := anypb.New(s2)
r2 := &databroker.Record{Data: a2}
originalR1 := proto.Clone(r1).(*databroker.Record)
m, _ := fieldmaskpb.New(&session.Session{}, "accessed_at")
storage.PatchRecord(r1, r2, m)
testutil.AssertProtoJSONEqual(t, `{
"data": {
"@type": "type.googleapis.com/session.Session",
"accessedAt": "2023-10-31T12:00:00Z",
"id": "session-id"
}
}`, r2)
// The existing record should not be modified.
testutil.AssertProtoEqual(t, originalR1, r1)
}

View file

@ -0,0 +1,111 @@
// Package storagetest contains test cases for use in verifying the behavior of
// a storage.Backend implementation.
package storagetest
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/storage"
)
// BackendWithPatch is a storage.Backend with an additional Patch() method.
// TODO: delete this type once Patch() is added to the storage.Backend interface
type BackendWithPatch interface {
storage.Backend
Patch(context.Context, []*databroker.Record, *fieldmaskpb.FieldMask) (uint64, []*databroker.Record, error)
}
// TestBackendPatch verifies the behavior of the backend Patch() method.
func TestBackendPatch(t *testing.T, ctx context.Context, backend BackendWithPatch) { //nolint:revive
mkRecord := func(s *session.Session) *databroker.Record {
a, _ := anypb.New(s)
return &databroker.Record{
Type: a.TypeUrl,
Id: s.Id,
Data: a,
}
}
// Populate an initial set of session records.
s1 := &session.Session{
Id: "session-1",
IdToken: &session.IDToken{Issuer: "issuer-1"},
OauthToken: &session.OAuthToken{AccessToken: "access-token-1"},
}
s2 := &session.Session{
Id: "session-2",
IdToken: &session.IDToken{Issuer: "issuer-2"},
OauthToken: &session.OAuthToken{AccessToken: "access-token-2"},
}
s3 := &session.Session{
Id: "session-3",
IdToken: &session.IDToken{Issuer: "issuer-3"},
OauthToken: &session.OAuthToken{AccessToken: "access-token-3"},
}
initial := []*databroker.Record{mkRecord(s1), mkRecord(s2), mkRecord(s3)}
_, err := backend.Put(ctx, initial)
require.NoError(t, err)
// Now patch just the oauth_token field.
u1 := &session.Session{
Id: "session-1",
OauthToken: &session.OAuthToken{AccessToken: "access-token-1-new"},
}
u2 := &session.Session{
Id: "session-4-does-not-exist",
OauthToken: &session.OAuthToken{AccessToken: "access-token-4-new"},
}
u3 := &session.Session{
Id: "session-3",
OauthToken: &session.OAuthToken{AccessToken: "access-token-3-new"},
}
mask, _ := fieldmaskpb.New(&session.Session{}, "oauth_token")
_, updated, err := backend.Patch(
ctx, []*databroker.Record{mkRecord(u1), mkRecord(u2), mkRecord(u3)}, mask)
require.NoError(t, err)
// The OAuthToken message should be updated but the IDToken message should
// be unchanged, as it was not included in the field mask. The results
// should indicate that only two records were updated (one did not exist).
assert.Equal(t, 2, len(updated))
assert.Greater(t, updated[0].Version, initial[0].Version)
assert.Greater(t, updated[1].Version, initial[2].Version)
testutil.AssertProtoJSONEqual(t, `{
"@type": "type.googleapis.com/session.Session",
"id": "session-1",
"idToken": {
"issuer": "issuer-1"
},
"oauthToken": {
"accessToken": "access-token-1-new"
}
}`, updated[0].Data)
testutil.AssertProtoJSONEqual(t, `{
"@type": "type.googleapis.com/session.Session",
"id": "session-3",
"idToken": {
"issuer": "issuer-3"
},
"oauthToken": {
"accessToken": "access-token-3-new"
}
}`, updated[1].Data)
// Verify that the updates will indeed be seen by a subsequent Get().
r1, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-1")
testutil.AssertProtoEqual(t, updated[0], r1)
r3, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-3")
testutil.AssertProtoEqual(t, updated[1], r3)
}