From cc6592b6fd77afea09972c6750b98fd715f435ed Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Wed, 8 Nov 2023 20:11:49 -0500 Subject: [PATCH] reconciler: allow custom comparison function (#4726) --- pkg/grpc/databroker/changeset.go | 4 ++-- pkg/grpc/databroker/reconciler.go | 5 ++++- pkg/grpc/databroker/recordset.go | 12 +++++++----- pkg/grpc/databroker/recordset_test.go | 7 ++++++- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/pkg/grpc/databroker/changeset.go b/pkg/grpc/databroker/changeset.go index 20b6889f8..a5a0b4211 100644 --- a/pkg/grpc/databroker/changeset.go +++ b/pkg/grpc/databroker/changeset.go @@ -9,13 +9,13 @@ import ( // GetChangeSet returns list of changes between the current and target record sets, // that may be applied to the databroker to bring it to the target state. -func GetChangeSet(current, target RecordSetBundle) []*Record { +func GetChangeSet(current, target RecordSetBundle, cmpFn RecordCompareFn) []*Record { cs := &changeSet{now: timestamppb.Now()} for _, rec := range current.GetRemoved(target).Flatten() { cs.Remove(rec.GetType(), rec.GetId()) } - for _, rec := range current.GetModified(target).Flatten() { + for _, rec := range current.GetModified(target, cmpFn).Flatten() { cs.Upsert(rec) } for _, rec := range current.GetAdded(target).Flatten() { diff --git a/pkg/grpc/databroker/reconciler.go b/pkg/grpc/databroker/reconciler.go index af83fe475..00387dd3b 100644 --- a/pkg/grpc/databroker/reconciler.go +++ b/pkg/grpc/databroker/reconciler.go @@ -17,6 +17,7 @@ type Reconciler struct { name string client DataBrokerServiceClient currentStateBuilder StateBuilderFn + cmpFn RecordCompareFn targetStateBuilder StateBuilderFn setCurrentState func([]*Record) trigger chan struct{} @@ -58,6 +59,7 @@ func NewReconciler( currentStateBuilder StateBuilderFn, targetStateBuilder StateBuilderFn, setCurrentState func([]*Record), + cmpFn RecordCompareFn, opts ...ReconcilerOption, ) *Reconciler { return &Reconciler{ @@ -68,6 +70,7 @@ func NewReconciler( currentStateBuilder: currentStateBuilder, targetStateBuilder: targetStateBuilder, setCurrentState: setCurrentState, + cmpFn: cmpFn, } } @@ -119,7 +122,7 @@ func (r *Reconciler) reconcile(ctx context.Context) error { return fmt.Errorf("get config record sets: %w", err) } - updates := GetChangeSet(current, target) + updates := GetChangeSet(current, target, r.cmpFn) err = r.applyChanges(ctx, updates) if err != nil { diff --git a/pkg/grpc/databroker/recordset.go b/pkg/grpc/databroker/recordset.go index 1241830b9..75f56184e 100644 --- a/pkg/grpc/databroker/recordset.go +++ b/pkg/grpc/databroker/recordset.go @@ -2,7 +2,6 @@ package databroker import ( "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" ) // RecordSetBundle is an index of databroker records by type @@ -11,6 +10,9 @@ type RecordSetBundle map[string]RecordSet // RecordSet is an index of databroker records by their id. type RecordSet map[string]*Record +// RecordCompareFn is a function that compares two records. +type RecordCompareFn func(record1, record2 *Record) bool + // RecordTypes returns the types of records in the bundle. func (rsb RecordSetBundle) RecordTypes() []string { types := make([]string, 0, len(rsb)) @@ -53,14 +55,14 @@ func (rsb RecordSetBundle) GetRemoved(other RecordSetBundle) RecordSetBundle { } // GetModified returns the records that are in both rs and other but have different data. -func (rsb RecordSetBundle) GetModified(other RecordSetBundle) RecordSetBundle { +func (rsb RecordSetBundle) GetModified(other RecordSetBundle, cmpFn RecordCompareFn) RecordSetBundle { modified := make(RecordSetBundle) for otherType, otherRS := range other { rs, ok := rsb[otherType] if !ok { continue } - m := rs.GetModified(otherRS) + m := rs.GetModified(otherRS, cmpFn) if len(m) > 0 { modified[otherType] = m } @@ -86,7 +88,7 @@ func (rs RecordSet) GetRemoved(other RecordSet) RecordSet { // GetModified returns the records that are in both rs and other but have different data. // by comparing the protobuf bytes of the payload. -func (rs RecordSet) GetModified(other RecordSet) RecordSet { +func (rs RecordSet) GetModified(other RecordSet, cmpFn RecordCompareFn) RecordSet { modified := make(RecordSet) for id, record := range other { otherRecord, ok := rs[id] @@ -94,7 +96,7 @@ func (rs RecordSet) GetModified(other RecordSet) RecordSet { continue } - if !proto.Equal(record, otherRecord) { + if !cmpFn(record, otherRecord) { modified[id] = record } } diff --git a/pkg/grpc/databroker/recordset_test.go b/pkg/grpc/databroker/recordset_test.go index f7277ff58..1a7bccd83 100644 --- a/pkg/grpc/databroker/recordset_test.go +++ b/pkg/grpc/databroker/recordset_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/protoutil" @@ -19,6 +20,10 @@ func TestRecords(t *testing.T) { } } + cmpFn := func(a, b *databroker.Record) bool { + return proto.Equal(a, b) + } + initial := make(databroker.RecordSetBundle) initial.Add(tr("1", "a", "a-1")) initial.Add(tr("2", "a", "a-2")) @@ -68,7 +73,7 @@ func TestRecords(t *testing.T) { }, }) - modified := initial.GetModified(updated) + modified := initial.GetModified(updated, cmpFn) equalJSON(modified, databroker.RecordSetBundle{ "a": databroker.RecordSet{ "1": tr("1", "a", "a-1-1"),