databroker: add utility recordset and changeset (#4701)

This commit is contained in:
Denis Mishin 2023-11-03 11:26:59 -04:00 committed by GitHub
parent 45b72bc9b5
commit 6d5558cb97
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 283 additions and 0 deletions

View file

@ -0,0 +1,66 @@
package databroker
import (
"context"
"fmt"
"google.golang.org/protobuf/types/known/timestamppb"
)
// GetChangeSet returns list of changes between the current and target record sets,
// that may be applied to the databroker to bring it to the target state.
func GetChangeSet(current, target RecordSetBundle) []*Record {
cs := &changeSet{now: timestamppb.Now()}
for _, rec := range current.GetRemoved(target).Flatten() {
cs.Remove(rec.GetType(), rec.GetId())
}
for _, rec := range current.GetModified(target).Flatten() {
cs.Upsert(rec)
}
for _, rec := range current.GetAdded(target).Flatten() {
cs.Upsert(rec)
}
return cs.updates
}
// changeSet is a set of databroker changes.
type changeSet struct {
now *timestamppb.Timestamp
updates []*Record
}
// Remove adds a record to the change set.
func (cs *changeSet) Remove(typ string, id string) {
cs.updates = append(cs.updates, &Record{
Type: typ,
Id: id,
DeletedAt: cs.now,
})
}
// Upsert adds a record to the change set.
func (cs *changeSet) Upsert(record *Record) {
cs.updates = append(cs.updates, &Record{
Type: record.Type,
Id: record.Id,
Data: record.Data,
})
}
// PutMulti puts the records into the databroker in batches.
func PutMulti(ctx context.Context, client DataBrokerServiceClient, records ...*Record) error {
if len(records) == 0 {
return nil
}
updates := OptimumPutRequestsFromRecords(records)
for _, req := range updates {
_, err := client.Put(ctx, req)
if err != nil {
return fmt.Errorf("put databroker record: %w", err)
}
}
return nil
}

View file

@ -0,0 +1,140 @@
package databroker
import (
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)
// RecordSetBundle is an index of databroker records by type
type RecordSetBundle map[string]RecordSet
// RecordSet is an index of databroker records by their id.
type RecordSet map[string]*Record
// RecordTypes returns the types of records in the bundle.
func (rsb RecordSetBundle) RecordTypes() []string {
types := make([]string, 0, len(rsb))
for typ := range rsb {
types = append(types, typ)
}
return types
}
// Add adds a record to the bundle.
func (rsb RecordSetBundle) Add(record *Record) {
rs, ok := rsb[record.GetType()]
if !ok {
rs = make(RecordSet)
rsb[record.GetType()] = rs
}
rs[record.GetId()] = record
}
// GetAdded returns the records that are in other but not in rsb.
func (rsb RecordSetBundle) GetAdded(other RecordSetBundle) RecordSetBundle {
added := make(RecordSetBundle)
for otherType, otherRS := range other {
rs, ok := rsb[otherType]
if !ok {
added[otherType] = otherRS
continue
}
rss := rs.GetAdded(other[otherType])
if len(rss) > 0 {
added[otherType] = rss
}
}
return added
}
// GetRemoved returns the records that are in rs but not in other.
func (rsb RecordSetBundle) GetRemoved(other RecordSetBundle) RecordSetBundle {
return other.GetAdded(rsb)
}
// GetModified returns the records that are in both rs and other but have different data.
func (rsb RecordSetBundle) GetModified(other RecordSetBundle) RecordSetBundle {
modified := make(RecordSetBundle)
for otherType, otherRS := range other {
rs, ok := rsb[otherType]
if !ok {
continue
}
m := rs.GetModified(otherRS)
if len(m) > 0 {
modified[otherType] = m
}
}
return modified
}
// GetAdded returns the records that are in other but not in rs.
func (rs RecordSet) GetAdded(other RecordSet) RecordSet {
added := make(RecordSet)
for id, record := range other {
if _, ok := rs[id]; !ok {
added[id] = record
}
}
return added
}
// GetRemoved returns the records that are in rs but not in other.
func (rs RecordSet) GetRemoved(other RecordSet) RecordSet {
return other.GetAdded(rs)
}
// GetModified returns the records that are in both rs and other but have different data.
// by comparing the protobuf bytes of the payload.
func (rs RecordSet) GetModified(other RecordSet) RecordSet {
modified := make(RecordSet)
for id, record := range other {
otherRecord, ok := rs[id]
if !ok {
continue
}
if !proto.Equal(record, otherRecord) {
modified[id] = record
}
}
return modified
}
// Flatten returns all records in the set.
func (rs RecordSet) Flatten() []*Record {
records := make([]*Record, 0, len(rs))
for _, record := range rs {
records = append(records, record)
}
return records
}
// Flatten returns all records in the bundle.
func (rsb RecordSetBundle) Flatten() []*Record {
records := make([]*Record, 0)
for _, rs := range rsb {
records = append(records, rs.Flatten()...)
}
return records
}
// Get returns a record by type and id.
func (rsb RecordSetBundle) Get(typeName, id string) (record *Record, ok bool) {
rs, ok := rsb[typeName]
if !ok {
return
}
record, ok = rs[id]
return
}
// MarshalJSON marshals the record to JSON.
func (r *Record) MarshalJSON() ([]byte, error) {
return protojson.Marshal(r)
}
// UnmarshalJSON unmarshals the record from JSON.
func (r *Record) UnmarshalJSON(data []byte) error {
return protojson.Unmarshal(data, r)
}

View file

@ -0,0 +1,77 @@
package databroker_test
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestRecords(t *testing.T) {
tr := func(id, typ, val string) *databroker.Record {
return &databroker.Record{
Id: id,
Type: typ,
Data: protoutil.NewAnyString(val),
}
}
initial := make(databroker.RecordSetBundle)
initial.Add(tr("1", "a", "a-1"))
initial.Add(tr("2", "a", "a-2"))
initial.Add(tr("1", "b", "b-1"))
// test record types
assert.ElementsMatch(t, []string{"a", "b"}, initial.RecordTypes())
// test added, deleted and modified
updated := make(databroker.RecordSetBundle)
updated.Add(tr("1", "a", "a-1-1"))
updated.Add(tr("3", "a", "a-3"))
updated.Add(tr("1", "b", "b-1"))
updated.Add(tr("2", "b", "b-2"))
updated.Add(tr("1", "c", "c-1"))
assert.ElementsMatch(t, []string{"a", "b", "c"}, updated.RecordTypes())
equalJSON := func(a, b databroker.RecordSetBundle) {
t.Helper()
var txt [2]string
for i, x := range [2]databroker.RecordSetBundle{a, b} {
data, err := json.Marshal(x)
assert.NoError(t, err)
txt[i] = string(data)
}
assert.JSONEq(t, txt[0], txt[1])
}
added := initial.GetAdded(updated)
equalJSON(added, databroker.RecordSetBundle{
"a": databroker.RecordSet{
"3": tr("3", "a", "a-3"),
},
"b": databroker.RecordSet{
"2": tr("2", "b", "b-2"),
},
"c": databroker.RecordSet{
"1": tr("1", "c", "c-1"),
},
})
removed := initial.GetRemoved(updated)
equalJSON(removed, databroker.RecordSetBundle{
"a": databroker.RecordSet{
"2": tr("2", "a", "a-2"),
},
})
modified := initial.GetModified(updated)
equalJSON(modified, databroker.RecordSetBundle{
"a": databroker.RecordSet{
"1": tr("1", "a", "a-1-1"),
},
})
}