diff --git a/pkg/grpc/databroker/changeset.go b/pkg/grpc/databroker/changeset.go new file mode 100644 index 000000000..94804ca98 --- /dev/null +++ b/pkg/grpc/databroker/changeset.go @@ -0,0 +1,51 @@ +package databroker + +import ( + "context" + "fmt" + + "google.golang.org/protobuf/types/known/timestamppb" +) + +// ChangeSet is a set of databroker changes. +type ChangeSet struct { + now *timestamppb.Timestamp + updates []*Record +} + +// NewChangeSet creates a new databroker change set. +func NewChangeSet() *ChangeSet { + return &ChangeSet{ + now: timestamppb.Now(), + } +} + +// Remove adds a record to the change set. +func (cs *ChangeSet) Remove(typ string, id string) { + cs.updates = append(cs.updates, &Record{ + Type: typ, + Id: id, + DeletedAt: cs.now, + }) +} + +// Upsert adds a record to the change set. +func (cs *ChangeSet) Upsert(record *Record) { + cs.updates = append(cs.updates, &Record{ + Type: record.Type, + Id: record.Id, + Data: record.Data, + }) +} + +// ApplyChanges applies the changes to the databroker. +func ApplyChanges(ctx context.Context, client DataBrokerServiceClient, changes *ChangeSet) error { + updates := OptimumPutRequestsFromRecords(changes.updates) + for _, req := range updates { + _, err := client.Put(ctx, req) + if err != nil { + return fmt.Errorf("put databroker record: %w", err) + } + } + return nil +} diff --git a/pkg/grpc/databroker/recordset.go b/pkg/grpc/databroker/recordset.go new file mode 100644 index 000000000..919a3c538 --- /dev/null +++ b/pkg/grpc/databroker/recordset.go @@ -0,0 +1,132 @@ +package databroker + +// RecordSetBundle is an index of databroker records by type +type RecordSetBundle[T IRecord[T]] map[string]RecordSet[T] + +// RecordSet is an index of databroker records by their id. +type RecordSet[T IRecord[T]] map[string]T + +// IRecord is an abstract record +type IRecord[T any] interface { + GetID() string + GetType() string + Equal(other T) bool +} + +// RecordTypes returns the types of records in the bundle. +func (rsb RecordSetBundle[T]) RecordTypes() []string { + types := make([]string, 0, len(rsb)) + for typ := range rsb { + types = append(types, typ) + } + return types +} + +// Add adds a record to the bundle. +func (rsb RecordSetBundle[T]) Add(record T) { + rs, ok := rsb[record.GetType()] + if !ok { + rs = make(RecordSet[T]) + rsb[record.GetType()] = rs + } + rs[record.GetID()] = record +} + +// GetAdded returns the records that are in other but not in rsb. +func (rsb RecordSetBundle[T]) GetAdded(other RecordSetBundle[T]) RecordSetBundle[T] { + added := make(RecordSetBundle[T]) + for otherType, otherRS := range other { + rs, ok := rsb[otherType] + if !ok { + added[otherType] = otherRS + continue + } + rss := rs.GetAdded(other[otherType]) + if len(rss) > 0 { + added[otherType] = rss + } + } + return added +} + +// GetRemoved returns the records that are in rs but not in other. +func (rsb RecordSetBundle[T]) GetRemoved(other RecordSetBundle[T]) RecordSetBundle[T] { + return other.GetAdded(rsb) +} + +// GetModified returns the records that are in both rs and other but have different data. +func (rsb RecordSetBundle[T]) GetModified(other RecordSetBundle[T]) RecordSetBundle[T] { + modified := make(RecordSetBundle[T]) + for otherType, otherRS := range other { + rs, ok := rsb[otherType] + if !ok { + continue + } + m := rs.GetModified(otherRS) + if len(m) > 0 { + modified[otherType] = m + } + } + return modified +} + +// GetAdded returns the records that are in other but not in rs. +func (rs RecordSet[T]) GetAdded(other RecordSet[T]) RecordSet[T] { + added := make(RecordSet[T]) + for id, record := range other { + if _, ok := rs[id]; !ok { + added[id] = record + } + } + return added +} + +// GetRemoved returns the records that are in rs but not in other. +func (rs RecordSet[T]) GetRemoved(other RecordSet[T]) RecordSet[T] { + return other.GetAdded(rs) +} + +// GetModified returns the records that are in both rs and other but have different data. +// by comparing the protobuf bytes of the payload. +func (rs RecordSet[T]) GetModified(other RecordSet[T]) RecordSet[T] { + modified := make(RecordSet[T]) + for id, record := range other { + otherRecord, ok := rs[id] + if !ok { + continue + } + + if !record.Equal(otherRecord) { + modified[id] = record + } + } + return modified +} + +// Flatten returns all records in the set. +func (rs RecordSet[T]) Flatten() []T { + records := make([]T, 0, len(rs)) + for _, record := range rs { + records = append(records, record) + } + return records +} + +// Flatten returns all records in the bundle. +func (rsb RecordSetBundle[T]) Flatten() []T { + records := make([]T, 0) + for _, rs := range rsb { + records = append(records, rs.Flatten()...) + } + return records +} + +// Get returns a record by type and id. +func (rsb RecordSetBundle[T]) Get(typeName, id string) (record T, ok bool) { + rs, ok := rsb[typeName] + if !ok { + return + } + record, ok = rs[id] + return +} diff --git a/pkg/grpc/databroker/recordset_test.go b/pkg/grpc/databroker/recordset_test.go new file mode 100644 index 000000000..a4372b63c --- /dev/null +++ b/pkg/grpc/databroker/recordset_test.go @@ -0,0 +1,77 @@ +package databroker_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +type testRecord struct { + Type string + ID string + Val string +} + +func (r testRecord) GetID() string { + return r.ID +} + +func (r testRecord) GetType() string { + return r.Type +} + +func (r testRecord) Equal(other testRecord) bool { + return r.ID == other.ID && r.Type == other.Type && r.Val == other.Val +} + +func TestRecords(t *testing.T) { + initial := make(databroker.RecordSetBundle[testRecord]) + initial.Add(testRecord{ID: "1", Type: "a", Val: "a-1"}) + initial.Add(testRecord{ID: "2", Type: "a", Val: "a-2"}) + initial.Add(testRecord{ID: "1", Type: "b", Val: "b-1"}) + + // test record types + assert.ElementsMatch(t, []string{"a", "b"}, initial.RecordTypes()) + + // test added, deleted and modified + updated := make(databroker.RecordSetBundle[testRecord]) + updated.Add(testRecord{ID: "1", Type: "a", Val: "a-1-1"}) + updated.Add(testRecord{ID: "3", Type: "a", Val: "a-3"}) + updated.Add(testRecord{ID: "1", Type: "b", Val: "b-1"}) + updated.Add(testRecord{ID: "2", Type: "b", Val: "b-2"}) + updated.Add(testRecord{ID: "1", Type: "c", Val: "c-1"}) + + assert.ElementsMatch(t, []string{"a", "b", "c"}, updated.RecordTypes()) + + added := initial.GetAdded(updated) + assert.Equal(t, + databroker.RecordSetBundle[testRecord]{ + "a": databroker.RecordSet[testRecord]{ + "3": {ID: "3", Type: "a", Val: "a-3"}, + }, + "b": databroker.RecordSet[testRecord]{ + "2": {ID: "2", Type: "b", Val: "b-2"}, + }, + "c": databroker.RecordSet[testRecord]{ + "1": {ID: "1", Type: "c", Val: "c-1"}, + }, + }, added) + + removed := initial.GetRemoved(updated) + assert.Equal(t, + databroker.RecordSetBundle[testRecord]{ + "a": databroker.RecordSet[testRecord]{ + "2": {ID: "2", Type: "a", Val: "a-2"}, + }, + }, removed) + + modified := initial.GetModified(updated) + assert.Equal(t, + databroker.RecordSetBundle[testRecord]{ + "a": databroker.RecordSet[testRecord]{ + "1": {ID: "1", Type: "a", Val: "a-1-1"}, + }, + }, modified) +}