zero: resource bundle reconciler (#4445)

This commit is contained in:
Denis Mishin 2023-08-17 13:19:51 -04:00 committed by Kenneth Jenkins
parent c0b1309e90
commit ea8762d706
17 changed files with 1559 additions and 0 deletions

1
go.mod
View file

@ -69,6 +69,7 @@ require (
golang.org/x/net v0.17.0
golang.org/x/oauth2 v0.12.0
golang.org/x/sync v0.3.0
golang.org/x/time v0.3.0
google.golang.org/api v0.143.0
google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13
google.golang.org/grpc v1.58.3

View file

@ -0,0 +1,35 @@
package reconciler
import (
"bufio"
"errors"
"fmt"
"io"
"google.golang.org/protobuf/encoding/protodelim"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
var unmarshalOpts = protodelim.UnmarshalOptions{}
// ReadBundleRecords reads records in a protobuf wire format from src.
// Each record is expected to be a databroker.Record.
func ReadBundleRecords(src io.Reader) (RecordSetBundle[DatabrokerRecord], error) {
r := bufio.NewReader(src)
rsb := make(RecordSetBundle[DatabrokerRecord])
for {
record := new(databroker.Record)
err := unmarshalOpts.UnmarshalFrom(r, record)
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, fmt.Errorf("error reading protobuf record: %w", err)
}
rsb.Add(DatabrokerRecord{record})
}
return rsb, nil
}

View file

@ -0,0 +1,64 @@
package reconciler
import (
"fmt"
"io"
"os"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/protodelim"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestReadRecords(t *testing.T) {
dir := t.TempDir()
fd, err := os.CreateTemp(dir, "config")
require.NoError(t, err)
t.Cleanup(func() { _ = fd.Close() })
err = writeSampleRecords(fd)
require.NoError(t, err)
_, err = fd.Seek(0, io.SeekStart)
require.NoError(t, err)
records, err := ReadBundleRecords(fd)
require.NoError(t, err)
require.Len(t, records, 1)
}
func writeSampleRecords(dst io.Writer) error {
var marshalOpts = protodelim.MarshalOptions{
MarshalOptions: proto.MarshalOptions{
AllowPartial: false,
Deterministic: true,
UseCachedSize: false,
},
}
cfg := protoutil.NewAny(&config.Config{
Routes: []*config.Route{
{
From: "https://from.example.com",
To: []string{"https://to.example.com"},
},
},
})
rec := &databroker.Record{
Id: "config",
Type: cfg.GetTypeUrl(),
Data: cfg,
}
_, err := marshalOpts.MarshalTo(dst, rec)
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
return nil
}

View file

@ -0,0 +1,124 @@
package reconciler
import (
"container/heap"
"sync"
)
type bundle struct {
id string
synced bool
priority int
}
type bundleHeap []bundle
func (h bundleHeap) Len() int { return len(h) }
func (h bundleHeap) Less(i, j int) bool {
// If one is synced and the other is not, the unsynced one comes first
if h[i].synced != h[j].synced {
return !h[i].synced
}
// Otherwise, the one with the lower priority comes first
return h[i].priority < h[j].priority
}
func (h bundleHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *bundleHeap) Push(x interface{}) {
item := x.(bundle)
*h = append(*h, item)
}
func (h *bundleHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
// BundleQueue is a priority queue of bundles to sync.
type BundleQueue struct {
sync.Mutex
bundles bundleHeap
counter int // to assign priorities based on order of insertion
}
// Set sets the bundles to be synced. This will reset the sync status of all bundles.
func (b *BundleQueue) Set(bundles []string) {
b.Lock()
defer b.Unlock()
b.bundles = make(bundleHeap, len(bundles))
b.counter = len(bundles)
for i, id := range bundles {
b.bundles[i] = bundle{
id: id,
synced: false,
priority: i,
}
}
heap.Init(&b.bundles)
}
// MarkForSync marks the bundle with the given ID for syncing.
func (b *BundleQueue) MarkForSync(id string) {
b.Lock()
defer b.Unlock()
for i, bundle := range b.bundles {
if bundle.id == id {
b.bundles[i].synced = false
heap.Fix(&b.bundles, i)
return
}
}
newBundle := bundle{id: id, synced: false, priority: b.counter}
heap.Push(&b.bundles, newBundle)
b.counter++
}
// MarkForSyncLater marks the bundle with the given ID for syncing later (after all other bundles).
func (b *BundleQueue) MarkForSyncLater(id string) {
b.Lock()
defer b.Unlock()
for i, bundle := range b.bundles {
if bundle.id != id {
continue
}
// Increase the counter first to ensure that this bundle has the highest (last) priority.
b.counter++
b.bundles[i].synced = false
b.bundles[i].priority = b.counter
heap.Fix(&b.bundles, i)
return
}
}
// GetNextBundleToSync returns the ID of the next bundle to sync and whether there is one.
func (b *BundleQueue) GetNextBundleToSync() (string, bool) {
b.Lock()
defer b.Unlock()
if len(b.bundles) == 0 {
return "", false
}
// Check the top bundle without popping
if b.bundles[0].synced {
return "", false
}
// Mark the top bundle as synced and push it to the end
id := b.bundles[0].id
b.bundles[0].synced = true
b.bundles[0].priority = b.counter
heap.Fix(&b.bundles, 0)
b.counter++
return id, true
}

View file

@ -0,0 +1,95 @@
package reconciler_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/zero/reconciler"
)
func TestQueueSet(t *testing.T) {
t.Parallel()
b := &reconciler.BundleQueue{}
b.Set([]string{"bundle1", "bundle2"})
id1, ok1 := b.GetNextBundleToSync()
id2, ok2 := b.GetNextBundleToSync()
assert.True(t, ok1, "Expected bundle1 to be set")
assert.Equal(t, "bundle1", id1)
assert.True(t, ok2, "Expected bundle2 to be set")
assert.Equal(t, "bundle2", id2)
id3, ok3 := b.GetNextBundleToSync()
assert.False(t, ok3, "Expected no more bundles to sync")
assert.Empty(t, id3)
}
func TestQueueMarkForSync(t *testing.T) {
t.Parallel()
b := &reconciler.BundleQueue{}
b.Set([]string{"bundle1", "bundle2"})
b.MarkForSync("bundle2")
id1, ok1 := b.GetNextBundleToSync()
assert.True(t, ok1, "Expected bundle1 to be marked for sync")
assert.Equal(t, "bundle1", id1)
b.MarkForSync("bundle3")
id2, ok2 := b.GetNextBundleToSync()
id3, ok3 := b.GetNextBundleToSync()
assert.True(t, ok2, "Expected bundle2 to be marked for sync")
assert.Equal(t, "bundle2", id2)
assert.True(t, ok3, "Expected bundle3 to be marked for sync")
assert.Equal(t, "bundle3", id3)
}
func TestQueueMarkForSyncLater(t *testing.T) {
t.Parallel()
b := &reconciler.BundleQueue{}
b.Set([]string{"bundle1", "bundle2", "bundle3"})
id1, ok1 := b.GetNextBundleToSync()
b.MarkForSyncLater("bundle1")
id2, ok2 := b.GetNextBundleToSync()
id3, ok3 := b.GetNextBundleToSync()
id4, ok4 := b.GetNextBundleToSync()
id5, ok5 := b.GetNextBundleToSync()
assert.True(t, ok1, "Expected bundle1 to be marked for sync")
assert.Equal(t, "bundle1", id1)
assert.True(t, ok2, "Expected bundle2 to be marked for sync")
assert.Equal(t, "bundle2", id2)
assert.True(t, ok3, "Expected bundle3 to be marked for sync")
assert.Equal(t, "bundle3", id3)
assert.True(t, ok4, "Expected bundle1 to be marked for sync")
assert.Equal(t, "bundle1", id4)
assert.False(t, ok5, "Expected no more bundles to sync")
assert.Empty(t, id5)
}
func TestQueueGetNextBundleToSync(t *testing.T) {
t.Parallel()
b := &reconciler.BundleQueue{}
b.Set([]string{"bundle1", "bundle2"})
id1, ok1 := b.GetNextBundleToSync()
id2, ok2 := b.GetNextBundleToSync()
id3, ok3 := b.GetNextBundleToSync()
assert.True(t, ok1, "Expected bundle1 to be retrieved for sync")
assert.Equal(t, "bundle1", id1)
assert.True(t, ok2, "Expected bundle2 to be retrieved for sync")
assert.Equal(t, "bundle2", id2)
require.False(t, ok3, "Expected no more bundles to sync")
assert.Empty(t, id3)
}

View file

@ -0,0 +1,110 @@
// Package reconciler syncs the state of resource bundles between the cloud and the databroker.
package reconciler
import (
"net/http"
"os"
"time"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
sdk "github.com/pomerium/zero-sdk"
)
// reconcilerConfig contains the configuration for the resource bundles reconciler.
type reconcilerConfig struct {
api *sdk.API
databrokerClient databroker.DataBrokerServiceClient
databrokerRPS int
tmpDir string
httpClient *http.Client
checkForUpdateIntervalWhenDisconnected time.Duration
checkForUpdateIntervalWhenConnected time.Duration
syncBackoffMaxInterval time.Duration
}
// Option configures the resource bundles reconciler
type Option func(*reconcilerConfig)
// WithTemporaryDirectory configures the resource bundles client to use a temporary directory for
// downloading files.
func WithTemporaryDirectory(path string) Option {
return func(cfg *reconcilerConfig) {
cfg.tmpDir = path
}
}
// WithAPI configures the cluster api client.
func WithAPI(client *sdk.API) Option {
return func(cfg *reconcilerConfig) {
cfg.api = client
}
}
// WithDataBrokerClient configures the databroker client.
func WithDataBrokerClient(client databroker.DataBrokerServiceClient) Option {
return func(cfg *reconcilerConfig) {
cfg.databrokerClient = client
}
}
// WithDownloadHTTPClient configures the http client used for downloading files.
func WithDownloadHTTPClient(client *http.Client) Option {
return func(cfg *reconcilerConfig) {
cfg.httpClient = client
}
}
// WithDatabrokerRPSLimit configures the maximum number of requests per second to the databroker.
func WithDatabrokerRPSLimit(rps int) Option {
return func(cfg *reconcilerConfig) {
cfg.databrokerRPS = rps
}
}
// WithCheckForUpdateIntervalWhenDisconnected configures the interval at which the reconciler will check
// for updates when disconnected from the cloud.
func WithCheckForUpdateIntervalWhenDisconnected(interval time.Duration) Option {
return func(cfg *reconcilerConfig) {
cfg.checkForUpdateIntervalWhenDisconnected = interval
}
}
// WithCheckForUpdateIntervalWhenConnected configures the interval at which the reconciler will check
// for updates when connected to the cloud.
func WithCheckForUpdateIntervalWhenConnected(interval time.Duration) Option {
return func(cfg *reconcilerConfig) {
cfg.checkForUpdateIntervalWhenConnected = interval
}
}
// WithSyncBackoffMaxInterval configures the maximum interval between sync attempts.
func WithSyncBackoffMaxInterval(interval time.Duration) Option {
return func(cfg *reconcilerConfig) {
cfg.syncBackoffMaxInterval = interval
}
}
func newConfig(opts ...Option) *reconcilerConfig {
cfg := &reconcilerConfig{}
for _, opt := range []Option{
WithTemporaryDirectory(os.TempDir()),
WithDownloadHTTPClient(http.DefaultClient),
WithDatabrokerRPSLimit(1_000),
WithCheckForUpdateIntervalWhenDisconnected(time.Minute * 5),
WithCheckForUpdateIntervalWhenConnected(time.Hour),
WithSyncBackoffMaxInterval(time.Minute),
} {
opt(cfg)
}
for _, opt := range opts {
opt(cfg)
}
return cfg
}

View file

@ -0,0 +1,81 @@
package reconciler
import (
"context"
"errors"
"fmt"
"io"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// DatabrokerRecord is a wrapper around a databroker record.
type DatabrokerRecord struct {
V *databroker.Record
}
var _ Record[DatabrokerRecord] = DatabrokerRecord{}
// GetID returns the databroker record's ID.
func (r DatabrokerRecord) GetID() string {
return r.V.GetId()
}
// GetType returns the databroker record's type.
func (r DatabrokerRecord) GetType() string {
return r.V.GetType()
}
// Equal returns true if the databroker records are equal.
func (r DatabrokerRecord) Equal(other DatabrokerRecord) bool {
return r.V.Type == other.V.Type &&
r.V.Id == other.V.Id &&
proto.Equal(r.V.Data, other.V.Data)
}
// GetDatabrokerRecords gets all databroker records of the given types.
func GetDatabrokerRecords(
ctx context.Context,
client databroker.DataBrokerServiceClient,
types []string,
) (RecordSetBundle[DatabrokerRecord], error) {
rsb := make(RecordSetBundle[DatabrokerRecord])
for _, typ := range types {
recs, err := getDatabrokerRecords(ctx, client, typ)
if err != nil {
return nil, fmt.Errorf("get databroker records for type %s: %w", typ, err)
}
rsb[typ] = recs
}
return rsb, nil
}
func getDatabrokerRecords(
ctx context.Context,
client databroker.DataBrokerServiceClient,
typ string,
) (RecordSet[DatabrokerRecord], error) {
stream, err := client.SyncLatest(ctx, &databroker.SyncLatestRequest{Type: typ})
if err != nil {
return nil, fmt.Errorf("sync latest databroker: %w", err)
}
recordSet := make(RecordSet[DatabrokerRecord])
for {
res, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return nil, fmt.Errorf("receive databroker record: %w", err)
}
if record := res.GetRecord(); record != nil {
recordSet[record.GetId()] = DatabrokerRecord{record}
}
}
return recordSet, nil
}

View file

@ -0,0 +1,53 @@
package reconciler
import (
"context"
"fmt"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// DatabrokerChangeSet is a set of databroker changes.
type DatabrokerChangeSet struct {
now *timestamppb.Timestamp
updates []*databroker.Record
}
// NewDatabrokerChangeSet creates a new databroker change set.
func NewDatabrokerChangeSet() *DatabrokerChangeSet {
return &DatabrokerChangeSet{
now: timestamppb.Now(),
}
}
// Remove adds a record to the change set.
func (cs *DatabrokerChangeSet) Remove(typ string, id string) {
cs.updates = append(cs.updates, &databroker.Record{
Type: typ,
Id: id,
DeletedAt: cs.now,
})
}
// Upsert adds a record to the change set.
func (cs *DatabrokerChangeSet) Upsert(record *databroker.Record) {
cs.updates = append(cs.updates, &databroker.Record{
Type: record.Type,
Id: record.Id,
Data: record.Data,
})
}
// ApplyChanges applies the changes to the databroker.
func ApplyChanges(ctx context.Context, client databroker.DataBrokerServiceClient, changes *DatabrokerChangeSet) error {
updates := databroker.OptimumPutRequestsFromRecords(changes.updates)
for _, req := range updates {
_, err := client.Put(ctx, req)
if err != nil {
return fmt.Errorf("put databroker record: %w", err)
}
}
return nil
}

View file

@ -0,0 +1,168 @@
package reconciler
import (
"context"
"errors"
"fmt"
"github.com/hashicorp/go-multierror"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
zero_sdk "github.com/pomerium/zero-sdk"
)
// BundleCacheEntry is a cache entry for a bundle
// that is kept in the databroker to avoid downloading
// the same bundle multiple times.
//
// by using the ETag and LastModified headers, we do not need to
// keep caches of the bundles themselves, which can be large.
//
// also it works in case of multiple instances, as it uses
// the databroker database as a shared cache.
type BundleCacheEntry struct {
zero_sdk.DownloadConditional
RecordTypes []string
}
const (
bundleCacheEntryRecordType = "pomerium.io/BundleCacheEntry"
)
var (
// ErrBundleCacheEntryNotFound is returned when a bundle cache entry is not found
ErrBundleCacheEntryNotFound = errors.New("bundle cache entry not found")
)
// GetBundleCacheEntry gets a bundle cache entry from the databroker
func (c *service) GetBundleCacheEntry(ctx context.Context, id string) (*BundleCacheEntry, error) {
record, err := c.config.databrokerClient.Get(ctx, &databroker.GetRequest{
Type: bundleCacheEntryRecordType,
Id: id,
})
if err != nil && status.Code(err) == codes.NotFound {
return nil, ErrBundleCacheEntryNotFound
} else if err != nil {
return nil, fmt.Errorf("get bundle cache entry: %w", err)
}
var dst BundleCacheEntry
data := record.GetRecord().GetData()
err = dst.FromAny(data)
if err != nil {
log.Ctx(ctx).Error().Err(err).
Str("bundle-id", id).
Str("data", protojson.Format(data)).
Msg("could not unmarshal bundle cache entry")
// we would allow it to be overwritten by the update process
return nil, ErrBundleCacheEntryNotFound
}
return &dst, nil
}
// SetBundleCacheEntry sets a bundle cache entry in the databroker
func (c *service) SetBundleCacheEntry(ctx context.Context, id string, src BundleCacheEntry) error {
val, err := src.ToAny()
if err != nil {
return fmt.Errorf("marshal bundle cache entry: %w", err)
}
_, err = c.config.databrokerClient.Put(ctx, &databroker.PutRequest{
Records: []*databroker.Record{
{
Type: bundleCacheEntryRecordType,
Id: id,
Data: val,
},
},
})
if err != nil {
return fmt.Errorf("set bundle cache entry: %w", err)
}
return nil
}
// ToAny marshals a BundleCacheEntry into an anypb.Any
func (r *BundleCacheEntry) ToAny() (*anypb.Any, error) {
err := r.Validate()
if err != nil {
return nil, fmt.Errorf("validate: %w", err)
}
types := make([]*structpb.Value, 0, len(r.RecordTypes))
for _, t := range r.RecordTypes {
types = append(types, structpb.NewStringValue(t))
}
return protoutil.NewAny(&structpb.Struct{
Fields: map[string]*structpb.Value{
"etag": structpb.NewStringValue(r.ETag),
"last_modified": structpb.NewStringValue(r.LastModified),
"record_types": structpb.NewListValue(&structpb.ListValue{Values: types}),
},
}), nil
}
// FromAny unmarshals an anypb.Any into a BundleCacheEntry
func (r *BundleCacheEntry) FromAny(any *anypb.Any) error {
var s structpb.Struct
err := any.UnmarshalTo(&s)
if err != nil {
return fmt.Errorf("unmarshal struct: %w", err)
}
r.ETag = s.GetFields()["etag"].GetStringValue()
r.LastModified = s.GetFields()["last_modified"].GetStringValue()
for _, v := range s.GetFields()["record_types"].GetListValue().GetValues() {
r.RecordTypes = append(r.RecordTypes, v.GetStringValue())
}
err = r.Validate()
if err != nil {
return fmt.Errorf("validate: %w", err)
}
return nil
}
// Validate validates a BundleCacheEntry
func (r *BundleCacheEntry) Validate() error {
var errs *multierror.Error
if len(r.RecordTypes) == 0 {
errs = multierror.Append(errs, errors.New("record_types is required"))
}
if err := r.DownloadConditional.Validate(); err != nil {
errs = multierror.Append(errs, err)
}
return errs.ErrorOrNil()
}
// GetDownloadConditional returns conditional download information
func (r *BundleCacheEntry) GetDownloadConditional() *zero_sdk.DownloadConditional {
if r == nil {
return nil
}
cond := r.DownloadConditional
return &cond
}
// GetRecordTypes returns the record types
func (r *BundleCacheEntry) GetRecordTypes() []string {
if r == nil {
return nil
}
return r.RecordTypes
}
// Equals returns true if the two cache entries are equal
func (r *BundleCacheEntry) Equals(other *BundleCacheEntry) bool {
return r != nil && other != nil &&
r.ETag == other.ETag && r.LastModified == other.LastModified
}

View file

@ -0,0 +1,29 @@
package reconciler_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/zero/reconciler"
zero_sdk "github.com/pomerium/zero-sdk"
)
func TestCacheEntryProto(t *testing.T) {
t.Parallel()
original := reconciler.BundleCacheEntry{
DownloadConditional: zero_sdk.DownloadConditional{
ETag: "etag value",
LastModified: "2009-02-13 18:31:30 -0500 EST",
},
RecordTypes: []string{"one", "two"},
}
originalProto, err := original.ToAny()
require.NoError(t, err)
var unmarshaled reconciler.BundleCacheEntry
err = unmarshaled.FromAny(originalProto)
require.NoError(t, err)
assert.True(t, original.Equals(&unmarshaled))
}

View file

@ -0,0 +1,34 @@
package reconciler
import (
"context"
"fmt"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// Reconcile reconciles the target and current record sets with the databroker.
func Reconcile(
ctx context.Context,
client databroker.DataBrokerServiceClient,
target, current RecordSetBundle[DatabrokerRecord],
) error {
updates := NewDatabrokerChangeSet()
for _, rec := range current.GetRemoved(target).Flatten() {
updates.Remove(rec.GetType(), rec.GetID())
}
for _, rec := range current.GetModified(target).Flatten() {
updates.Upsert(rec.V)
}
for _, rec := range current.GetAdded(target).Flatten() {
updates.Upsert(rec.V)
}
err := ApplyChanges(ctx, client, updates)
if err != nil {
return fmt.Errorf("apply databroker changes: %w", err)
}
return nil
}

View file

@ -0,0 +1,196 @@
package reconciler_test
import (
"context"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
"google.golang.org/protobuf/types/known/wrapperspb"
databroker_int "github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/zero/reconciler"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func newDatabroker(t *testing.T) (context.Context, databroker.DataBrokerServiceClient) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
gs := grpc.NewServer()
srv := databroker_int.New()
databroker.RegisterDataBrokerServiceServer(gs, srv)
lis := bufconn.Listen(1)
t.Cleanup(func() {
lis.Close()
gs.Stop()
})
go func() { _ = gs.Serve(lis) }()
conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(func(context.Context, string) (conn net.Conn, e error) {
return lis.Dial()
}), grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
t.Cleanup(func() { _ = conn.Close() })
return ctx, databroker.NewDataBrokerServiceClient(conn)
}
func newRecordBundle(records []testRecord) reconciler.RecordSetBundle[reconciler.DatabrokerRecord] {
bundle := make(reconciler.RecordSetBundle[reconciler.DatabrokerRecord])
for _, r := range records {
bundle.Add(newRecord(r))
}
return bundle
}
func newRecord(r testRecord) reconciler.DatabrokerRecord {
return reconciler.DatabrokerRecord{
V: &databroker.Record{
Type: r.Type,
Id: r.ID,
Data: protoutil.NewAnyString(r.Val),
}}
}
func assertBundle(t *testing.T, want []testRecord, got reconciler.RecordSetBundle[reconciler.DatabrokerRecord]) {
t.Helper()
for _, wantRecord := range want {
gotRecord, ok := got.Get(wantRecord.Type, wantRecord.ID)
if assert.True(t, ok, "record %s/%s not found", wantRecord.Type, wantRecord.ID) {
assertRecord(t, wantRecord, gotRecord)
}
}
assert.Len(t, got.Flatten(), len(want))
}
func assertRecord(t *testing.T, want testRecord, got reconciler.DatabrokerRecord) {
t.Helper()
var val wrapperspb.StringValue
err := got.V.Data.UnmarshalTo(&val)
require.NoError(t, err)
assert.Equal(t, want.Type, got.V.Type)
assert.Equal(t, want.ID, got.V.Id)
assert.Equal(t, want.Val, val.Value)
}
func TestHelpers(t *testing.T) {
want := []testRecord{
{"type1", "id1", "value1"},
{"type1", "id2", "value2"},
}
bundle := newRecordBundle(want)
assertBundle(t, want, bundle)
}
func wantRemoved(want, current []string) []string {
wantM := make(map[string]struct{}, len(want))
for _, w := range want {
wantM[w] = struct{}{}
}
var toRemove []string
for _, c := range current {
if _, ok := wantM[c]; !ok {
toRemove = append(toRemove, c)
}
}
return toRemove
}
func reconcile(
ctx context.Context,
t *testing.T,
client databroker.DataBrokerServiceClient,
want []testRecord,
current reconciler.RecordSetBundle[reconciler.DatabrokerRecord],
) reconciler.RecordSetBundle[reconciler.DatabrokerRecord] {
t.Helper()
wantBundle := newRecordBundle(want)
err := reconciler.Reconcile(ctx, client, wantBundle, current)
require.NoError(t, err)
got, err := reconciler.GetDatabrokerRecords(ctx, client, wantBundle.RecordTypes())
require.NoError(t, err)
assertBundle(t, want, got)
res, err := reconciler.GetDatabrokerRecords(ctx, client, wantRemoved(wantBundle.RecordTypes(), current.RecordTypes()))
require.NoError(t, err)
assert.Empty(t, res.Flatten())
return got
}
func TestReconcile(t *testing.T) {
t.Parallel()
ctx, client := newDatabroker(t)
err := reconciler.Reconcile(ctx, client, nil, nil)
require.NoError(t, err)
var current reconciler.RecordSetBundle[reconciler.DatabrokerRecord]
for _, tc := range []struct {
name string
want []testRecord
}{
{"empty", nil},
{"initial", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id2", "value2"},
}},
{"add one", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id2", "value2"},
{"type1", "id3", "value3"},
}},
{"update one", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id2", "value2-updated"},
{"type1", "id3", "value3"},
}},
{"delete one", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id3", "value3"},
}},
{"delete all", nil},
{"multiple types", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id2", "value2"},
{"type2", "id1", "value1"},
{"type2", "id2", "value2"},
}},
{"multiple types update", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id2", "value2-updated"},
{"type2", "id1", "value1"},
{"type2", "id2", "value2-updated"},
}},
{"multiple types delete", []testRecord{
{"type1", "id1", "value1"},
{"type2", "id1", "value1"},
}},
{"multiple types delete one type, add one value", []testRecord{
{"type1", "id1", "value1"},
{"type1", "id4", "value4"},
}},
} {
t.Run(tc.name, func(t *testing.T) {
current = reconcile(ctx, t, client, tc.want, current)
})
}
}

View file

@ -0,0 +1,132 @@
package reconciler
// RecordSetBundle is an index of databroker records by type
type RecordSetBundle[T Record[T]] map[string]RecordSet[T]
// RecordSet is an index of databroker records by their id.
type RecordSet[T Record[T]] map[string]T
// Record is a record
type Record[T any] interface {
GetID() string
GetType() string
Equal(other T) bool
}
// RecordTypes returns the types of records in the bundle.
func (rsb RecordSetBundle[T]) RecordTypes() []string {
types := make([]string, 0, len(rsb))
for typ := range rsb {
types = append(types, typ)
}
return types
}
// Add adds a record to the bundle.
func (rsb RecordSetBundle[T]) Add(record T) {
rs, ok := rsb[record.GetType()]
if !ok {
rs = make(RecordSet[T])
rsb[record.GetType()] = rs
}
rs[record.GetID()] = record
}
// GetAdded returns the records that are in other but not in rsb.
func (rsb RecordSetBundle[T]) GetAdded(other RecordSetBundle[T]) RecordSetBundle[T] {
added := make(RecordSetBundle[T])
for otherType, otherRS := range other {
rs, ok := rsb[otherType]
if !ok {
added[otherType] = otherRS
continue
}
rss := rs.GetAdded(other[otherType])
if len(rss) > 0 {
added[otherType] = rss
}
}
return added
}
// GetRemoved returns the records that are in rs but not in other.
func (rsb RecordSetBundle[T]) GetRemoved(other RecordSetBundle[T]) RecordSetBundle[T] {
return other.GetAdded(rsb)
}
// GetModified returns the records that are in both rs and other but have different data.
func (rsb RecordSetBundle[T]) GetModified(other RecordSetBundle[T]) RecordSetBundle[T] {
modified := make(RecordSetBundle[T])
for otherType, otherRS := range other {
rs, ok := rsb[otherType]
if !ok {
continue
}
m := rs.GetModified(otherRS)
if len(m) > 0 {
modified[otherType] = m
}
}
return modified
}
// GetAdded returns the records that are in other but not in rs.
func (rs RecordSet[T]) GetAdded(other RecordSet[T]) RecordSet[T] {
added := make(RecordSet[T])
for id, record := range other {
if _, ok := rs[id]; !ok {
added[id] = record
}
}
return added
}
// GetRemoved returns the records that are in rs but not in other.
func (rs RecordSet[T]) GetRemoved(other RecordSet[T]) RecordSet[T] {
return other.GetAdded(rs)
}
// GetModified returns the records that are in both rs and other but have different data.
// by comparing the protobuf bytes of the payload.
func (rs RecordSet[T]) GetModified(other RecordSet[T]) RecordSet[T] {
modified := make(RecordSet[T])
for id, record := range other {
otherRecord, ok := rs[id]
if !ok {
continue
}
if !record.Equal(otherRecord) {
modified[id] = record
}
}
return modified
}
// Flatten returns all records in the set.
func (rs RecordSet[T]) Flatten() []T {
records := make([]T, 0, len(rs))
for _, record := range rs {
records = append(records, record)
}
return records
}
// Flatten returns all records in the bundle.
func (rsb RecordSetBundle[T]) Flatten() []T {
records := make([]T, 0)
for _, rs := range rsb {
records = append(records, rs.Flatten()...)
}
return records
}
// Get returns a record by type and id.
func (rsb RecordSetBundle[T]) Get(typeName, id string) (record T, ok bool) {
rs, ok := rsb[typeName]
if !ok {
return
}
record, ok = rs[id]
return
}

View file

@ -0,0 +1,77 @@
package reconciler_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/zero/reconciler"
)
type testRecord struct {
Type string
ID string
Val string
}
func (r testRecord) GetID() string {
return r.ID
}
func (r testRecord) GetType() string {
return r.Type
}
func (r testRecord) Equal(other testRecord) bool {
return r.ID == other.ID && r.Type == other.Type && r.Val == other.Val
}
func TestRecords(t *testing.T) {
initial := make(reconciler.RecordSetBundle[testRecord])
initial.Add(testRecord{ID: "1", Type: "a", Val: "a-1"})
initial.Add(testRecord{ID: "2", Type: "a", Val: "a-2"})
initial.Add(testRecord{ID: "1", Type: "b", Val: "b-1"})
// test record types
assert.ElementsMatch(t, []string{"a", "b"}, initial.RecordTypes())
// test added, deleted and modified
updated := make(reconciler.RecordSetBundle[testRecord])
updated.Add(testRecord{ID: "1", Type: "a", Val: "a-1-1"})
updated.Add(testRecord{ID: "3", Type: "a", Val: "a-3"})
updated.Add(testRecord{ID: "1", Type: "b", Val: "b-1"})
updated.Add(testRecord{ID: "2", Type: "b", Val: "b-2"})
updated.Add(testRecord{ID: "1", Type: "c", Val: "c-1"})
assert.ElementsMatch(t, []string{"a", "b", "c"}, updated.RecordTypes())
added := initial.GetAdded(updated)
assert.Equal(t,
reconciler.RecordSetBundle[testRecord]{
"a": reconciler.RecordSet[testRecord]{
"3": {ID: "3", Type: "a", Val: "a-3"},
},
"b": reconciler.RecordSet[testRecord]{
"2": {ID: "2", Type: "b", Val: "b-2"},
},
"c": reconciler.RecordSet[testRecord]{
"1": {ID: "1", Type: "c", Val: "c-1"},
},
}, added)
removed := initial.GetRemoved(updated)
assert.Equal(t,
reconciler.RecordSetBundle[testRecord]{
"a": reconciler.RecordSet[testRecord]{
"2": {ID: "2", Type: "a", Val: "a-2"},
},
}, removed)
modified := initial.GetModified(updated)
assert.Equal(t,
reconciler.RecordSetBundle[testRecord]{
"a": reconciler.RecordSet[testRecord]{
"1": {ID: "1", Type: "a", Val: "a-1-1"},
},
}, modified)
}

View file

@ -0,0 +1,89 @@
package reconciler
/*
* This is a main control loop for the reconciler service.
*
*/
import (
"context"
"time"
"golang.org/x/sync/errgroup"
"golang.org/x/time/rate"
"github.com/pomerium/pomerium/internal/atomicutil"
connect_mux "github.com/pomerium/zero-sdk/connect-mux"
)
type service struct {
config *reconcilerConfig
databrokerRateLimit *rate.Limiter
bundles BundleQueue
fullSyncRequest chan struct{}
bundleSyncRequest chan struct{}
periodicUpdateInterval atomicutil.Value[time.Duration]
}
// Run creates a new bundle updater client
// that runs until the context is canceled or a fatal error occurs.
func Run(ctx context.Context, opts ...Option) error {
config := newConfig(opts...)
c := &service{
config: config,
databrokerRateLimit: rate.NewLimiter(rate.Limit(config.databrokerRPS), 1),
fullSyncRequest: make(chan struct{}, 1),
}
c.periodicUpdateInterval.Store(config.checkForUpdateIntervalWhenDisconnected)
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error { return c.watchUpdates(ctx) })
eg.Go(func() error { return c.SyncLoop(ctx) })
return eg.Wait()
}
// run is a main control loop.
// it is very simple and sequential download and reconcile.
// it may be later optimized by splitting between download and reconciliation process,
// as we would get more resource bundles beyond the config.
func (c *service) watchUpdates(ctx context.Context) error {
return c.config.api.Watch(ctx,
connect_mux.WithOnConnected(func(ctx context.Context) {
c.triggerFullUpdate(true)
}),
connect_mux.WithOnDisconnected(func(_ context.Context) {
c.triggerFullUpdate(false)
}),
connect_mux.WithOnBundleUpdated(func(_ context.Context, key string) {
c.triggerBundleUpdate(key)
}),
)
}
func (c *service) triggerBundleUpdate(id string) {
c.periodicUpdateInterval.Store(c.config.checkForUpdateIntervalWhenConnected)
c.bundles.MarkForSync(id)
select {
case c.fullSyncRequest <- struct{}{}:
default:
}
}
func (c *service) triggerFullUpdate(connected bool) {
timeout := c.config.checkForUpdateIntervalWhenDisconnected
if connected {
timeout = c.config.checkForUpdateIntervalWhenConnected
}
c.periodicUpdateInterval.Store(timeout)
select {
case c.fullSyncRequest <- struct{}{}:
default:
}
}

View file

@ -0,0 +1,231 @@
package reconciler
/*
* Sync syncs the bundles between their cloud source and the databroker.
*
* FullSync performs a full sync of the bundles by calling the API,
* and walking the list of bundles, and calling SyncBundle on each.
* It also removes any records in the databroker that are not in the list of bundles.
*
* WatchAndSync watches the API for changes, and calls SyncBundle on each change.
*
*/
import (
"context"
"errors"
"fmt"
"io"
"time"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/retry"
)
// Sync synchronizes the bundles between their cloud source and the databroker.
func (c *service) SyncLoop(ctx context.Context) error {
ticker := time.NewTicker(c.periodicUpdateInterval.Load())
defer ticker.Stop()
for {
dur := c.periodicUpdateInterval.Load()
ticker.Reset(dur)
select {
case <-ctx.Done():
return ctx.Err()
case <-c.bundleSyncRequest:
log.Ctx(ctx).Info().Msg("bundle sync triggered")
err := c.syncBundles(ctx)
if err != nil {
return fmt.Errorf("reconciler: sync bundles: %w", err)
}
case <-c.fullSyncRequest:
log.Ctx(ctx).Info().Msg("full sync triggered")
err := c.syncAll(ctx)
if err != nil {
return fmt.Errorf("reconciler: sync all: %w", err)
}
case <-ticker.C:
log.Ctx(ctx).Info().Msg("periodic sync triggered")
err := c.syncAll(ctx)
if err != nil {
return fmt.Errorf("reconciler: sync all: %w", err)
}
}
}
}
func (c *service) syncAll(ctx context.Context) error {
err := c.syncBundleList(ctx)
if err != nil {
return fmt.Errorf("sync bundle list: %w", err)
}
err = c.syncBundles(ctx)
if err != nil {
return fmt.Errorf("sync bundles: %w", err)
}
return nil
}
// trySyncAllBundles tries to sync all bundles in the queue.
func (c *service) syncBundleList(ctx context.Context) error {
// refresh bundle list,
// ignoring other signals while we're retrying
return retry.Retry(ctx,
"refresh bundle list", c.refreshBundleList,
retry.WithWatch("refresh bundle list", c.fullSyncRequest, nil),
retry.WithWatch("bundle update", c.bundleSyncRequest, nil),
)
}
// syncBundles retries until there are no more bundles to sync.
// updates bundle list if the full bundle update request arrives.
func (c *service) syncBundles(ctx context.Context) error {
return retry.Retry(ctx,
"sync bundles", c.trySyncBundles,
retry.WithWatch("refresh bundle list", c.fullSyncRequest, c.refreshBundleList),
retry.WithWatch("bundle update", c.bundleSyncRequest, nil),
)
}
// trySyncAllBundles tries to sync all bundles in the queue
// it returns nil if all bundles were synced successfully
func (c *service) trySyncBundles(ctx context.Context) error {
for {
id, ok := c.bundles.GetNextBundleToSync()
if !ok { // no more bundles to sync
return nil
}
err := c.syncBundle(ctx, id)
if err != nil {
c.bundles.MarkForSyncLater(id)
return fmt.Errorf("sync bundle %s: %w", id, err)
}
}
}
// syncBundle syncs the bundle to the databroker.
// Databroker holds last synced bundle state in form of a (etag, last-modified) tuple.
// This is only persisted in the databroker after all records are successfully synced.
// That allows us to ignore any changes based on the same bundle state, without need to re-check all records between bundle and databroker.
func (c *service) syncBundle(ctx context.Context, key string) error {
cached, err := c.GetBundleCacheEntry(ctx, key)
if err != nil && !errors.Is(err, ErrBundleCacheEntryNotFound) {
return fmt.Errorf("get bundle cache entry: %w", err)
}
// download is much faster compared to databroker sync,
// so we don't use pipe but rather download to a temp file and then sync it to databroker
fd, err := c.GetTmpFile(key)
if err != nil {
return fmt.Errorf("get tmp file: %w", err)
}
defer func() {
if err := fd.Close(); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("close tmp file")
}
}()
conditional := cached.GetDownloadConditional()
log.Ctx(ctx).Debug().Str("id", key).Any("conditional", conditional).Msg("downloading bundle")
result, err := c.config.api.DownloadClusterResourceBundle(ctx, fd, key, conditional)
if err != nil {
return fmt.Errorf("download bundle: %w", err)
}
if result.NotModified {
log.Ctx(ctx).Debug().Str("bundle", key).Msg("bundle not changed")
return nil
}
log.Ctx(ctx).Debug().Str("bundle", key).
Interface("cached-entry", cached).
Interface("current-entry", result.DownloadConditional).
Msg("bundle updated")
_, err = fd.Seek(0, io.SeekStart)
if err != nil {
return fmt.Errorf("seek to start: %w", err)
}
bundleRecordTypes, err := c.syncBundleToDatabroker(ctx, fd, cached.GetRecordTypes())
if err != nil {
return fmt.Errorf("apply bundle to databroker: %w", err)
}
current := BundleCacheEntry{
DownloadConditional: *result.DownloadConditional,
RecordTypes: bundleRecordTypes,
}
log.Ctx(ctx).Info().
Str("bundle", key).
Strs("record_types", bundleRecordTypes).
Str("etag", current.ETag).
Str("last_modified", current.LastModified).
Msg("bundle synced")
err = c.SetBundleCacheEntry(ctx, key, current)
if err != nil {
return fmt.Errorf("set bundle cache entry: %w", err)
}
return nil
}
func strUnion(a, b []string) []string {
m := make(map[string]struct{}, len(a)+len(b))
for _, s := range a {
m[s] = struct{}{}
}
for _, s := range b {
m[s] = struct{}{}
}
out := make([]string, 0, len(m))
for s := range m {
out = append(out, s)
}
return out
}
func (c *service) syncBundleToDatabroker(ctx context.Context, src io.Reader, currentRecordTypes []string) ([]string, error) {
bundleRecords, err := ReadBundleRecords(src)
if err != nil {
return nil, fmt.Errorf("read bundle records: %w", err)
}
databrokerRecords, err := GetDatabrokerRecords(ctx,
c.config.databrokerClient,
strUnion(bundleRecords.RecordTypes(), currentRecordTypes),
)
if err != nil {
return nil, fmt.Errorf("get databroker records: %w", err)
}
err = Reconcile(ctx, c.config.databrokerClient, bundleRecords, databrokerRecords)
if err != nil {
return nil, fmt.Errorf("reconcile databroker records: %w", err)
}
return bundleRecords.RecordTypes(), nil
}
func (c *service) refreshBundleList(ctx context.Context) error {
resp, err := c.config.api.GetClusterResourceBundles(ctx)
if err != nil {
return fmt.Errorf("get bundles: %w", err)
}
ids := make([]string, 0, len(resp.Bundles))
for _, v := range resp.Bundles {
ids = append(ids, v.Id)
}
c.bundles.Set(ids)
return nil
}

View file

@ -0,0 +1,40 @@
package reconciler
import (
"fmt"
"io"
"os"
"github.com/hashicorp/go-multierror"
)
// ReadWriteSeekCloser is a file that can be read, written, seeked, and closed.
type ReadWriteSeekCloser interface {
io.ReadWriteSeeker
io.Closer
}
// GetTmpFile returns a temporary file for the reconciler to use.
// TODO: encrypt contents to ensure encryption at rest
func (c *service) GetTmpFile(key string) (ReadWriteSeekCloser, error) {
fd, err := os.CreateTemp(c.config.tmpDir, fmt.Sprintf("pomerium-bundle-%s", key))
if err != nil {
return nil, fmt.Errorf("create temp file: %w", err)
}
return &tmpFile{File: fd}, nil
}
type tmpFile struct {
*os.File
}
func (f *tmpFile) Close() error {
var errs *multierror.Error
if err := f.File.Close(); err != nil {
errs = multierror.Append(errs, err)
}
if err := os.Remove(f.File.Name()); err != nil {
errs = multierror.Append(errs, err)
}
return errs.ErrorOrNil()
}