diff --git a/pkg/grpc/databroker/changeset.go b/pkg/grpc/databroker/changeset.go index cc72f0e2b..0d0d3ba81 100644 --- a/pkg/grpc/databroker/changeset.go +++ b/pkg/grpc/databroker/changeset.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -14,7 +14,7 @@ func GetChangeSet(current, target RecordSetBundle, cmpFn RecordCompareFn) []*Rec cs := &changeSet{now: timestamppb.Now()} for _, rec := range current.GetRemoved(target).Flatten() { - cs.Remove(rec.GetType(), rec.GetId()) + cs.Remove(rec) } for _, rec := range current.GetModified(target, cmpFn).Flatten() { cs.Upsert(rec) @@ -33,13 +33,10 @@ type changeSet struct { } // Remove adds a record to the change set. -func (cs *changeSet) Remove(typ string, id string) { - cs.updates = append(cs.updates, &Record{ - Type: typ, - Id: id, - DeletedAt: cs.now, - Data: &anypb.Any{TypeUrl: typ}, - }) +func (cs *changeSet) Remove(record *Record) { + record = proto.Clone(record).(*Record) + record.DeletedAt = cs.now + cs.updates = append(cs.updates, record) } // Upsert adds a record to the change set. diff --git a/pkg/grpc/databroker/changeset_test.go b/pkg/grpc/databroker/changeset_test.go new file mode 100644 index 000000000..14c101055 --- /dev/null +++ b/pkg/grpc/databroker/changeset_test.go @@ -0,0 +1,52 @@ +package databroker_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/pomerium/datasource/pkg/directory" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +func TestGetChangeset(t *testing.T) { + t.Parallel() + + rsb1 := databroker.RecordSetBundle{} + rsb2 := databroker.RecordSetBundle{} + updates := databroker.GetChangeSet(rsb1, rsb2, func(record1, record2 *databroker.Record) bool { + return cmp.Equal(record1, record2, protocmp.Transform()) + }) + assert.Len(t, updates, 0) + + rsb1 = databroker.RecordSetBundle{} + rsb1.Add(&databroker.Record{ + Type: directory.UserRecordType, + Id: "user-1", + Data: protoutil.NewAny(mustNewStruct(map[string]any{ + "email": "user-1@example.com", + })), + }) + rsb2 = databroker.RecordSetBundle{} + updates = databroker.GetChangeSet(rsb1, rsb2, func(record1, record2 *databroker.Record) bool { + return cmp.Equal(record1, record2, protocmp.Transform()) + }) + if assert.Len(t, updates, 1) { + assert.Equal(t, directory.UserRecordType, updates[0].GetType()) + assert.Equal(t, "type.googleapis.com/google.protobuf.Struct", updates[0].GetData().GetTypeUrl(), + "should preserve data type") + assert.NotNil(t, updates[0].GetDeletedAt()) + } +} + +func mustNewStruct(m map[string]any) *structpb.Struct { + s, err := structpb.NewStruct(m) + if err != nil { + panic(err) + } + return s +}