diff --git a/pkg/grpc/databroker/changeset.go b/pkg/grpc/databroker/changeset.go new file mode 100644 index 000000000..20b6889f8 --- /dev/null +++ b/pkg/grpc/databroker/changeset.go @@ -0,0 +1,66 @@ +package databroker + +import ( + "context" + "fmt" + + "google.golang.org/protobuf/types/known/timestamppb" +) + +// GetChangeSet returns list of changes between the current and target record sets, +// that may be applied to the databroker to bring it to the target state. +func GetChangeSet(current, target RecordSetBundle) []*Record { + cs := &changeSet{now: timestamppb.Now()} + + for _, rec := range current.GetRemoved(target).Flatten() { + cs.Remove(rec.GetType(), rec.GetId()) + } + for _, rec := range current.GetModified(target).Flatten() { + cs.Upsert(rec) + } + for _, rec := range current.GetAdded(target).Flatten() { + cs.Upsert(rec) + } + + return cs.updates +} + +// changeSet is a set of databroker changes. +type changeSet struct { + now *timestamppb.Timestamp + updates []*Record +} + +// Remove adds a record to the change set. +func (cs *changeSet) Remove(typ string, id string) { + cs.updates = append(cs.updates, &Record{ + Type: typ, + Id: id, + DeletedAt: cs.now, + }) +} + +// Upsert adds a record to the change set. +func (cs *changeSet) Upsert(record *Record) { + cs.updates = append(cs.updates, &Record{ + Type: record.Type, + Id: record.Id, + Data: record.Data, + }) +} + +// PutMulti puts the records into the databroker in batches. +func PutMulti(ctx context.Context, client DataBrokerServiceClient, records ...*Record) error { + if len(records) == 0 { + return nil + } + + updates := OptimumPutRequestsFromRecords(records) + for _, req := range updates { + _, err := client.Put(ctx, req) + if err != nil { + return fmt.Errorf("put databroker record: %w", err) + } + } + return nil +} diff --git a/pkg/grpc/databroker/recordset.go b/pkg/grpc/databroker/recordset.go new file mode 100644 index 000000000..1241830b9 --- /dev/null +++ b/pkg/grpc/databroker/recordset.go @@ -0,0 +1,140 @@ +package databroker + +import ( + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +// RecordSetBundle is an index of databroker records by type +type RecordSetBundle map[string]RecordSet + +// RecordSet is an index of databroker records by their id. +type RecordSet map[string]*Record + +// RecordTypes returns the types of records in the bundle. +func (rsb RecordSetBundle) RecordTypes() []string { + types := make([]string, 0, len(rsb)) + for typ := range rsb { + types = append(types, typ) + } + return types +} + +// Add adds a record to the bundle. +func (rsb RecordSetBundle) Add(record *Record) { + rs, ok := rsb[record.GetType()] + if !ok { + rs = make(RecordSet) + rsb[record.GetType()] = rs + } + rs[record.GetId()] = record +} + +// GetAdded returns the records that are in other but not in rsb. +func (rsb RecordSetBundle) GetAdded(other RecordSetBundle) RecordSetBundle { + added := make(RecordSetBundle) + for otherType, otherRS := range other { + rs, ok := rsb[otherType] + if !ok { + added[otherType] = otherRS + continue + } + rss := rs.GetAdded(other[otherType]) + if len(rss) > 0 { + added[otherType] = rss + } + } + return added +} + +// GetRemoved returns the records that are in rs but not in other. +func (rsb RecordSetBundle) GetRemoved(other RecordSetBundle) RecordSetBundle { + return other.GetAdded(rsb) +} + +// GetModified returns the records that are in both rs and other but have different data. +func (rsb RecordSetBundle) GetModified(other RecordSetBundle) RecordSetBundle { + modified := make(RecordSetBundle) + for otherType, otherRS := range other { + rs, ok := rsb[otherType] + if !ok { + continue + } + m := rs.GetModified(otherRS) + if len(m) > 0 { + modified[otherType] = m + } + } + return modified +} + +// GetAdded returns the records that are in other but not in rs. +func (rs RecordSet) GetAdded(other RecordSet) RecordSet { + added := make(RecordSet) + for id, record := range other { + if _, ok := rs[id]; !ok { + added[id] = record + } + } + return added +} + +// GetRemoved returns the records that are in rs but not in other. +func (rs RecordSet) GetRemoved(other RecordSet) RecordSet { + return other.GetAdded(rs) +} + +// GetModified returns the records that are in both rs and other but have different data. +// by comparing the protobuf bytes of the payload. +func (rs RecordSet) GetModified(other RecordSet) RecordSet { + modified := make(RecordSet) + for id, record := range other { + otherRecord, ok := rs[id] + if !ok { + continue + } + + if !proto.Equal(record, otherRecord) { + modified[id] = record + } + } + return modified +} + +// Flatten returns all records in the set. +func (rs RecordSet) Flatten() []*Record { + records := make([]*Record, 0, len(rs)) + for _, record := range rs { + records = append(records, record) + } + return records +} + +// Flatten returns all records in the bundle. +func (rsb RecordSetBundle) Flatten() []*Record { + records := make([]*Record, 0) + for _, rs := range rsb { + records = append(records, rs.Flatten()...) + } + return records +} + +// Get returns a record by type and id. +func (rsb RecordSetBundle) Get(typeName, id string) (record *Record, ok bool) { + rs, ok := rsb[typeName] + if !ok { + return + } + record, ok = rs[id] + return +} + +// MarshalJSON marshals the record to JSON. +func (r *Record) MarshalJSON() ([]byte, error) { + return protojson.Marshal(r) +} + +// UnmarshalJSON unmarshals the record from JSON. +func (r *Record) UnmarshalJSON(data []byte) error { + return protojson.Unmarshal(data, r) +} diff --git a/pkg/grpc/databroker/recordset_test.go b/pkg/grpc/databroker/recordset_test.go new file mode 100644 index 000000000..f7277ff58 --- /dev/null +++ b/pkg/grpc/databroker/recordset_test.go @@ -0,0 +1,77 @@ +package databroker_test + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +func TestRecords(t *testing.T) { + tr := func(id, typ, val string) *databroker.Record { + return &databroker.Record{ + Id: id, + Type: typ, + Data: protoutil.NewAnyString(val), + } + } + + initial := make(databroker.RecordSetBundle) + initial.Add(tr("1", "a", "a-1")) + initial.Add(tr("2", "a", "a-2")) + initial.Add(tr("1", "b", "b-1")) + + // test record types + assert.ElementsMatch(t, []string{"a", "b"}, initial.RecordTypes()) + + // test added, deleted and modified + updated := make(databroker.RecordSetBundle) + updated.Add(tr("1", "a", "a-1-1")) + updated.Add(tr("3", "a", "a-3")) + updated.Add(tr("1", "b", "b-1")) + updated.Add(tr("2", "b", "b-2")) + updated.Add(tr("1", "c", "c-1")) + + assert.ElementsMatch(t, []string{"a", "b", "c"}, updated.RecordTypes()) + + equalJSON := func(a, b databroker.RecordSetBundle) { + t.Helper() + var txt [2]string + for i, x := range [2]databroker.RecordSetBundle{a, b} { + data, err := json.Marshal(x) + assert.NoError(t, err) + txt[i] = string(data) + } + assert.JSONEq(t, txt[0], txt[1]) + } + + added := initial.GetAdded(updated) + equalJSON(added, databroker.RecordSetBundle{ + "a": databroker.RecordSet{ + "3": tr("3", "a", "a-3"), + }, + "b": databroker.RecordSet{ + "2": tr("2", "b", "b-2"), + }, + "c": databroker.RecordSet{ + "1": tr("1", "c", "c-1"), + }, + }) + + removed := initial.GetRemoved(updated) + equalJSON(removed, databroker.RecordSetBundle{ + "a": databroker.RecordSet{ + "2": tr("2", "a", "a-2"), + }, + }) + + modified := initial.GetModified(updated) + equalJSON(modified, databroker.RecordSetBundle{ + "a": databroker.RecordSet{ + "1": tr("1", "a", "a-1-1"), + }, + }) +}