zero/bundle-reconciler: better code reuse (#4758)

This commit is contained in:
Denis Mishin 2023-11-21 14:32:52 -05:00 committed by GitHub
parent 14b13bb791
commit 7e2532f644
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 31 additions and 474 deletions

View file

@ -15,9 +15,9 @@ var unmarshalOpts = protodelim.UnmarshalOptions{}
// ReadBundleRecords reads records in a protobuf wire format from src.
// Each record is expected to be a databroker.Record.
func ReadBundleRecords(src io.Reader) (RecordSetBundle[DatabrokerRecord], error) {
func ReadBundleRecords(src io.Reader) (databroker.RecordSetBundle, error) {
r := bufio.NewReader(src)
rsb := make(RecordSetBundle[DatabrokerRecord])
rsb := make(databroker.RecordSetBundle)
for {
record := new(databroker.Record)
err := unmarshalOpts.UnmarshalFrom(r, record)
@ -28,7 +28,7 @@ func ReadBundleRecords(src io.Reader) (RecordSetBundle[DatabrokerRecord], error)
return nil, fmt.Errorf("error reading protobuf record: %w", err)
}
rsb.Add(DatabrokerRecord{record})
rsb.Add(record)
}
return rsb, nil

View file

@ -11,28 +11,11 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// DatabrokerRecord is a wrapper around a databroker record.
type DatabrokerRecord struct {
V *databroker.Record
}
var _ Record[DatabrokerRecord] = DatabrokerRecord{}
// GetID returns the databroker record's ID.
func (r DatabrokerRecord) GetID() string {
return r.V.GetId()
}
// GetType returns the databroker record's type.
func (r DatabrokerRecord) GetType() string {
return r.V.GetType()
}
// Equal returns true if the databroker records are equal.
func (r DatabrokerRecord) Equal(other DatabrokerRecord) bool {
return r.V.Type == other.V.Type &&
r.V.Id == other.V.Id &&
proto.Equal(r.V.Data, other.V.Data)
// EqualRecord returns true if the databroker records are equal.
func EqualRecord(a, b *databroker.Record) bool {
return a.Type == b.Type &&
a.Id == b.Id &&
proto.Equal(a.Data, b.Data)
}
// GetDatabrokerRecords gets all databroker records of the given types.
@ -40,8 +23,8 @@ func GetDatabrokerRecords(
ctx context.Context,
client databroker.DataBrokerServiceClient,
types []string,
) (RecordSetBundle[DatabrokerRecord], error) {
rsb := make(RecordSetBundle[DatabrokerRecord])
) (databroker.RecordSetBundle, error) {
rsb := make(databroker.RecordSetBundle)
for _, typ := range types {
recs, err := getDatabrokerRecords(ctx, client, typ)
@ -58,13 +41,13 @@ func getDatabrokerRecords(
ctx context.Context,
client databroker.DataBrokerServiceClient,
typ string,
) (RecordSet[DatabrokerRecord], error) {
) (databroker.RecordSet, error) {
stream, err := client.SyncLatest(ctx, &databroker.SyncLatestRequest{Type: typ})
if err != nil {
return nil, fmt.Errorf("sync latest databroker: %w", err)
}
recordSet := make(RecordSet[DatabrokerRecord])
recordSet := make(databroker.RecordSet)
for {
res, err := stream.Recv()
if errors.Is(err, io.EOF) {
@ -74,7 +57,7 @@ func getDatabrokerRecords(
}
if record := res.GetRecord(); record != nil {
recordSet[record.GetId()] = DatabrokerRecord{record}
recordSet[record.GetId()] = record
}
}
return recordSet, nil

View file

@ -1,34 +0,0 @@
package reconciler
import (
"context"
"fmt"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// Reconcile reconciles the target and current record sets with the databroker.
func Reconcile(
ctx context.Context,
client databroker.DataBrokerServiceClient,
target, current RecordSetBundle[DatabrokerRecord],
) error {
updates := NewDatabrokerChangeSet()
for _, rec := range current.GetRemoved(target).Flatten() {
updates.Remove(rec.GetType(), rec.GetID())
}
for _, rec := range current.GetModified(target).Flatten() {
updates.Upsert(rec.V)
}
for _, rec := range current.GetAdded(target).Flatten() {
updates.Upsert(rec.V)
}
err := ApplyChanges(ctx, client, updates)
if err != nil {
return fmt.Errorf("apply databroker changes: %w", err)
}
return nil
}

View file

@ -1,196 +0,0 @@
package reconciler_test
import (
"context"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
"google.golang.org/protobuf/types/known/wrapperspb"
databroker_int "github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/zero/reconciler"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func newDatabroker(t *testing.T) (context.Context, databroker.DataBrokerServiceClient) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
gs := grpc.NewServer()
srv := databroker_int.New()
databroker.RegisterDataBrokerServiceServer(gs, srv)
lis := bufconn.Listen(1)
t.Cleanup(func() {
lis.Close()
gs.Stop()
})
go func() { _ = gs.Serve(lis) }()
conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(func(context.Context, string) (conn net.Conn, e error) {
return lis.Dial()
}), grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
t.Cleanup(func() { _ = conn.Close() })
return ctx, databroker.NewDataBrokerServiceClient(conn)
}
func newRecordBundle(records []testRecord) reconciler.RecordSetBundle[reconciler.DatabrokerRecord] {
bundle := make(reconciler.RecordSetBundle[reconciler.DatabrokerRecord])
for _, r := range records {
bundle.Add(newRecord(r))
}
return bundle
}
func newRecord(r testRecord) reconciler.DatabrokerRecord {
return reconciler.DatabrokerRecord{
V: &databroker.Record{
Type: r.Type,
Id: r.ID,
Data: protoutil.NewAnyString(r.Val),
}}
}
func assertBundle(t *testing.T, want []testRecord, got reconciler.RecordSetBundle[reconciler.DatabrokerRecord]) {
t.Helper()
for _, wantRecord := range want {
gotRecord, ok := got.Get(wantRecord.Type, wantRecord.ID)
if assert.True(t, ok, "record %s/%s not found", wantRecord.Type, wantRecord.ID) {
assertRecord(t, wantRecord, gotRecord)
}
}
assert.Len(t, got.Flatten(), len(want))
}
func assertRecord(t *testing.T, want testRecord, got reconciler.DatabrokerRecord) {
t.Helper()
var val wrapperspb.StringValue
err := got.V.Data.UnmarshalTo(&val)
require.NoError(t, err)
assert.Equal(t, want.Type, got.V.Type)
assert.Equal(t, want.ID, got.V.Id)
assert.Equal(t, want.Val, val.Value)
}
func TestHelpers(t *testing.T) {
want := []testRecord{
{"type1", "id1", "value1"},
{"type1", "id2", "value2"},
}
bundle := newRecordBundle(want)
assertBundle(t, want, bundle)
}
func wantRemoved(want, current []string) []string {
wantM := make(map[string]struct{}, len(want))
for _, w := range want {
wantM[w] = struct{}{}
}
var toRemove []string
for _, c := range current {
if _, ok := wantM[c]; !ok {
toRemove = append(toRemove, c)
}
}
return toRemove
}
func reconcile(
ctx context.Context,
t *testing.T,
client databroker.DataBrokerServiceClient,
want []testRecord,
current reconciler.RecordSetBundle[reconciler.DatabrokerRecord],
) reconciler.RecordSetBundle[reconciler.DatabrokerRecord] {
t.Helper()
wantBundle := newRecordBundle(want)
err := reconciler.Reconcile(ctx, client, wantBundle, current)
require.NoError(t, err)
got, err := reconciler.GetDatabrokerRecords(ctx, client, wantBundle.RecordTypes())
require.NoError(t, err)
assertBundle(t, want, got)
res, err := reconciler.GetDatabrokerRecords(ctx, client, wantRemoved(wantBundle.RecordTypes(), current.RecordTypes()))
require.NoError(t, err)
assert.Empty(t, res.Flatten())
return got
}
func TestReconcile(t *testing.T) {
t.Parallel()
ctx, client := newDatabroker(t)
err := reconciler.Reconcile(ctx, client, nil, nil)
require.NoError(t, err)
var current reconciler.RecordSetBundle[reconciler.DatabrokerRecord]
for _, tc := range []struct {
name string
want []testRecord
}{
{"empty", nil},
{"initial", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id2", "value2"},
}},
{"add one", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id2", "value2"},
{"type1", "id3", "value3"},
}},
{"update one", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id2", "value2-updated"},
{"type1", "id3", "value3"},
}},
{"delete one", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id3", "value3"},
}},
{"delete all", nil},
{"multiple types", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id2", "value2"},
{"type2", "id1", "value1"},
{"type2", "id2", "value2"},
}},
{"multiple types update", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id2", "value2-updated"},
{"type2", "id1", "value1"},
{"type2", "id2", "value2-updated"},
}},
{"multiple types delete", []testRecord{
{"type1", "id1", "value1"},
{"type2", "id1", "value1"},
}},
{"multiple types delete one type, add one value", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id4", "value4"},
}},
} {
t.Run(tc.name, func(t *testing.T) {
current = reconcile(ctx, t, client, tc.want, current)
})
}
}

View file

@ -1,132 +0,0 @@
package reconciler
// RecordSetBundle is an index of databroker records by type
type RecordSetBundle[T Record[T]] map[string]RecordSet[T]
// RecordSet is an index of databroker records by their id.
type RecordSet[T Record[T]] map[string]T
// Record is a record
type Record[T any] interface {
GetID() string
GetType() string
Equal(other T) bool
}
// RecordTypes returns the types of records in the bundle.
func (rsb RecordSetBundle[T]) RecordTypes() []string {
types := make([]string, 0, len(rsb))
for typ := range rsb {
types = append(types, typ)
}
return types
}
// Add adds a record to the bundle.
func (rsb RecordSetBundle[T]) Add(record T) {
rs, ok := rsb[record.GetType()]
if !ok {
rs = make(RecordSet[T])
rsb[record.GetType()] = rs
}
rs[record.GetID()] = record
}
// GetAdded returns the records that are in other but not in rsb.
func (rsb RecordSetBundle[T]) GetAdded(other RecordSetBundle[T]) RecordSetBundle[T] {
added := make(RecordSetBundle[T])
for otherType, otherRS := range other {
rs, ok := rsb[otherType]
if !ok {
added[otherType] = otherRS
continue
}
rss := rs.GetAdded(other[otherType])
if len(rss) > 0 {
added[otherType] = rss
}
}
return added
}
// GetRemoved returns the records that are in rs but not in other.
func (rsb RecordSetBundle[T]) GetRemoved(other RecordSetBundle[T]) RecordSetBundle[T] {
return other.GetAdded(rsb)
}
// GetModified returns the records that are in both rs and other but have different data.
func (rsb RecordSetBundle[T]) GetModified(other RecordSetBundle[T]) RecordSetBundle[T] {
modified := make(RecordSetBundle[T])
for otherType, otherRS := range other {
rs, ok := rsb[otherType]
if !ok {
continue
}
m := rs.GetModified(otherRS)
if len(m) > 0 {
modified[otherType] = m
}
}
return modified
}
// GetAdded returns the records that are in other but not in rs.
func (rs RecordSet[T]) GetAdded(other RecordSet[T]) RecordSet[T] {
added := make(RecordSet[T])
for id, record := range other {
if _, ok := rs[id]; !ok {
added[id] = record
}
}
return added
}
// GetRemoved returns the records that are in rs but not in other.
func (rs RecordSet[T]) GetRemoved(other RecordSet[T]) RecordSet[T] {
return other.GetAdded(rs)
}
// GetModified returns the records that are in both rs and other but have different data.
// by comparing the protobuf bytes of the payload.
func (rs RecordSet[T]) GetModified(other RecordSet[T]) RecordSet[T] {
modified := make(RecordSet[T])
for id, record := range other {
otherRecord, ok := rs[id]
if !ok {
continue
}
if !record.Equal(otherRecord) {
modified[id] = record
}
}
return modified
}
// Flatten returns all records in the set.
func (rs RecordSet[T]) Flatten() []T {
records := make([]T, 0, len(rs))
for _, record := range rs {
records = append(records, record)
}
return records
}
// Flatten returns all records in the bundle.
func (rsb RecordSetBundle[T]) Flatten() []T {
records := make([]T, 0)
for _, rs := range rsb {
records = append(records, rs.Flatten()...)
}
return records
}
// Get returns a record by type and id.
func (rsb RecordSetBundle[T]) Get(typeName, id string) (record T, ok bool) {
rs, ok := rsb[typeName]
if !ok {
return
}
record, ok = rs[id]
return
}

View file

@ -1,77 +0,0 @@
package reconciler_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/zero/reconciler"
)
type testRecord struct {
Type string
ID string
Val string
}
func (r testRecord) GetID() string {
return r.ID
}
func (r testRecord) GetType() string {
return r.Type
}
func (r testRecord) Equal(other testRecord) bool {
return r.ID == other.ID && r.Type == other.Type && r.Val == other.Val
}
func TestRecords(t *testing.T) {
initial := make(reconciler.RecordSetBundle[testRecord])
initial.Add(testRecord{ID: "1", Type: "a", Val: "a-1"})
initial.Add(testRecord{ID: "2", Type: "a", Val: "a-2"})
initial.Add(testRecord{ID: "1", Type: "b", Val: "b-1"})
// test record types
assert.ElementsMatch(t, []string{"a", "b"}, initial.RecordTypes())
// test added, deleted and modified
updated := make(reconciler.RecordSetBundle[testRecord])
updated.Add(testRecord{ID: "1", Type: "a", Val: "a-1-1"})
updated.Add(testRecord{ID: "3", Type: "a", Val: "a-3"})
updated.Add(testRecord{ID: "1", Type: "b", Val: "b-1"})
updated.Add(testRecord{ID: "2", Type: "b", Val: "b-2"})
updated.Add(testRecord{ID: "1", Type: "c", Val: "c-1"})
assert.ElementsMatch(t, []string{"a", "b", "c"}, updated.RecordTypes())
added := initial.GetAdded(updated)
assert.Equal(t,
reconciler.RecordSetBundle[testRecord]{
"a": reconciler.RecordSet[testRecord]{
"3": {ID: "3", Type: "a", Val: "a-3"},
},
"b": reconciler.RecordSet[testRecord]{
"2": {ID: "2", Type: "b", Val: "b-2"},
},
"c": reconciler.RecordSet[testRecord]{
"1": {ID: "1", Type: "c", Val: "c-1"},
},
}, added)
removed := initial.GetRemoved(updated)
assert.Equal(t,
reconciler.RecordSetBundle[testRecord]{
"a": reconciler.RecordSet[testRecord]{
"2": {ID: "2", Type: "a", Val: "a-2"},
},
}, removed)
modified := initial.GetModified(updated)
assert.Equal(t,
reconciler.RecordSetBundle[testRecord]{
"a": reconciler.RecordSet[testRecord]{
"1": {ID: "1", Type: "a", Val: "a-1-1"},
},
}, modified)
}

View file

@ -20,6 +20,7 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/retry"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// Sync synchronizes the bundles between their cloud source and the databroker.
@ -154,7 +155,7 @@ func (c *service) syncBundle(ctx context.Context, key string) error {
return fmt.Errorf("seek to start: %w", err)
}
bundleRecordTypes, err := c.syncBundleToDatabroker(ctx, fd, cached.GetRecordTypes())
bundleRecordTypes, err := c.syncBundleToDatabroker(ctx, key, fd, cached.GetRecordTypes())
if err != nil {
c.ReportBundleAppliedFailure(ctx, key, BundleStatusFailureDatabrokerError, err)
return fmt.Errorf("apply bundle to databroker: %w", err)
@ -198,7 +199,7 @@ func strUnion(a, b []string) []string {
return out
}
func (c *service) syncBundleToDatabroker(ctx context.Context, src io.Reader, currentRecordTypes []string) ([]string, error) {
func (c *service) syncBundleToDatabroker(ctx context.Context, key string, src io.Reader, currentRecordTypes []string) ([]string, error) {
bundleRecords, err := ReadBundleRecords(src)
if err != nil {
return nil, fmt.Errorf("read bundle records: %w", err)
@ -212,7 +213,18 @@ func (c *service) syncBundleToDatabroker(ctx context.Context, src io.Reader, cur
return nil, fmt.Errorf("get databroker records: %w", err)
}
err = Reconcile(ctx, c.config.databrokerClient, bundleRecords, databrokerRecords)
err = databroker.NewReconciler(
fmt.Sprintf("bundle-%s", key),
c.config.databrokerClient,
func(ctx context.Context) (databroker.RecordSetBundle, error) {
return databrokerRecords, nil
},
func(ctx context.Context) (databroker.RecordSetBundle, error) {
return bundleRecords, nil
},
func(_ []*databroker.Record) {},
EqualRecord,
).Reconcile(ctx)
if err != nil {
return nil, fmt.Errorf("reconcile databroker records: %w", err)
}

View file

@ -103,7 +103,7 @@ func (r *Reconciler) RunLeased(ctx context.Context) error {
func (r *Reconciler) reconcileLoop(ctx context.Context) error {
for {
err := r.reconcile(ctx)
err := r.Reconcile(ctx)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("reconcile")
}
@ -116,7 +116,8 @@ func (r *Reconciler) reconcileLoop(ctx context.Context) error {
}
}
func (r *Reconciler) reconcile(ctx context.Context) error {
// Reconcile brings databroker state in line with the target state.
func (r *Reconciler) Reconcile(ctx context.Context) error {
current, target, err := r.getRecordSets(ctx)
if err != nil {
return fmt.Errorf("get config record sets: %w", err)