mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-07 11:58:12 +02:00
Merge branch 'main' into wasaga/pomerium-disable-validation
This commit is contained in:
commit
d7611b1331
4 changed files with 452 additions and 0 deletions
66
pkg/grpc/databroker/changeset.go
Normal file
66
pkg/grpc/databroker/changeset.go
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
package databroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetChangeSet returns list of changes between the current and target record sets,
|
||||||
|
// that may be applied to the databroker to bring it to the target state.
|
||||||
|
func GetChangeSet(current, target RecordSetBundle) []*Record {
|
||||||
|
cs := &changeSet{now: timestamppb.Now()}
|
||||||
|
|
||||||
|
for _, rec := range current.GetRemoved(target).Flatten() {
|
||||||
|
cs.Remove(rec.GetType(), rec.GetId())
|
||||||
|
}
|
||||||
|
for _, rec := range current.GetModified(target).Flatten() {
|
||||||
|
cs.Upsert(rec)
|
||||||
|
}
|
||||||
|
for _, rec := range current.GetAdded(target).Flatten() {
|
||||||
|
cs.Upsert(rec)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cs.updates
|
||||||
|
}
|
||||||
|
|
||||||
|
// changeSet is a set of databroker changes.
|
||||||
|
type changeSet struct {
|
||||||
|
now *timestamppb.Timestamp
|
||||||
|
updates []*Record
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove adds a record to the change set.
|
||||||
|
func (cs *changeSet) Remove(typ string, id string) {
|
||||||
|
cs.updates = append(cs.updates, &Record{
|
||||||
|
Type: typ,
|
||||||
|
Id: id,
|
||||||
|
DeletedAt: cs.now,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upsert adds a record to the change set.
|
||||||
|
func (cs *changeSet) Upsert(record *Record) {
|
||||||
|
cs.updates = append(cs.updates, &Record{
|
||||||
|
Type: record.Type,
|
||||||
|
Id: record.Id,
|
||||||
|
Data: record.Data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutMulti puts the records into the databroker in batches.
|
||||||
|
func PutMulti(ctx context.Context, client DataBrokerServiceClient, records ...*Record) error {
|
||||||
|
if len(records) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
updates := OptimumPutRequestsFromRecords(records)
|
||||||
|
for _, req := range updates {
|
||||||
|
_, err := client.Put(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("put databroker record: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
169
pkg/grpc/databroker/reconciler.go
Normal file
169
pkg/grpc/databroker/reconciler.go
Normal file
|
@ -0,0 +1,169 @@
|
||||||
|
package databroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Reconciler reconciles the target and current record sets with the databroker.
|
||||||
|
type Reconciler struct {
|
||||||
|
reconcilerConfig
|
||||||
|
name string
|
||||||
|
client DataBrokerServiceClient
|
||||||
|
currentStateBuilder StateBuilderFn
|
||||||
|
targetStateBuilder StateBuilderFn
|
||||||
|
setCurrentState func([]*Record)
|
||||||
|
trigger chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type reconcilerConfig struct {
|
||||||
|
interval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReconcilerOption is an option for a reconciler.
|
||||||
|
type ReconcilerOption func(*reconcilerConfig)
|
||||||
|
|
||||||
|
// WithInterval sets the interval for the reconciler.
|
||||||
|
func WithInterval(interval time.Duration) ReconcilerOption {
|
||||||
|
return func(c *reconcilerConfig) {
|
||||||
|
c.interval = interval
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getReconcilerConfig(options ...ReconcilerOption) reconcilerConfig {
|
||||||
|
options = append([]ReconcilerOption{
|
||||||
|
WithInterval(time.Minute),
|
||||||
|
}, options...)
|
||||||
|
var c reconcilerConfig
|
||||||
|
for _, option := range options {
|
||||||
|
option(&c)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// StateBuilderFn is a function that builds a record set bundle
|
||||||
|
type StateBuilderFn func(ctx context.Context) (RecordSetBundle, error)
|
||||||
|
|
||||||
|
// NewReconciler creates a new reconciler
|
||||||
|
func NewReconciler(
|
||||||
|
// name must be unique across pomerium ecosystem
|
||||||
|
name string,
|
||||||
|
client DataBrokerServiceClient,
|
||||||
|
currentStateBuilder StateBuilderFn,
|
||||||
|
targetStateBuilder StateBuilderFn,
|
||||||
|
setCurrentState func([]*Record),
|
||||||
|
opts ...ReconcilerOption,
|
||||||
|
) *Reconciler {
|
||||||
|
return &Reconciler{
|
||||||
|
name: fmt.Sprintf("%s-reconciler", name),
|
||||||
|
reconcilerConfig: getReconcilerConfig(opts...),
|
||||||
|
trigger: make(chan struct{}, 1),
|
||||||
|
client: client,
|
||||||
|
currentStateBuilder: currentStateBuilder,
|
||||||
|
targetStateBuilder: targetStateBuilder,
|
||||||
|
setCurrentState: setCurrentState,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerSync triggers a sync
|
||||||
|
func (r *Reconciler) TriggerSync() {
|
||||||
|
select {
|
||||||
|
case r.trigger <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run runs the reconciler
|
||||||
|
func (r *Reconciler) Run(ctx context.Context) error {
|
||||||
|
leaser := NewLeaser(r.name, r.interval, r)
|
||||||
|
return leaser.Run(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDataBrokerServiceClient implements the LeaseHandler interface.
|
||||||
|
func (r *Reconciler) GetDataBrokerServiceClient() DataBrokerServiceClient {
|
||||||
|
return r.client
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunLeased implements the LeaseHandler interface.
|
||||||
|
func (r *Reconciler) RunLeased(ctx context.Context) error {
|
||||||
|
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||||
|
return c.Str("service", r.name)
|
||||||
|
})
|
||||||
|
return r.reconcileLoop(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reconciler) reconcileLoop(ctx context.Context) error {
|
||||||
|
for {
|
||||||
|
err := r.reconcile(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Ctx(ctx).Error().Err(err).Msg("reconcile")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-r.trigger:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reconciler) reconcile(ctx context.Context) error {
|
||||||
|
current, target, err := r.getRecordSets(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get config record sets: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updates := GetChangeSet(current, target)
|
||||||
|
|
||||||
|
err = r.applyChanges(ctx, updates)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply config change set: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.setCurrentState(updates)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reconciler) getRecordSets(ctx context.Context) (
|
||||||
|
current RecordSetBundle,
|
||||||
|
target RecordSetBundle,
|
||||||
|
_ error,
|
||||||
|
) {
|
||||||
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
|
eg.Go(func() error {
|
||||||
|
var err error
|
||||||
|
current, err = r.currentStateBuilder(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("build current config record set: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
var err error
|
||||||
|
target, err = r.targetStateBuilder(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("build target config record set: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
err := eg.Wait()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return current, target, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reconciler) applyChanges(ctx context.Context, updates []*Record) error {
|
||||||
|
err := PutMulti(ctx, r.client, updates...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply databroker changes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
140
pkg/grpc/databroker/recordset.go
Normal file
140
pkg/grpc/databroker/recordset.go
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
package databroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RecordSetBundle is an index of databroker records by type
|
||||||
|
type RecordSetBundle map[string]RecordSet
|
||||||
|
|
||||||
|
// RecordSet is an index of databroker records by their id.
|
||||||
|
type RecordSet map[string]*Record
|
||||||
|
|
||||||
|
// RecordTypes returns the types of records in the bundle.
|
||||||
|
func (rsb RecordSetBundle) RecordTypes() []string {
|
||||||
|
types := make([]string, 0, len(rsb))
|
||||||
|
for typ := range rsb {
|
||||||
|
types = append(types, typ)
|
||||||
|
}
|
||||||
|
return types
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a record to the bundle.
|
||||||
|
func (rsb RecordSetBundle) Add(record *Record) {
|
||||||
|
rs, ok := rsb[record.GetType()]
|
||||||
|
if !ok {
|
||||||
|
rs = make(RecordSet)
|
||||||
|
rsb[record.GetType()] = rs
|
||||||
|
}
|
||||||
|
rs[record.GetId()] = record
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAdded returns the records that are in other but not in rsb.
|
||||||
|
func (rsb RecordSetBundle) GetAdded(other RecordSetBundle) RecordSetBundle {
|
||||||
|
added := make(RecordSetBundle)
|
||||||
|
for otherType, otherRS := range other {
|
||||||
|
rs, ok := rsb[otherType]
|
||||||
|
if !ok {
|
||||||
|
added[otherType] = otherRS
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rss := rs.GetAdded(other[otherType])
|
||||||
|
if len(rss) > 0 {
|
||||||
|
added[otherType] = rss
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return added
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRemoved returns the records that are in rs but not in other.
|
||||||
|
func (rsb RecordSetBundle) GetRemoved(other RecordSetBundle) RecordSetBundle {
|
||||||
|
return other.GetAdded(rsb)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModified returns the records that are in both rs and other but have different data.
|
||||||
|
func (rsb RecordSetBundle) GetModified(other RecordSetBundle) RecordSetBundle {
|
||||||
|
modified := make(RecordSetBundle)
|
||||||
|
for otherType, otherRS := range other {
|
||||||
|
rs, ok := rsb[otherType]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m := rs.GetModified(otherRS)
|
||||||
|
if len(m) > 0 {
|
||||||
|
modified[otherType] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAdded returns the records that are in other but not in rs.
|
||||||
|
func (rs RecordSet) GetAdded(other RecordSet) RecordSet {
|
||||||
|
added := make(RecordSet)
|
||||||
|
for id, record := range other {
|
||||||
|
if _, ok := rs[id]; !ok {
|
||||||
|
added[id] = record
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return added
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRemoved returns the records that are in rs but not in other.
|
||||||
|
func (rs RecordSet) GetRemoved(other RecordSet) RecordSet {
|
||||||
|
return other.GetAdded(rs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModified returns the records that are in both rs and other but have different data.
|
||||||
|
// by comparing the protobuf bytes of the payload.
|
||||||
|
func (rs RecordSet) GetModified(other RecordSet) RecordSet {
|
||||||
|
modified := make(RecordSet)
|
||||||
|
for id, record := range other {
|
||||||
|
otherRecord, ok := rs[id]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !proto.Equal(record, otherRecord) {
|
||||||
|
modified[id] = record
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flatten returns all records in the set.
|
||||||
|
func (rs RecordSet) Flatten() []*Record {
|
||||||
|
records := make([]*Record, 0, len(rs))
|
||||||
|
for _, record := range rs {
|
||||||
|
records = append(records, record)
|
||||||
|
}
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flatten returns all records in the bundle.
|
||||||
|
func (rsb RecordSetBundle) Flatten() []*Record {
|
||||||
|
records := make([]*Record, 0)
|
||||||
|
for _, rs := range rsb {
|
||||||
|
records = append(records, rs.Flatten()...)
|
||||||
|
}
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a record by type and id.
|
||||||
|
func (rsb RecordSetBundle) Get(typeName, id string) (record *Record, ok bool) {
|
||||||
|
rs, ok := rsb[typeName]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
record, ok = rs[id]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON marshals the record to JSON.
|
||||||
|
func (r *Record) MarshalJSON() ([]byte, error) {
|
||||||
|
return protojson.Marshal(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON unmarshals the record from JSON.
|
||||||
|
func (r *Record) UnmarshalJSON(data []byte) error {
|
||||||
|
return protojson.Unmarshal(data, r)
|
||||||
|
}
|
77
pkg/grpc/databroker/recordset_test.go
Normal file
77
pkg/grpc/databroker/recordset_test.go
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
package databroker_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRecords(t *testing.T) {
|
||||||
|
tr := func(id, typ, val string) *databroker.Record {
|
||||||
|
return &databroker.Record{
|
||||||
|
Id: id,
|
||||||
|
Type: typ,
|
||||||
|
Data: protoutil.NewAnyString(val),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
initial := make(databroker.RecordSetBundle)
|
||||||
|
initial.Add(tr("1", "a", "a-1"))
|
||||||
|
initial.Add(tr("2", "a", "a-2"))
|
||||||
|
initial.Add(tr("1", "b", "b-1"))
|
||||||
|
|
||||||
|
// test record types
|
||||||
|
assert.ElementsMatch(t, []string{"a", "b"}, initial.RecordTypes())
|
||||||
|
|
||||||
|
// test added, deleted and modified
|
||||||
|
updated := make(databroker.RecordSetBundle)
|
||||||
|
updated.Add(tr("1", "a", "a-1-1"))
|
||||||
|
updated.Add(tr("3", "a", "a-3"))
|
||||||
|
updated.Add(tr("1", "b", "b-1"))
|
||||||
|
updated.Add(tr("2", "b", "b-2"))
|
||||||
|
updated.Add(tr("1", "c", "c-1"))
|
||||||
|
|
||||||
|
assert.ElementsMatch(t, []string{"a", "b", "c"}, updated.RecordTypes())
|
||||||
|
|
||||||
|
equalJSON := func(a, b databroker.RecordSetBundle) {
|
||||||
|
t.Helper()
|
||||||
|
var txt [2]string
|
||||||
|
for i, x := range [2]databroker.RecordSetBundle{a, b} {
|
||||||
|
data, err := json.Marshal(x)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
txt[i] = string(data)
|
||||||
|
}
|
||||||
|
assert.JSONEq(t, txt[0], txt[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
added := initial.GetAdded(updated)
|
||||||
|
equalJSON(added, databroker.RecordSetBundle{
|
||||||
|
"a": databroker.RecordSet{
|
||||||
|
"3": tr("3", "a", "a-3"),
|
||||||
|
},
|
||||||
|
"b": databroker.RecordSet{
|
||||||
|
"2": tr("2", "b", "b-2"),
|
||||||
|
},
|
||||||
|
"c": databroker.RecordSet{
|
||||||
|
"1": tr("1", "c", "c-1"),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
removed := initial.GetRemoved(updated)
|
||||||
|
equalJSON(removed, databroker.RecordSetBundle{
|
||||||
|
"a": databroker.RecordSet{
|
||||||
|
"2": tr("2", "a", "a-2"),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
modified := initial.GetModified(updated)
|
||||||
|
equalJSON(modified, databroker.RecordSetBundle{
|
||||||
|
"a": databroker.RecordSet{
|
||||||
|
"1": tr("1", "a", "a-1-1"),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue