reconciler: allow custom comparison function (#4726)

This commit is contained in:
Denis Mishin 2023-11-08 20:11:49 -05:00 committed by GitHub
parent ab7b66691d
commit cc6592b6fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 9 deletions

View file

@ -9,13 +9,13 @@ import (
// GetChangeSet returns list of changes between the current and target record sets, // GetChangeSet returns list of changes between the current and target record sets,
// that may be applied to the databroker to bring it to the target state. // that may be applied to the databroker to bring it to the target state.
func GetChangeSet(current, target RecordSetBundle) []*Record { func GetChangeSet(current, target RecordSetBundle, cmpFn RecordCompareFn) []*Record {
cs := &changeSet{now: timestamppb.Now()} cs := &changeSet{now: timestamppb.Now()}
for _, rec := range current.GetRemoved(target).Flatten() { for _, rec := range current.GetRemoved(target).Flatten() {
cs.Remove(rec.GetType(), rec.GetId()) cs.Remove(rec.GetType(), rec.GetId())
} }
for _, rec := range current.GetModified(target).Flatten() { for _, rec := range current.GetModified(target, cmpFn).Flatten() {
cs.Upsert(rec) cs.Upsert(rec)
} }
for _, rec := range current.GetAdded(target).Flatten() { for _, rec := range current.GetAdded(target).Flatten() {

View file

@ -17,6 +17,7 @@ type Reconciler struct {
name string name string
client DataBrokerServiceClient client DataBrokerServiceClient
currentStateBuilder StateBuilderFn currentStateBuilder StateBuilderFn
cmpFn RecordCompareFn
targetStateBuilder StateBuilderFn targetStateBuilder StateBuilderFn
setCurrentState func([]*Record) setCurrentState func([]*Record)
trigger chan struct{} trigger chan struct{}
@ -58,6 +59,7 @@ func NewReconciler(
currentStateBuilder StateBuilderFn, currentStateBuilder StateBuilderFn,
targetStateBuilder StateBuilderFn, targetStateBuilder StateBuilderFn,
setCurrentState func([]*Record), setCurrentState func([]*Record),
cmpFn RecordCompareFn,
opts ...ReconcilerOption, opts ...ReconcilerOption,
) *Reconciler { ) *Reconciler {
return &Reconciler{ return &Reconciler{
@ -68,6 +70,7 @@ func NewReconciler(
currentStateBuilder: currentStateBuilder, currentStateBuilder: currentStateBuilder,
targetStateBuilder: targetStateBuilder, targetStateBuilder: targetStateBuilder,
setCurrentState: setCurrentState, setCurrentState: setCurrentState,
cmpFn: cmpFn,
} }
} }
@ -119,7 +122,7 @@ func (r *Reconciler) reconcile(ctx context.Context) error {
return fmt.Errorf("get config record sets: %w", err) return fmt.Errorf("get config record sets: %w", err)
} }
updates := GetChangeSet(current, target) updates := GetChangeSet(current, target, r.cmpFn)
err = r.applyChanges(ctx, updates) err = r.applyChanges(ctx, updates)
if err != nil { if err != nil {

View file

@ -2,7 +2,6 @@ package databroker
import ( import (
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
) )
// RecordSetBundle is an index of databroker records by type // RecordSetBundle is an index of databroker records by type
@ -11,6 +10,9 @@ type RecordSetBundle map[string]RecordSet
// RecordSet is an index of databroker records by their id. // RecordSet is an index of databroker records by their id.
type RecordSet map[string]*Record type RecordSet map[string]*Record
// RecordCompareFn is a function that compares two records.
type RecordCompareFn func(record1, record2 *Record) bool
// RecordTypes returns the types of records in the bundle. // RecordTypes returns the types of records in the bundle.
func (rsb RecordSetBundle) RecordTypes() []string { func (rsb RecordSetBundle) RecordTypes() []string {
types := make([]string, 0, len(rsb)) types := make([]string, 0, len(rsb))
@ -53,14 +55,14 @@ func (rsb RecordSetBundle) GetRemoved(other RecordSetBundle) RecordSetBundle {
} }
// GetModified returns the records that are in both rs and other but have different data. // GetModified returns the records that are in both rs and other but have different data.
func (rsb RecordSetBundle) GetModified(other RecordSetBundle) RecordSetBundle { func (rsb RecordSetBundle) GetModified(other RecordSetBundle, cmpFn RecordCompareFn) RecordSetBundle {
modified := make(RecordSetBundle) modified := make(RecordSetBundle)
for otherType, otherRS := range other { for otherType, otherRS := range other {
rs, ok := rsb[otherType] rs, ok := rsb[otherType]
if !ok { if !ok {
continue continue
} }
m := rs.GetModified(otherRS) m := rs.GetModified(otherRS, cmpFn)
if len(m) > 0 { if len(m) > 0 {
modified[otherType] = m modified[otherType] = m
} }
@ -86,7 +88,7 @@ func (rs RecordSet) GetRemoved(other RecordSet) RecordSet {
// GetModified returns the records that are in both rs and other but have different data. // GetModified returns the records that are in both rs and other but have different data.
// by comparing the protobuf bytes of the payload. // by comparing the protobuf bytes of the payload.
func (rs RecordSet) GetModified(other RecordSet) RecordSet { func (rs RecordSet) GetModified(other RecordSet, cmpFn RecordCompareFn) RecordSet {
modified := make(RecordSet) modified := make(RecordSet)
for id, record := range other { for id, record := range other {
otherRecord, ok := rs[id] otherRecord, ok := rs[id]
@ -94,7 +96,7 @@ func (rs RecordSet) GetModified(other RecordSet) RecordSet {
continue continue
} }
if !proto.Equal(record, otherRecord) { if !cmpFn(record, otherRecord) {
modified[id] = record modified[id] = record
} }
} }

View file

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil" "github.com/pomerium/pomerium/pkg/protoutil"
@ -19,6 +20,10 @@ func TestRecords(t *testing.T) {
} }
} }
cmpFn := func(a, b *databroker.Record) bool {
return proto.Equal(a, b)
}
initial := make(databroker.RecordSetBundle) initial := make(databroker.RecordSetBundle)
initial.Add(tr("1", "a", "a-1")) initial.Add(tr("1", "a", "a-1"))
initial.Add(tr("2", "a", "a-2")) initial.Add(tr("2", "a", "a-2"))
@ -68,7 +73,7 @@ func TestRecords(t *testing.T) {
}, },
}) })
modified := initial.GetModified(updated) modified := initial.GetModified(updated, cmpFn)
equalJSON(modified, databroker.RecordSetBundle{ equalJSON(modified, databroker.RecordSetBundle{
"a": databroker.RecordSet{ "a": databroker.RecordSet{
"1": tr("1", "a", "a-1-1"), "1": tr("1", "a", "a-1-1"),