mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-05 21:36:02 +02:00
storage/inmemory: implement patch operation (#4654)
Add a new Patch() method that updates specific fields of an existing record's data, based on a field mask. Extract some logic from the existing Get() and Put() methods so it can be shared with the new Patch() method.
This commit is contained in:
parent
5f4e13e130
commit
47890e9ee1
5 changed files with 270 additions and 16 deletions
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -130,18 +131,25 @@ func (backend *Backend) Close() error {
|
|||
func (backend *Backend) Get(_ context.Context, recordType, id string) (*databroker.Record, error) {
|
||||
backend.mu.RLock()
|
||||
defer backend.mu.RUnlock()
|
||||
if record := backend.get(recordType, id); record != nil {
|
||||
return record, nil
|
||||
}
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
// get gets a record from the in-memory store, assuming the RWMutex is held.
|
||||
func (backend *Backend) get(recordType, id string) *databroker.Record {
|
||||
records := backend.lookup[recordType]
|
||||
if records == nil {
|
||||
return nil, storage.ErrNotFound
|
||||
return nil
|
||||
}
|
||||
|
||||
record := records.Get(id)
|
||||
if record == nil {
|
||||
return nil, storage.ErrNotFound
|
||||
return nil
|
||||
}
|
||||
|
||||
return dup(record), nil
|
||||
return dup(record)
|
||||
}
|
||||
|
||||
// GetOptions returns the options for a type in the in-memory store.
|
||||
|
@ -216,6 +224,19 @@ func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (
|
|||
Str("db_type", record.Type)
|
||||
})
|
||||
|
||||
backend.update(record)
|
||||
|
||||
recordTypes[record.GetType()] = struct{}{}
|
||||
}
|
||||
for recordType := range recordTypes {
|
||||
backend.enforceCapacity(recordType)
|
||||
}
|
||||
|
||||
return backend.serverVersion, nil
|
||||
}
|
||||
|
||||
// update stores a record into the in-memory store, assuming the RWMutex is held.
|
||||
func (backend *Backend) update(record *databroker.Record) {
|
||||
backend.recordChange(record)
|
||||
|
||||
c, ok := backend.lookup[record.GetType()]
|
||||
|
@ -229,14 +250,51 @@ func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (
|
|||
} else {
|
||||
c.Put(dup(record))
|
||||
}
|
||||
|
||||
recordTypes[record.GetType()] = struct{}{}
|
||||
}
|
||||
for recordType := range recordTypes {
|
||||
backend.enforceCapacity(recordType)
|
||||
}
|
||||
|
||||
return backend.serverVersion, nil
|
||||
// Patch updates the specified fields of existing record(s).
|
||||
func (backend *Backend) Patch(
|
||||
ctx context.Context, records []*databroker.Record, fields *fieldmaskpb.FieldMask,
|
||||
) (serverVersion uint64, patchedRecords []*databroker.Record, err error) {
|
||||
backend.mu.Lock()
|
||||
defer backend.mu.Unlock()
|
||||
defer backend.onChange.Broadcast(ctx)
|
||||
|
||||
serverVersion = backend.serverVersion
|
||||
patchedRecords = make([]*databroker.Record, 0, len(records))
|
||||
|
||||
for _, record := range records {
|
||||
err = backend.patch(record, fields)
|
||||
if storage.IsNotFound(err) {
|
||||
// Skip any record that does not currently exist.
|
||||
continue
|
||||
} else if err != nil {
|
||||
return
|
||||
}
|
||||
patchedRecords = append(patchedRecords, record)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// patch updates the specified fields of an existing record, assuming the RWMutex is held.
|
||||
func (backend *Backend) patch(record *databroker.Record, fields *fieldmaskpb.FieldMask) error {
|
||||
if record == nil {
|
||||
return fmt.Errorf("cannot patch using a nil record")
|
||||
}
|
||||
|
||||
existing := backend.get(record.GetType(), record.GetId())
|
||||
if existing == nil {
|
||||
return storage.ErrNotFound
|
||||
}
|
||||
|
||||
if err := storage.PatchRecord(existing, record, fields); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backend.update(record)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetOptions sets the options for a type in the in-memory store.
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/pkg/storage/storagetest"
|
||||
)
|
||||
|
||||
func TestBackend(t *testing.T) {
|
||||
|
@ -72,6 +73,9 @@ func TestBackend(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("patch", func(t *testing.T) {
|
||||
storagetest.TestBackendPatch(t, ctx, backend)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpiry(t *testing.T) {
|
||||
|
|
36
pkg/storage/patch.go
Normal file
36
pkg/storage/patch.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
// PatchRecord extracts the data from existing and record, updates the existing
|
||||
// data subject to the provided field mask, and stores the result back into
|
||||
// record. The existing record is not modified.
|
||||
func PatchRecord(existing, record *databroker.Record, fields *fieldmaskpb.FieldMask) error {
|
||||
dst, err := existing.GetData().UnmarshalNew()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not unmarshal existing record data: %w", err)
|
||||
}
|
||||
|
||||
src, err := record.GetData().UnmarshalNew()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not unmarshal new record data: %w", err)
|
||||
}
|
||||
|
||||
if err := protoutil.OverwriteMasked(dst, src, fields); err != nil {
|
||||
return fmt.Errorf("cannot patch record: %w", err)
|
||||
}
|
||||
|
||||
record.Data, err = anypb.New(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not marshal new record data: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
45
pkg/storage/patch_test.go
Normal file
45
pkg/storage/patch_test.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package storage_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func TestPatchRecord(t *testing.T) {
|
||||
tm := timestamppb.New(time.Date(2023, 10, 31, 12, 0, 0, 0, time.UTC))
|
||||
|
||||
s1 := &session.Session{Id: "session-id"}
|
||||
a1, _ := anypb.New(s1)
|
||||
r1 := &databroker.Record{Data: a1}
|
||||
|
||||
s2 := &session.Session{Id: "new-session-id", AccessedAt: tm}
|
||||
a2, _ := anypb.New(s2)
|
||||
r2 := &databroker.Record{Data: a2}
|
||||
|
||||
originalR1 := proto.Clone(r1).(*databroker.Record)
|
||||
|
||||
m, _ := fieldmaskpb.New(&session.Session{}, "accessed_at")
|
||||
|
||||
storage.PatchRecord(r1, r2, m)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"data": {
|
||||
"@type": "type.googleapis.com/session.Session",
|
||||
"accessedAt": "2023-10-31T12:00:00Z",
|
||||
"id": "session-id"
|
||||
}
|
||||
}`, r2)
|
||||
|
||||
// The existing record should not be modified.
|
||||
testutil.AssertProtoEqual(t, originalR1, r1)
|
||||
}
|
111
pkg/storage/storagetest/storagetest.go
Normal file
111
pkg/storage/storagetest/storagetest.go
Normal file
|
@ -0,0 +1,111 @@
|
|||
// Package storagetest contains test cases for use in verifying the behavior of
|
||||
// a storage.Backend implementation.
|
||||
package storagetest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
// BackendWithPatch is a storage.Backend with an additional Patch() method.
|
||||
// TODO: delete this type once Patch() is added to the storage.Backend interface
|
||||
type BackendWithPatch interface {
|
||||
storage.Backend
|
||||
Patch(context.Context, []*databroker.Record, *fieldmaskpb.FieldMask) (uint64, []*databroker.Record, error)
|
||||
}
|
||||
|
||||
// TestBackendPatch verifies the behavior of the backend Patch() method.
|
||||
func TestBackendPatch(t *testing.T, ctx context.Context, backend BackendWithPatch) { //nolint:revive
|
||||
mkRecord := func(s *session.Session) *databroker.Record {
|
||||
a, _ := anypb.New(s)
|
||||
return &databroker.Record{
|
||||
Type: a.TypeUrl,
|
||||
Id: s.Id,
|
||||
Data: a,
|
||||
}
|
||||
}
|
||||
|
||||
// Populate an initial set of session records.
|
||||
s1 := &session.Session{
|
||||
Id: "session-1",
|
||||
IdToken: &session.IDToken{Issuer: "issuer-1"},
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token-1"},
|
||||
}
|
||||
s2 := &session.Session{
|
||||
Id: "session-2",
|
||||
IdToken: &session.IDToken{Issuer: "issuer-2"},
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token-2"},
|
||||
}
|
||||
s3 := &session.Session{
|
||||
Id: "session-3",
|
||||
IdToken: &session.IDToken{Issuer: "issuer-3"},
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token-3"},
|
||||
}
|
||||
initial := []*databroker.Record{mkRecord(s1), mkRecord(s2), mkRecord(s3)}
|
||||
|
||||
_, err := backend.Put(ctx, initial)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now patch just the oauth_token field.
|
||||
u1 := &session.Session{
|
||||
Id: "session-1",
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token-1-new"},
|
||||
}
|
||||
u2 := &session.Session{
|
||||
Id: "session-4-does-not-exist",
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token-4-new"},
|
||||
}
|
||||
u3 := &session.Session{
|
||||
Id: "session-3",
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token-3-new"},
|
||||
}
|
||||
|
||||
mask, _ := fieldmaskpb.New(&session.Session{}, "oauth_token")
|
||||
|
||||
_, updated, err := backend.Patch(
|
||||
ctx, []*databroker.Record{mkRecord(u1), mkRecord(u2), mkRecord(u3)}, mask)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The OAuthToken message should be updated but the IDToken message should
|
||||
// be unchanged, as it was not included in the field mask. The results
|
||||
// should indicate that only two records were updated (one did not exist).
|
||||
assert.Equal(t, 2, len(updated))
|
||||
assert.Greater(t, updated[0].Version, initial[0].Version)
|
||||
assert.Greater(t, updated[1].Version, initial[2].Version)
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"@type": "type.googleapis.com/session.Session",
|
||||
"id": "session-1",
|
||||
"idToken": {
|
||||
"issuer": "issuer-1"
|
||||
},
|
||||
"oauthToken": {
|
||||
"accessToken": "access-token-1-new"
|
||||
}
|
||||
}`, updated[0].Data)
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"@type": "type.googleapis.com/session.Session",
|
||||
"id": "session-3",
|
||||
"idToken": {
|
||||
"issuer": "issuer-3"
|
||||
},
|
||||
"oauthToken": {
|
||||
"accessToken": "access-token-3-new"
|
||||
}
|
||||
}`, updated[1].Data)
|
||||
|
||||
// Verify that the updates will indeed be seen by a subsequent Get().
|
||||
r1, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-1")
|
||||
testutil.AssertProtoEqual(t, updated[0], r1)
|
||||
r3, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-3")
|
||||
testutil.AssertProtoEqual(t, updated[1], r3)
|
||||
}
|
Loading…
Add table
Reference in a new issue