mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-07 13:22:43 +02:00
reconciler: allow custom comparison function (#4726)
This commit is contained in:
parent
ab7b66691d
commit
cc6592b6fd
4 changed files with 19 additions and 9 deletions
|
@ -9,13 +9,13 @@ import (
|
||||||
|
|
||||||
// GetChangeSet returns list of changes between the current and target record sets,
|
// GetChangeSet returns list of changes between the current and target record sets,
|
||||||
// that may be applied to the databroker to bring it to the target state.
|
// that may be applied to the databroker to bring it to the target state.
|
||||||
func GetChangeSet(current, target RecordSetBundle) []*Record {
|
func GetChangeSet(current, target RecordSetBundle, cmpFn RecordCompareFn) []*Record {
|
||||||
cs := &changeSet{now: timestamppb.Now()}
|
cs := &changeSet{now: timestamppb.Now()}
|
||||||
|
|
||||||
for _, rec := range current.GetRemoved(target).Flatten() {
|
for _, rec := range current.GetRemoved(target).Flatten() {
|
||||||
cs.Remove(rec.GetType(), rec.GetId())
|
cs.Remove(rec.GetType(), rec.GetId())
|
||||||
}
|
}
|
||||||
for _, rec := range current.GetModified(target).Flatten() {
|
for _, rec := range current.GetModified(target, cmpFn).Flatten() {
|
||||||
cs.Upsert(rec)
|
cs.Upsert(rec)
|
||||||
}
|
}
|
||||||
for _, rec := range current.GetAdded(target).Flatten() {
|
for _, rec := range current.GetAdded(target).Flatten() {
|
||||||
|
|
|
@ -17,6 +17,7 @@ type Reconciler struct {
|
||||||
name string
|
name string
|
||||||
client DataBrokerServiceClient
|
client DataBrokerServiceClient
|
||||||
currentStateBuilder StateBuilderFn
|
currentStateBuilder StateBuilderFn
|
||||||
|
cmpFn RecordCompareFn
|
||||||
targetStateBuilder StateBuilderFn
|
targetStateBuilder StateBuilderFn
|
||||||
setCurrentState func([]*Record)
|
setCurrentState func([]*Record)
|
||||||
trigger chan struct{}
|
trigger chan struct{}
|
||||||
|
@ -58,6 +59,7 @@ func NewReconciler(
|
||||||
currentStateBuilder StateBuilderFn,
|
currentStateBuilder StateBuilderFn,
|
||||||
targetStateBuilder StateBuilderFn,
|
targetStateBuilder StateBuilderFn,
|
||||||
setCurrentState func([]*Record),
|
setCurrentState func([]*Record),
|
||||||
|
cmpFn RecordCompareFn,
|
||||||
opts ...ReconcilerOption,
|
opts ...ReconcilerOption,
|
||||||
) *Reconciler {
|
) *Reconciler {
|
||||||
return &Reconciler{
|
return &Reconciler{
|
||||||
|
@ -68,6 +70,7 @@ func NewReconciler(
|
||||||
currentStateBuilder: currentStateBuilder,
|
currentStateBuilder: currentStateBuilder,
|
||||||
targetStateBuilder: targetStateBuilder,
|
targetStateBuilder: targetStateBuilder,
|
||||||
setCurrentState: setCurrentState,
|
setCurrentState: setCurrentState,
|
||||||
|
cmpFn: cmpFn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,7 +122,7 @@ func (r *Reconciler) reconcile(ctx context.Context) error {
|
||||||
return fmt.Errorf("get config record sets: %w", err)
|
return fmt.Errorf("get config record sets: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
updates := GetChangeSet(current, target)
|
updates := GetChangeSet(current, target, r.cmpFn)
|
||||||
|
|
||||||
err = r.applyChanges(ctx, updates)
|
err = r.applyChanges(ctx, updates)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -2,7 +2,6 @@ package databroker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// RecordSetBundle is an index of databroker records by type
|
// RecordSetBundle is an index of databroker records by type
|
||||||
|
@ -11,6 +10,9 @@ type RecordSetBundle map[string]RecordSet
|
||||||
// RecordSet is an index of databroker records by their id.
|
// RecordSet is an index of databroker records by their id.
|
||||||
type RecordSet map[string]*Record
|
type RecordSet map[string]*Record
|
||||||
|
|
||||||
|
// RecordCompareFn is a function that compares two records.
|
||||||
|
type RecordCompareFn func(record1, record2 *Record) bool
|
||||||
|
|
||||||
// RecordTypes returns the types of records in the bundle.
|
// RecordTypes returns the types of records in the bundle.
|
||||||
func (rsb RecordSetBundle) RecordTypes() []string {
|
func (rsb RecordSetBundle) RecordTypes() []string {
|
||||||
types := make([]string, 0, len(rsb))
|
types := make([]string, 0, len(rsb))
|
||||||
|
@ -53,14 +55,14 @@ func (rsb RecordSetBundle) GetRemoved(other RecordSetBundle) RecordSetBundle {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModified returns the records that are in both rs and other but have different data.
|
// GetModified returns the records that are in both rs and other but have different data.
|
||||||
func (rsb RecordSetBundle) GetModified(other RecordSetBundle) RecordSetBundle {
|
func (rsb RecordSetBundle) GetModified(other RecordSetBundle, cmpFn RecordCompareFn) RecordSetBundle {
|
||||||
modified := make(RecordSetBundle)
|
modified := make(RecordSetBundle)
|
||||||
for otherType, otherRS := range other {
|
for otherType, otherRS := range other {
|
||||||
rs, ok := rsb[otherType]
|
rs, ok := rsb[otherType]
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
m := rs.GetModified(otherRS)
|
m := rs.GetModified(otherRS, cmpFn)
|
||||||
if len(m) > 0 {
|
if len(m) > 0 {
|
||||||
modified[otherType] = m
|
modified[otherType] = m
|
||||||
}
|
}
|
||||||
|
@ -86,7 +88,7 @@ func (rs RecordSet) GetRemoved(other RecordSet) RecordSet {
|
||||||
|
|
||||||
// GetModified returns the records that are in both rs and other but have different data.
|
// GetModified returns the records that are in both rs and other but have different data.
|
||||||
// by comparing the protobuf bytes of the payload.
|
// by comparing the protobuf bytes of the payload.
|
||||||
func (rs RecordSet) GetModified(other RecordSet) RecordSet {
|
func (rs RecordSet) GetModified(other RecordSet, cmpFn RecordCompareFn) RecordSet {
|
||||||
modified := make(RecordSet)
|
modified := make(RecordSet)
|
||||||
for id, record := range other {
|
for id, record := range other {
|
||||||
otherRecord, ok := rs[id]
|
otherRecord, ok := rs[id]
|
||||||
|
@ -94,7 +96,7 @@ func (rs RecordSet) GetModified(other RecordSet) RecordSet {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !proto.Equal(record, otherRecord) {
|
if !cmpFn(record, otherRecord) {
|
||||||
modified[id] = record
|
modified[id] = record
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
@ -19,6 +20,10 @@ func TestRecords(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmpFn := func(a, b *databroker.Record) bool {
|
||||||
|
return proto.Equal(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
initial := make(databroker.RecordSetBundle)
|
initial := make(databroker.RecordSetBundle)
|
||||||
initial.Add(tr("1", "a", "a-1"))
|
initial.Add(tr("1", "a", "a-1"))
|
||||||
initial.Add(tr("2", "a", "a-2"))
|
initial.Add(tr("2", "a", "a-2"))
|
||||||
|
@ -68,7 +73,7 @@ func TestRecords(t *testing.T) {
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
modified := initial.GetModified(updated)
|
modified := initial.GetModified(updated, cmpFn)
|
||||||
equalJSON(modified, databroker.RecordSetBundle{
|
equalJSON(modified, databroker.RecordSetBundle{
|
||||||
"a": databroker.RecordSet{
|
"a": databroker.RecordSet{
|
||||||
"1": tr("1", "a", "a-1-1"),
|
"1": tr("1", "a", "a-1-1"),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue