Merge branch 'main' into wasaga/pomerium-disable-validation

This commit is contained in:
Denis Mishin 2023-11-03 16:40:16 -04:00 committed by GitHub
commit d7611b1331
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 452 additions and 0 deletions

View file

@ -0,0 +1,66 @@
package databroker
import (
"context"
"fmt"
"google.golang.org/protobuf/types/known/timestamppb"
)
// GetChangeSet returns list of changes between the current and target record sets,
// that may be applied to the databroker to bring it to the target state.
func GetChangeSet(current, target RecordSetBundle) []*Record {
cs := &changeSet{now: timestamppb.Now()}
for _, rec := range current.GetRemoved(target).Flatten() {
cs.Remove(rec.GetType(), rec.GetId())
}
for _, rec := range current.GetModified(target).Flatten() {
cs.Upsert(rec)
}
for _, rec := range current.GetAdded(target).Flatten() {
cs.Upsert(rec)
}
return cs.updates
}
// changeSet is a set of databroker changes.
type changeSet struct {
now *timestamppb.Timestamp
updates []*Record
}
// Remove adds a record to the change set.
func (cs *changeSet) Remove(typ string, id string) {
cs.updates = append(cs.updates, &Record{
Type: typ,
Id: id,
DeletedAt: cs.now,
})
}
// Upsert adds a record to the change set.
func (cs *changeSet) Upsert(record *Record) {
cs.updates = append(cs.updates, &Record{
Type: record.Type,
Id: record.Id,
Data: record.Data,
})
}
// PutMulti puts the records into the databroker in batches.
func PutMulti(ctx context.Context, client DataBrokerServiceClient, records ...*Record) error {
if len(records) == 0 {
return nil
}
updates := OptimumPutRequestsFromRecords(records)
for _, req := range updates {
_, err := client.Put(ctx, req)
if err != nil {
return fmt.Errorf("put databroker record: %w", err)
}
}
return nil
}

View file

@ -0,0 +1,169 @@
package databroker
import (
"context"
"fmt"
"time"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
"github.com/pomerium/pomerium/internal/log"
)
// Reconciler reconciles the target and current record sets with the databroker.
type Reconciler struct {
reconcilerConfig
name string
client DataBrokerServiceClient
currentStateBuilder StateBuilderFn
targetStateBuilder StateBuilderFn
setCurrentState func([]*Record)
trigger chan struct{}
}
type reconcilerConfig struct {
interval time.Duration
}
// ReconcilerOption is an option for a reconciler.
type ReconcilerOption func(*reconcilerConfig)
// WithInterval sets the interval for the reconciler.
func WithInterval(interval time.Duration) ReconcilerOption {
return func(c *reconcilerConfig) {
c.interval = interval
}
}
func getReconcilerConfig(options ...ReconcilerOption) reconcilerConfig {
options = append([]ReconcilerOption{
WithInterval(time.Minute),
}, options...)
var c reconcilerConfig
for _, option := range options {
option(&c)
}
return c
}
// StateBuilderFn is a function that builds a record set bundle
type StateBuilderFn func(ctx context.Context) (RecordSetBundle, error)
// NewReconciler creates a new reconciler
func NewReconciler(
// name must be unique across pomerium ecosystem
name string,
client DataBrokerServiceClient,
currentStateBuilder StateBuilderFn,
targetStateBuilder StateBuilderFn,
setCurrentState func([]*Record),
opts ...ReconcilerOption,
) *Reconciler {
return &Reconciler{
name: fmt.Sprintf("%s-reconciler", name),
reconcilerConfig: getReconcilerConfig(opts...),
trigger: make(chan struct{}, 1),
client: client,
currentStateBuilder: currentStateBuilder,
targetStateBuilder: targetStateBuilder,
setCurrentState: setCurrentState,
}
}
// TriggerSync triggers a sync
func (r *Reconciler) TriggerSync() {
select {
case r.trigger <- struct{}{}:
default:
}
}
// Run runs the reconciler
func (r *Reconciler) Run(ctx context.Context) error {
leaser := NewLeaser(r.name, r.interval, r)
return leaser.Run(ctx)
}
// GetDataBrokerServiceClient implements the LeaseHandler interface.
func (r *Reconciler) GetDataBrokerServiceClient() DataBrokerServiceClient {
return r.client
}
// RunLeased implements the LeaseHandler interface.
func (r *Reconciler) RunLeased(ctx context.Context) error {
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", r.name)
})
return r.reconcileLoop(ctx)
}
func (r *Reconciler) reconcileLoop(ctx context.Context) error {
for {
err := r.reconcile(ctx)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("reconcile")
}
select {
case <-ctx.Done():
return ctx.Err()
case <-r.trigger:
}
}
}
func (r *Reconciler) reconcile(ctx context.Context) error {
current, target, err := r.getRecordSets(ctx)
if err != nil {
return fmt.Errorf("get config record sets: %w", err)
}
updates := GetChangeSet(current, target)
err = r.applyChanges(ctx, updates)
if err != nil {
return fmt.Errorf("apply config change set: %w", err)
}
r.setCurrentState(updates)
return nil
}
func (r *Reconciler) getRecordSets(ctx context.Context) (
current RecordSetBundle,
target RecordSetBundle,
_ error,
) {
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
var err error
current, err = r.currentStateBuilder(ctx)
if err != nil {
return fmt.Errorf("build current config record set: %w", err)
}
return nil
})
eg.Go(func() error {
var err error
target, err = r.targetStateBuilder(ctx)
if err != nil {
return fmt.Errorf("build target config record set: %w", err)
}
return nil
})
err := eg.Wait()
if err != nil {
return nil, nil, err
}
return current, target, nil
}
func (r *Reconciler) applyChanges(ctx context.Context, updates []*Record) error {
err := PutMulti(ctx, r.client, updates...)
if err != nil {
return fmt.Errorf("apply databroker changes: %w", err)
}
return nil
}

View file

@ -0,0 +1,140 @@
package databroker
import (
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)
// RecordSetBundle is an index of databroker records by type
type RecordSetBundle map[string]RecordSet
// RecordSet is an index of databroker records by their id.
type RecordSet map[string]*Record
// RecordTypes returns the types of records in the bundle.
func (rsb RecordSetBundle) RecordTypes() []string {
types := make([]string, 0, len(rsb))
for typ := range rsb {
types = append(types, typ)
}
return types
}
// Add adds a record to the bundle.
func (rsb RecordSetBundle) Add(record *Record) {
rs, ok := rsb[record.GetType()]
if !ok {
rs = make(RecordSet)
rsb[record.GetType()] = rs
}
rs[record.GetId()] = record
}
// GetAdded returns the records that are in other but not in rsb.
func (rsb RecordSetBundle) GetAdded(other RecordSetBundle) RecordSetBundle {
added := make(RecordSetBundle)
for otherType, otherRS := range other {
rs, ok := rsb[otherType]
if !ok {
added[otherType] = otherRS
continue
}
rss := rs.GetAdded(other[otherType])
if len(rss) > 0 {
added[otherType] = rss
}
}
return added
}
// GetRemoved returns the records that are in rs but not in other.
func (rsb RecordSetBundle) GetRemoved(other RecordSetBundle) RecordSetBundle {
return other.GetAdded(rsb)
}
// GetModified returns the records that are in both rs and other but have different data.
func (rsb RecordSetBundle) GetModified(other RecordSetBundle) RecordSetBundle {
modified := make(RecordSetBundle)
for otherType, otherRS := range other {
rs, ok := rsb[otherType]
if !ok {
continue
}
m := rs.GetModified(otherRS)
if len(m) > 0 {
modified[otherType] = m
}
}
return modified
}
// GetAdded returns the records that are in other but not in rs.
func (rs RecordSet) GetAdded(other RecordSet) RecordSet {
added := make(RecordSet)
for id, record := range other {
if _, ok := rs[id]; !ok {
added[id] = record
}
}
return added
}
// GetRemoved returns the records that are in rs but not in other.
func (rs RecordSet) GetRemoved(other RecordSet) RecordSet {
return other.GetAdded(rs)
}
// GetModified returns the records that are in both rs and other but have different data.
// by comparing the protobuf bytes of the payload.
func (rs RecordSet) GetModified(other RecordSet) RecordSet {
modified := make(RecordSet)
for id, record := range other {
otherRecord, ok := rs[id]
if !ok {
continue
}
if !proto.Equal(record, otherRecord) {
modified[id] = record
}
}
return modified
}
// Flatten returns all records in the set.
func (rs RecordSet) Flatten() []*Record {
records := make([]*Record, 0, len(rs))
for _, record := range rs {
records = append(records, record)
}
return records
}
// Flatten returns all records in the bundle.
func (rsb RecordSetBundle) Flatten() []*Record {
records := make([]*Record, 0)
for _, rs := range rsb {
records = append(records, rs.Flatten()...)
}
return records
}
// Get returns a record by type and id.
func (rsb RecordSetBundle) Get(typeName, id string) (record *Record, ok bool) {
rs, ok := rsb[typeName]
if !ok {
return
}
record, ok = rs[id]
return
}
// MarshalJSON marshals the record to JSON.
func (r *Record) MarshalJSON() ([]byte, error) {
return protojson.Marshal(r)
}
// UnmarshalJSON unmarshals the record from JSON.
func (r *Record) UnmarshalJSON(data []byte) error {
return protojson.Unmarshal(data, r)
}

View file

@ -0,0 +1,77 @@
package databroker_test
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestRecords(t *testing.T) {
tr := func(id, typ, val string) *databroker.Record {
return &databroker.Record{
Id: id,
Type: typ,
Data: protoutil.NewAnyString(val),
}
}
initial := make(databroker.RecordSetBundle)
initial.Add(tr("1", "a", "a-1"))
initial.Add(tr("2", "a", "a-2"))
initial.Add(tr("1", "b", "b-1"))
// test record types
assert.ElementsMatch(t, []string{"a", "b"}, initial.RecordTypes())
// test added, deleted and modified
updated := make(databroker.RecordSetBundle)
updated.Add(tr("1", "a", "a-1-1"))
updated.Add(tr("3", "a", "a-3"))
updated.Add(tr("1", "b", "b-1"))
updated.Add(tr("2", "b", "b-2"))
updated.Add(tr("1", "c", "c-1"))
assert.ElementsMatch(t, []string{"a", "b", "c"}, updated.RecordTypes())
equalJSON := func(a, b databroker.RecordSetBundle) {
t.Helper()
var txt [2]string
for i, x := range [2]databroker.RecordSetBundle{a, b} {
data, err := json.Marshal(x)
assert.NoError(t, err)
txt[i] = string(data)
}
assert.JSONEq(t, txt[0], txt[1])
}
added := initial.GetAdded(updated)
equalJSON(added, databroker.RecordSetBundle{
"a": databroker.RecordSet{
"3": tr("3", "a", "a-3"),
},
"b": databroker.RecordSet{
"2": tr("2", "b", "b-2"),
},
"c": databroker.RecordSet{
"1": tr("1", "c", "c-1"),
},
})
removed := initial.GetRemoved(updated)
equalJSON(removed, databroker.RecordSetBundle{
"a": databroker.RecordSet{
"2": tr("2", "a", "a-2"),
},
})
modified := initial.GetModified(updated)
equalJSON(modified, databroker.RecordSetBundle{
"a": databroker.RecordSet{
"1": tr("1", "a", "a-1-1"),
},
})
}