From 3b65049d2f2ee74fa7d08b9de511f036ad197a76 Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Thu, 17 Aug 2023 13:19:51 -0400 Subject: [PATCH] zero: resource bundle reconciler (#4445) --- go.mod | 1 + go.sum | 1 + internal/zero/reconciler/bundles_format.go | 35 +++ .../zero/reconciler/bundles_format_test.go | 64 +++++ internal/zero/reconciler/bundles_queue.go | 124 ++++++++++ .../zero/reconciler/bundles_queue_test.go | 95 +++++++ internal/zero/reconciler/config.go | 110 +++++++++ internal/zero/reconciler/databroker.go | 81 ++++++ .../zero/reconciler/databroker_changeset.go | 53 ++++ internal/zero/reconciler/download_cache.go | 168 +++++++++++++ .../zero/reconciler/download_cache_test.go | 29 +++ internal/zero/reconciler/reconciler.go | 34 +++ internal/zero/reconciler/reconciler_test.go | 196 +++++++++++++++ internal/zero/reconciler/records.go | 132 ++++++++++ internal/zero/reconciler/records_test.go | 77 ++++++ internal/zero/reconciler/service.go | 89 +++++++ internal/zero/reconciler/sync.go | 231 ++++++++++++++++++ internal/zero/reconciler/tmpfile.go | 40 +++ 18 files changed, 1560 insertions(+) create mode 100644 internal/zero/reconciler/bundles_format.go create mode 100644 internal/zero/reconciler/bundles_format_test.go create mode 100644 internal/zero/reconciler/bundles_queue.go create mode 100644 internal/zero/reconciler/bundles_queue_test.go create mode 100644 internal/zero/reconciler/config.go create mode 100644 internal/zero/reconciler/databroker.go create mode 100644 internal/zero/reconciler/databroker_changeset.go create mode 100644 internal/zero/reconciler/download_cache.go create mode 100644 internal/zero/reconciler/download_cache_test.go create mode 100644 internal/zero/reconciler/reconciler.go create mode 100644 internal/zero/reconciler/reconciler_test.go create mode 100644 internal/zero/reconciler/records.go create mode 100644 internal/zero/reconciler/records_test.go create mode 100644 internal/zero/reconciler/service.go create mode 100644 internal/zero/reconciler/sync.go create mode 100644 internal/zero/reconciler/tmpfile.go diff --git a/go.mod b/go.mod index a4d6f769e..2773e5ac0 100644 --- a/go.mod +++ b/go.mod @@ -72,6 +72,7 @@ require ( golang.org/x/net v0.12.0 golang.org/x/oauth2 v0.10.0 golang.org/x/sync v0.3.0 + golang.org/x/time v0.3.0 google.golang.org/api v0.134.0 google.golang.org/genproto/googleapis/rpc v0.0.0-20230720185612-659f7aaaa771 google.golang.org/grpc v1.57.0 diff --git a/go.sum b/go.sum index 5863fe447..5ba32e781 100644 --- a/go.sum +++ b/go.sum @@ -1110,6 +1110,7 @@ golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/zero/reconciler/bundles_format.go b/internal/zero/reconciler/bundles_format.go new file mode 100644 index 000000000..001a806f6 --- /dev/null +++ b/internal/zero/reconciler/bundles_format.go @@ -0,0 +1,35 @@ +package reconciler + +import ( + "bufio" + "errors" + "fmt" + "io" + + "google.golang.org/protobuf/encoding/protodelim" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +var unmarshalOpts = protodelim.UnmarshalOptions{} + +// ReadBundleRecords reads records in a protobuf wire format from src. +// Each record is expected to be a databroker.Record. +func ReadBundleRecords(src io.Reader) (RecordSetBundle[DatabrokerRecord], error) { + r := bufio.NewReader(src) + rsb := make(RecordSetBundle[DatabrokerRecord]) + for { + record := new(databroker.Record) + err := unmarshalOpts.UnmarshalFrom(r, record) + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, fmt.Errorf("error reading protobuf record: %w", err) + } + + rsb.Add(DatabrokerRecord{record}) + } + + return rsb, nil +} diff --git a/internal/zero/reconciler/bundles_format_test.go b/internal/zero/reconciler/bundles_format_test.go new file mode 100644 index 000000000..0af272a63 --- /dev/null +++ b/internal/zero/reconciler/bundles_format_test.go @@ -0,0 +1,64 @@ +package reconciler + +import ( + "fmt" + "io" + "os" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protodelim" + "google.golang.org/protobuf/proto" + + "github.com/pomerium/pomerium/pkg/grpc/config" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +func TestReadRecords(t *testing.T) { + + dir := t.TempDir() + fd, err := os.CreateTemp(dir, "config") + require.NoError(t, err) + t.Cleanup(func() { _ = fd.Close() }) + + err = writeSampleRecords(fd) + require.NoError(t, err) + + _, err = fd.Seek(0, io.SeekStart) + require.NoError(t, err) + + records, err := ReadBundleRecords(fd) + require.NoError(t, err) + require.Len(t, records, 1) +} + +func writeSampleRecords(dst io.Writer) error { + var marshalOpts = protodelim.MarshalOptions{ + MarshalOptions: proto.MarshalOptions{ + AllowPartial: false, + Deterministic: true, + UseCachedSize: false, + }, + } + + cfg := protoutil.NewAny(&config.Config{ + Routes: []*config.Route{ + { + From: "https://from.example.com", + To: []string{"https://to.example.com"}, + }, + }, + }) + rec := &databroker.Record{ + Id: "config", + Type: cfg.GetTypeUrl(), + Data: cfg, + } + _, err := marshalOpts.MarshalTo(dst, rec) + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + + return nil +} diff --git a/internal/zero/reconciler/bundles_queue.go b/internal/zero/reconciler/bundles_queue.go new file mode 100644 index 000000000..8ed3f6b7b --- /dev/null +++ b/internal/zero/reconciler/bundles_queue.go @@ -0,0 +1,124 @@ +package reconciler + +import ( + "container/heap" + "sync" +) + +type bundle struct { + id string + synced bool + priority int +} + +type bundleHeap []bundle + +func (h bundleHeap) Len() int { return len(h) } +func (h bundleHeap) Less(i, j int) bool { + // If one is synced and the other is not, the unsynced one comes first + if h[i].synced != h[j].synced { + return !h[i].synced + } + // Otherwise, the one with the lower priority comes first + return h[i].priority < h[j].priority +} + +func (h bundleHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *bundleHeap) Push(x interface{}) { + item := x.(bundle) + *h = append(*h, item) +} + +func (h *bundleHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +// BundleQueue is a priority queue of bundles to sync. +type BundleQueue struct { + sync.Mutex + bundles bundleHeap + counter int // to assign priorities based on order of insertion +} + +// Set sets the bundles to be synced. This will reset the sync status of all bundles. +func (b *BundleQueue) Set(bundles []string) { + b.Lock() + defer b.Unlock() + + b.bundles = make(bundleHeap, len(bundles)) + b.counter = len(bundles) + for i, id := range bundles { + b.bundles[i] = bundle{ + id: id, + synced: false, + priority: i, + } + } + heap.Init(&b.bundles) +} + +// MarkForSync marks the bundle with the given ID for syncing. +func (b *BundleQueue) MarkForSync(id string) { + b.Lock() + defer b.Unlock() + + for i, bundle := range b.bundles { + if bundle.id == id { + b.bundles[i].synced = false + heap.Fix(&b.bundles, i) + return + } + } + + newBundle := bundle{id: id, synced: false, priority: b.counter} + heap.Push(&b.bundles, newBundle) + b.counter++ +} + +// MarkForSyncLater marks the bundle with the given ID for syncing later (after all other bundles). +func (b *BundleQueue) MarkForSyncLater(id string) { + b.Lock() + defer b.Unlock() + + for i, bundle := range b.bundles { + if bundle.id != id { + continue + } + + // Increase the counter first to ensure that this bundle has the highest (last) priority. + b.counter++ + b.bundles[i].synced = false + b.bundles[i].priority = b.counter + heap.Fix(&b.bundles, i) + return + } +} + +// GetNextBundleToSync returns the ID of the next bundle to sync and whether there is one. +func (b *BundleQueue) GetNextBundleToSync() (string, bool) { + b.Lock() + defer b.Unlock() + + if len(b.bundles) == 0 { + return "", false + } + + // Check the top bundle without popping + if b.bundles[0].synced { + return "", false + } + + // Mark the top bundle as synced and push it to the end + id := b.bundles[0].id + b.bundles[0].synced = true + b.bundles[0].priority = b.counter + heap.Fix(&b.bundles, 0) + b.counter++ + + return id, true +} diff --git a/internal/zero/reconciler/bundles_queue_test.go b/internal/zero/reconciler/bundles_queue_test.go new file mode 100644 index 000000000..dd2190292 --- /dev/null +++ b/internal/zero/reconciler/bundles_queue_test.go @@ -0,0 +1,95 @@ +package reconciler_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/zero/reconciler" +) + +func TestQueueSet(t *testing.T) { + t.Parallel() + + b := &reconciler.BundleQueue{} + b.Set([]string{"bundle1", "bundle2"}) + + id1, ok1 := b.GetNextBundleToSync() + id2, ok2 := b.GetNextBundleToSync() + + assert.True(t, ok1, "Expected bundle1 to be set") + assert.Equal(t, "bundle1", id1) + assert.True(t, ok2, "Expected bundle2 to be set") + assert.Equal(t, "bundle2", id2) + + id3, ok3 := b.GetNextBundleToSync() + assert.False(t, ok3, "Expected no more bundles to sync") + assert.Empty(t, id3) +} + +func TestQueueMarkForSync(t *testing.T) { + t.Parallel() + + b := &reconciler.BundleQueue{} + b.Set([]string{"bundle1", "bundle2"}) + + b.MarkForSync("bundle2") + id1, ok1 := b.GetNextBundleToSync() + + assert.True(t, ok1, "Expected bundle1 to be marked for sync") + assert.Equal(t, "bundle1", id1) + + b.MarkForSync("bundle3") + id2, ok2 := b.GetNextBundleToSync() + id3, ok3 := b.GetNextBundleToSync() + + assert.True(t, ok2, "Expected bundle2 to be marked for sync") + assert.Equal(t, "bundle2", id2) + assert.True(t, ok3, "Expected bundle3 to be marked for sync") + assert.Equal(t, "bundle3", id3) +} + +func TestQueueMarkForSyncLater(t *testing.T) { + t.Parallel() + + b := &reconciler.BundleQueue{} + b.Set([]string{"bundle1", "bundle2", "bundle3"}) + + id1, ok1 := b.GetNextBundleToSync() + b.MarkForSyncLater("bundle1") + id2, ok2 := b.GetNextBundleToSync() + id3, ok3 := b.GetNextBundleToSync() + id4, ok4 := b.GetNextBundleToSync() + id5, ok5 := b.GetNextBundleToSync() + + assert.True(t, ok1, "Expected bundle1 to be marked for sync") + assert.Equal(t, "bundle1", id1) + assert.True(t, ok2, "Expected bundle2 to be marked for sync") + assert.Equal(t, "bundle2", id2) + assert.True(t, ok3, "Expected bundle3 to be marked for sync") + assert.Equal(t, "bundle3", id3) + assert.True(t, ok4, "Expected bundle1 to be marked for sync") + assert.Equal(t, "bundle1", id4) + assert.False(t, ok5, "Expected no more bundles to sync") + assert.Empty(t, id5) + +} + +func TestQueueGetNextBundleToSync(t *testing.T) { + t.Parallel() + + b := &reconciler.BundleQueue{} + b.Set([]string{"bundle1", "bundle2"}) + + id1, ok1 := b.GetNextBundleToSync() + id2, ok2 := b.GetNextBundleToSync() + id3, ok3 := b.GetNextBundleToSync() + + assert.True(t, ok1, "Expected bundle1 to be retrieved for sync") + assert.Equal(t, "bundle1", id1) + assert.True(t, ok2, "Expected bundle2 to be retrieved for sync") + assert.Equal(t, "bundle2", id2) + require.False(t, ok3, "Expected no more bundles to sync") + assert.Empty(t, id3) +} diff --git a/internal/zero/reconciler/config.go b/internal/zero/reconciler/config.go new file mode 100644 index 000000000..d9f48546f --- /dev/null +++ b/internal/zero/reconciler/config.go @@ -0,0 +1,110 @@ +// Package reconciler syncs the state of resource bundles between the cloud and the databroker. +package reconciler + +import ( + "net/http" + "os" + "time" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" + sdk "github.com/pomerium/zero-sdk" +) + +// reconcilerConfig contains the configuration for the resource bundles reconciler. +type reconcilerConfig struct { + api *sdk.API + + databrokerClient databroker.DataBrokerServiceClient + databrokerRPS int + + tmpDir string + + httpClient *http.Client + + checkForUpdateIntervalWhenDisconnected time.Duration + checkForUpdateIntervalWhenConnected time.Duration + + syncBackoffMaxInterval time.Duration +} + +// Option configures the resource bundles reconciler +type Option func(*reconcilerConfig) + +// WithTemporaryDirectory configures the resource bundles client to use a temporary directory for +// downloading files. +func WithTemporaryDirectory(path string) Option { + return func(cfg *reconcilerConfig) { + cfg.tmpDir = path + } +} + +// WithAPI configures the cluster api client. +func WithAPI(client *sdk.API) Option { + return func(cfg *reconcilerConfig) { + cfg.api = client + } +} + +// WithDataBrokerClient configures the databroker client. +func WithDataBrokerClient(client databroker.DataBrokerServiceClient) Option { + return func(cfg *reconcilerConfig) { + cfg.databrokerClient = client + } +} + +// WithDownloadHTTPClient configures the http client used for downloading files. +func WithDownloadHTTPClient(client *http.Client) Option { + return func(cfg *reconcilerConfig) { + cfg.httpClient = client + } +} + +// WithDatabrokerRPSLimit configures the maximum number of requests per second to the databroker. +func WithDatabrokerRPSLimit(rps int) Option { + return func(cfg *reconcilerConfig) { + cfg.databrokerRPS = rps + } +} + +// WithCheckForUpdateIntervalWhenDisconnected configures the interval at which the reconciler will check +// for updates when disconnected from the cloud. +func WithCheckForUpdateIntervalWhenDisconnected(interval time.Duration) Option { + return func(cfg *reconcilerConfig) { + cfg.checkForUpdateIntervalWhenDisconnected = interval + } +} + +// WithCheckForUpdateIntervalWhenConnected configures the interval at which the reconciler will check +// for updates when connected to the cloud. +func WithCheckForUpdateIntervalWhenConnected(interval time.Duration) Option { + return func(cfg *reconcilerConfig) { + cfg.checkForUpdateIntervalWhenConnected = interval + } +} + +// WithSyncBackoffMaxInterval configures the maximum interval between sync attempts. +func WithSyncBackoffMaxInterval(interval time.Duration) Option { + return func(cfg *reconcilerConfig) { + cfg.syncBackoffMaxInterval = interval + } +} + +func newConfig(opts ...Option) *reconcilerConfig { + cfg := &reconcilerConfig{} + for _, opt := range []Option{ + WithTemporaryDirectory(os.TempDir()), + WithDownloadHTTPClient(http.DefaultClient), + WithDatabrokerRPSLimit(1_000), + WithCheckForUpdateIntervalWhenDisconnected(time.Minute * 5), + WithCheckForUpdateIntervalWhenConnected(time.Hour), + WithSyncBackoffMaxInterval(time.Minute), + } { + opt(cfg) + } + + for _, opt := range opts { + opt(cfg) + } + + return cfg +} diff --git a/internal/zero/reconciler/databroker.go b/internal/zero/reconciler/databroker.go new file mode 100644 index 000000000..f38f0c2bd --- /dev/null +++ b/internal/zero/reconciler/databroker.go @@ -0,0 +1,81 @@ +package reconciler + +import ( + "context" + "errors" + "fmt" + "io" + + "google.golang.org/protobuf/proto" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +// DatabrokerRecord is a wrapper around a databroker record. +type DatabrokerRecord struct { + V *databroker.Record +} + +var _ Record[DatabrokerRecord] = DatabrokerRecord{} + +// GetID returns the databroker record's ID. +func (r DatabrokerRecord) GetID() string { + return r.V.GetId() +} + +// GetType returns the databroker record's type. +func (r DatabrokerRecord) GetType() string { + return r.V.GetType() +} + +// Equal returns true if the databroker records are equal. +func (r DatabrokerRecord) Equal(other DatabrokerRecord) bool { + return r.V.Type == other.V.Type && + r.V.Id == other.V.Id && + proto.Equal(r.V.Data, other.V.Data) +} + +// GetDatabrokerRecords gets all databroker records of the given types. +func GetDatabrokerRecords( + ctx context.Context, + client databroker.DataBrokerServiceClient, + types []string, +) (RecordSetBundle[DatabrokerRecord], error) { + rsb := make(RecordSetBundle[DatabrokerRecord]) + + for _, typ := range types { + recs, err := getDatabrokerRecords(ctx, client, typ) + if err != nil { + return nil, fmt.Errorf("get databroker records for type %s: %w", typ, err) + } + rsb[typ] = recs + } + + return rsb, nil +} + +func getDatabrokerRecords( + ctx context.Context, + client databroker.DataBrokerServiceClient, + typ string, +) (RecordSet[DatabrokerRecord], error) { + stream, err := client.SyncLatest(ctx, &databroker.SyncLatestRequest{Type: typ}) + if err != nil { + return nil, fmt.Errorf("sync latest databroker: %w", err) + } + + recordSet := make(RecordSet[DatabrokerRecord]) + for { + res, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return nil, fmt.Errorf("receive databroker record: %w", err) + } + + if record := res.GetRecord(); record != nil { + recordSet[record.GetId()] = DatabrokerRecord{record} + } + } + return recordSet, nil +} diff --git a/internal/zero/reconciler/databroker_changeset.go b/internal/zero/reconciler/databroker_changeset.go new file mode 100644 index 000000000..1da210916 --- /dev/null +++ b/internal/zero/reconciler/databroker_changeset.go @@ -0,0 +1,53 @@ +package reconciler + +import ( + "context" + "fmt" + + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +// DatabrokerChangeSet is a set of databroker changes. +type DatabrokerChangeSet struct { + now *timestamppb.Timestamp + updates []*databroker.Record +} + +// NewDatabrokerChangeSet creates a new databroker change set. +func NewDatabrokerChangeSet() *DatabrokerChangeSet { + return &DatabrokerChangeSet{ + now: timestamppb.Now(), + } +} + +// Remove adds a record to the change set. +func (cs *DatabrokerChangeSet) Remove(typ string, id string) { + cs.updates = append(cs.updates, &databroker.Record{ + Type: typ, + Id: id, + DeletedAt: cs.now, + }) +} + +// Upsert adds a record to the change set. +func (cs *DatabrokerChangeSet) Upsert(record *databroker.Record) { + cs.updates = append(cs.updates, &databroker.Record{ + Type: record.Type, + Id: record.Id, + Data: record.Data, + }) +} + +// ApplyChanges applies the changes to the databroker. +func ApplyChanges(ctx context.Context, client databroker.DataBrokerServiceClient, changes *DatabrokerChangeSet) error { + updates := databroker.OptimumPutRequestsFromRecords(changes.updates) + for _, req := range updates { + _, err := client.Put(ctx, req) + if err != nil { + return fmt.Errorf("put databroker record: %w", err) + } + } + return nil +} diff --git a/internal/zero/reconciler/download_cache.go b/internal/zero/reconciler/download_cache.go new file mode 100644 index 000000000..2a709965b --- /dev/null +++ b/internal/zero/reconciler/download_cache.go @@ -0,0 +1,168 @@ +package reconciler + +import ( + "context" + "errors" + "fmt" + + "github.com/hashicorp/go-multierror" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" + zero_sdk "github.com/pomerium/zero-sdk" +) + +// BundleCacheEntry is a cache entry for a bundle +// that is kept in the databroker to avoid downloading +// the same bundle multiple times. +// +// by using the ETag and LastModified headers, we do not need to +// keep caches of the bundles themselves, which can be large. +// +// also it works in case of multiple instances, as it uses +// the databroker database as a shared cache. +type BundleCacheEntry struct { + zero_sdk.DownloadConditional + RecordTypes []string +} + +const ( + bundleCacheEntryRecordType = "pomerium.io/BundleCacheEntry" +) + +var ( + // ErrBundleCacheEntryNotFound is returned when a bundle cache entry is not found + ErrBundleCacheEntryNotFound = errors.New("bundle cache entry not found") +) + +// GetBundleCacheEntry gets a bundle cache entry from the databroker +func (c *service) GetBundleCacheEntry(ctx context.Context, id string) (*BundleCacheEntry, error) { + record, err := c.config.databrokerClient.Get(ctx, &databroker.GetRequest{ + Type: bundleCacheEntryRecordType, + Id: id, + }) + if err != nil && status.Code(err) == codes.NotFound { + return nil, ErrBundleCacheEntryNotFound + } else if err != nil { + return nil, fmt.Errorf("get bundle cache entry: %w", err) + } + + var dst BundleCacheEntry + data := record.GetRecord().GetData() + err = dst.FromAny(data) + if err != nil { + log.Ctx(ctx).Error().Err(err). + Str("bundle-id", id). + Str("data", protojson.Format(data)). + Msg("could not unmarshal bundle cache entry") + // we would allow it to be overwritten by the update process + return nil, ErrBundleCacheEntryNotFound + } + + return &dst, nil +} + +// SetBundleCacheEntry sets a bundle cache entry in the databroker +func (c *service) SetBundleCacheEntry(ctx context.Context, id string, src BundleCacheEntry) error { + val, err := src.ToAny() + if err != nil { + return fmt.Errorf("marshal bundle cache entry: %w", err) + } + _, err = c.config.databrokerClient.Put(ctx, &databroker.PutRequest{ + Records: []*databroker.Record{ + { + Type: bundleCacheEntryRecordType, + Id: id, + Data: val, + }, + }, + }) + if err != nil { + return fmt.Errorf("set bundle cache entry: %w", err) + } + return nil +} + +// ToAny marshals a BundleCacheEntry into an anypb.Any +func (r *BundleCacheEntry) ToAny() (*anypb.Any, error) { + err := r.Validate() + if err != nil { + return nil, fmt.Errorf("validate: %w", err) + } + + types := make([]*structpb.Value, 0, len(r.RecordTypes)) + for _, t := range r.RecordTypes { + types = append(types, structpb.NewStringValue(t)) + } + + return protoutil.NewAny(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "etag": structpb.NewStringValue(r.ETag), + "last_modified": structpb.NewStringValue(r.LastModified), + "record_types": structpb.NewListValue(&structpb.ListValue{Values: types}), + }, + }), nil +} + +// FromAny unmarshals an anypb.Any into a BundleCacheEntry +func (r *BundleCacheEntry) FromAny(any *anypb.Any) error { + var s structpb.Struct + err := any.UnmarshalTo(&s) + if err != nil { + return fmt.Errorf("unmarshal struct: %w", err) + } + + r.ETag = s.GetFields()["etag"].GetStringValue() + r.LastModified = s.GetFields()["last_modified"].GetStringValue() + + for _, v := range s.GetFields()["record_types"].GetListValue().GetValues() { + r.RecordTypes = append(r.RecordTypes, v.GetStringValue()) + } + + err = r.Validate() + if err != nil { + return fmt.Errorf("validate: %w", err) + } + return nil +} + +// Validate validates a BundleCacheEntry +func (r *BundleCacheEntry) Validate() error { + var errs *multierror.Error + if len(r.RecordTypes) == 0 { + errs = multierror.Append(errs, errors.New("record_types is required")) + } + if err := r.DownloadConditional.Validate(); err != nil { + errs = multierror.Append(errs, err) + } + return errs.ErrorOrNil() +} + +// GetDownloadConditional returns conditional download information +func (r *BundleCacheEntry) GetDownloadConditional() *zero_sdk.DownloadConditional { + if r == nil { + return nil + } + cond := r.DownloadConditional + return &cond +} + +// GetRecordTypes returns the record types +func (r *BundleCacheEntry) GetRecordTypes() []string { + if r == nil { + return nil + } + return r.RecordTypes +} + +// Equals returns true if the two cache entries are equal +func (r *BundleCacheEntry) Equals(other *BundleCacheEntry) bool { + return r != nil && other != nil && + r.ETag == other.ETag && r.LastModified == other.LastModified +} diff --git a/internal/zero/reconciler/download_cache_test.go b/internal/zero/reconciler/download_cache_test.go new file mode 100644 index 000000000..f0d608a98 --- /dev/null +++ b/internal/zero/reconciler/download_cache_test.go @@ -0,0 +1,29 @@ +package reconciler_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/zero/reconciler" + zero_sdk "github.com/pomerium/zero-sdk" +) + +func TestCacheEntryProto(t *testing.T) { + t.Parallel() + + original := reconciler.BundleCacheEntry{ + DownloadConditional: zero_sdk.DownloadConditional{ + ETag: "etag value", + LastModified: "2009-02-13 18:31:30 -0500 EST", + }, + RecordTypes: []string{"one", "two"}, + } + originalProto, err := original.ToAny() + require.NoError(t, err) + var unmarshaled reconciler.BundleCacheEntry + err = unmarshaled.FromAny(originalProto) + require.NoError(t, err) + assert.True(t, original.Equals(&unmarshaled)) +} diff --git a/internal/zero/reconciler/reconciler.go b/internal/zero/reconciler/reconciler.go new file mode 100644 index 000000000..aa1b237d4 --- /dev/null +++ b/internal/zero/reconciler/reconciler.go @@ -0,0 +1,34 @@ +package reconciler + +import ( + "context" + "fmt" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +// Reconcile reconciles the target and current record sets with the databroker. +func Reconcile( + ctx context.Context, + client databroker.DataBrokerServiceClient, + target, current RecordSetBundle[DatabrokerRecord], +) error { + updates := NewDatabrokerChangeSet() + + for _, rec := range current.GetRemoved(target).Flatten() { + updates.Remove(rec.GetType(), rec.GetID()) + } + for _, rec := range current.GetModified(target).Flatten() { + updates.Upsert(rec.V) + } + for _, rec := range current.GetAdded(target).Flatten() { + updates.Upsert(rec.V) + } + + err := ApplyChanges(ctx, client, updates) + if err != nil { + return fmt.Errorf("apply databroker changes: %w", err) + } + + return nil +} diff --git a/internal/zero/reconciler/reconciler_test.go b/internal/zero/reconciler/reconciler_test.go new file mode 100644 index 000000000..9935c58e5 --- /dev/null +++ b/internal/zero/reconciler/reconciler_test.go @@ -0,0 +1,196 @@ +package reconciler_test + +import ( + "context" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + "google.golang.org/protobuf/types/known/wrapperspb" + + databroker_int "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/internal/zero/reconciler" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +func newDatabroker(t *testing.T) (context.Context, databroker.DataBrokerServiceClient) { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + gs := grpc.NewServer() + srv := databroker_int.New() + + databroker.RegisterDataBrokerServiceServer(gs, srv) + + lis := bufconn.Listen(1) + t.Cleanup(func() { + lis.Close() + gs.Stop() + }) + + go func() { _ = gs.Serve(lis) }() + + conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(func(context.Context, string) (conn net.Conn, e error) { + return lis.Dial() + }), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + + return ctx, databroker.NewDataBrokerServiceClient(conn) +} + +func newRecordBundle(records []testRecord) reconciler.RecordSetBundle[reconciler.DatabrokerRecord] { + bundle := make(reconciler.RecordSetBundle[reconciler.DatabrokerRecord]) + for _, r := range records { + bundle.Add(newRecord(r)) + } + return bundle +} + +func newRecord(r testRecord) reconciler.DatabrokerRecord { + return reconciler.DatabrokerRecord{ + V: &databroker.Record{ + Type: r.Type, + Id: r.ID, + Data: protoutil.NewAnyString(r.Val), + }} +} + +func assertBundle(t *testing.T, want []testRecord, got reconciler.RecordSetBundle[reconciler.DatabrokerRecord]) { + t.Helper() + + for _, wantRecord := range want { + gotRecord, ok := got.Get(wantRecord.Type, wantRecord.ID) + if assert.True(t, ok, "record %s/%s not found", wantRecord.Type, wantRecord.ID) { + assertRecord(t, wantRecord, gotRecord) + } + } + assert.Len(t, got.Flatten(), len(want)) +} + +func assertRecord(t *testing.T, want testRecord, got reconciler.DatabrokerRecord) { + t.Helper() + + var val wrapperspb.StringValue + err := got.V.Data.UnmarshalTo(&val) + require.NoError(t, err) + + assert.Equal(t, want.Type, got.V.Type) + assert.Equal(t, want.ID, got.V.Id) + assert.Equal(t, want.Val, val.Value) +} + +func TestHelpers(t *testing.T) { + want := []testRecord{ + {"type1", "id1", "value1"}, + {"type1", "id2", "value2"}, + } + + bundle := newRecordBundle(want) + assertBundle(t, want, bundle) +} + +func wantRemoved(want, current []string) []string { + wantM := make(map[string]struct{}, len(want)) + for _, w := range want { + wantM[w] = struct{}{} + } + var toRemove []string + for _, c := range current { + if _, ok := wantM[c]; !ok { + toRemove = append(toRemove, c) + } + } + return toRemove +} + +func reconcile( + ctx context.Context, + t *testing.T, + client databroker.DataBrokerServiceClient, + want []testRecord, + current reconciler.RecordSetBundle[reconciler.DatabrokerRecord], +) reconciler.RecordSetBundle[reconciler.DatabrokerRecord] { + t.Helper() + + wantBundle := newRecordBundle(want) + err := reconciler.Reconcile(ctx, client, wantBundle, current) + require.NoError(t, err) + + got, err := reconciler.GetDatabrokerRecords(ctx, client, wantBundle.RecordTypes()) + require.NoError(t, err) + assertBundle(t, want, got) + + res, err := reconciler.GetDatabrokerRecords(ctx, client, wantRemoved(wantBundle.RecordTypes(), current.RecordTypes())) + require.NoError(t, err) + assert.Empty(t, res.Flatten()) + + return got +} + +func TestReconcile(t *testing.T) { + t.Parallel() + + ctx, client := newDatabroker(t) + + err := reconciler.Reconcile(ctx, client, nil, nil) + require.NoError(t, err) + + var current reconciler.RecordSetBundle[reconciler.DatabrokerRecord] + for _, tc := range []struct { + name string + want []testRecord + }{ + {"empty", nil}, + {"initial", []testRecord{ + {"type1", "id1", "value1"}, + {"type1", "id2", "value2"}, + }}, + {"add one", []testRecord{ + {"type1", "id1", "value1"}, + {"type1", "id2", "value2"}, + {"type1", "id3", "value3"}, + }}, + {"update one", []testRecord{ + {"type1", "id1", "value1"}, + {"type1", "id2", "value2-updated"}, + {"type1", "id3", "value3"}, + }}, + {"delete one", []testRecord{ + {"type1", "id1", "value1"}, + {"type1", "id3", "value3"}, + }}, + {"delete all", nil}, + {"multiple types", []testRecord{ + {"type1", "id1", "value1"}, + {"type1", "id2", "value2"}, + {"type2", "id1", "value1"}, + {"type2", "id2", "value2"}, + }}, + {"multiple types update", []testRecord{ + {"type1", "id1", "value1"}, + {"type1", "id2", "value2-updated"}, + {"type2", "id1", "value1"}, + {"type2", "id2", "value2-updated"}, + }}, + {"multiple types delete", []testRecord{ + {"type1", "id1", "value1"}, + {"type2", "id1", "value1"}, + }}, + {"multiple types delete one type, add one value", []testRecord{ + {"type1", "id1", "value1"}, + {"type1", "id4", "value4"}, + }}, + } { + t.Run(tc.name, func(t *testing.T) { + current = reconcile(ctx, t, client, tc.want, current) + }) + } +} diff --git a/internal/zero/reconciler/records.go b/internal/zero/reconciler/records.go new file mode 100644 index 000000000..b6094a615 --- /dev/null +++ b/internal/zero/reconciler/records.go @@ -0,0 +1,132 @@ +package reconciler + +// RecordSetBundle is an index of databroker records by type +type RecordSetBundle[T Record[T]] map[string]RecordSet[T] + +// RecordSet is an index of databroker records by their id. +type RecordSet[T Record[T]] map[string]T + +// Record is a record +type Record[T any] interface { + GetID() string + GetType() string + Equal(other T) bool +} + +// RecordTypes returns the types of records in the bundle. +func (rsb RecordSetBundle[T]) RecordTypes() []string { + types := make([]string, 0, len(rsb)) + for typ := range rsb { + types = append(types, typ) + } + return types +} + +// Add adds a record to the bundle. +func (rsb RecordSetBundle[T]) Add(record T) { + rs, ok := rsb[record.GetType()] + if !ok { + rs = make(RecordSet[T]) + rsb[record.GetType()] = rs + } + rs[record.GetID()] = record +} + +// GetAdded returns the records that are in other but not in rsb. +func (rsb RecordSetBundle[T]) GetAdded(other RecordSetBundle[T]) RecordSetBundle[T] { + added := make(RecordSetBundle[T]) + for otherType, otherRS := range other { + rs, ok := rsb[otherType] + if !ok { + added[otherType] = otherRS + continue + } + rss := rs.GetAdded(other[otherType]) + if len(rss) > 0 { + added[otherType] = rss + } + } + return added +} + +// GetRemoved returns the records that are in rs but not in other. +func (rsb RecordSetBundle[T]) GetRemoved(other RecordSetBundle[T]) RecordSetBundle[T] { + return other.GetAdded(rsb) +} + +// GetModified returns the records that are in both rs and other but have different data. +func (rsb RecordSetBundle[T]) GetModified(other RecordSetBundle[T]) RecordSetBundle[T] { + modified := make(RecordSetBundle[T]) + for otherType, otherRS := range other { + rs, ok := rsb[otherType] + if !ok { + continue + } + m := rs.GetModified(otherRS) + if len(m) > 0 { + modified[otherType] = m + } + } + return modified +} + +// GetAdded returns the records that are in other but not in rs. +func (rs RecordSet[T]) GetAdded(other RecordSet[T]) RecordSet[T] { + added := make(RecordSet[T]) + for id, record := range other { + if _, ok := rs[id]; !ok { + added[id] = record + } + } + return added +} + +// GetRemoved returns the records that are in rs but not in other. +func (rs RecordSet[T]) GetRemoved(other RecordSet[T]) RecordSet[T] { + return other.GetAdded(rs) +} + +// GetModified returns the records that are in both rs and other but have different data. +// by comparing the protobuf bytes of the payload. +func (rs RecordSet[T]) GetModified(other RecordSet[T]) RecordSet[T] { + modified := make(RecordSet[T]) + for id, record := range other { + otherRecord, ok := rs[id] + if !ok { + continue + } + + if !record.Equal(otherRecord) { + modified[id] = record + } + } + return modified +} + +// Flatten returns all records in the set. +func (rs RecordSet[T]) Flatten() []T { + records := make([]T, 0, len(rs)) + for _, record := range rs { + records = append(records, record) + } + return records +} + +// Flatten returns all records in the bundle. +func (rsb RecordSetBundle[T]) Flatten() []T { + records := make([]T, 0) + for _, rs := range rsb { + records = append(records, rs.Flatten()...) + } + return records +} + +// Get returns a record by type and id. +func (rsb RecordSetBundle[T]) Get(typeName, id string) (record T, ok bool) { + rs, ok := rsb[typeName] + if !ok { + return + } + record, ok = rs[id] + return +} diff --git a/internal/zero/reconciler/records_test.go b/internal/zero/reconciler/records_test.go new file mode 100644 index 000000000..5d1864538 --- /dev/null +++ b/internal/zero/reconciler/records_test.go @@ -0,0 +1,77 @@ +package reconciler_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/zero/reconciler" +) + +type testRecord struct { + Type string + ID string + Val string +} + +func (r testRecord) GetID() string { + return r.ID +} + +func (r testRecord) GetType() string { + return r.Type +} + +func (r testRecord) Equal(other testRecord) bool { + return r.ID == other.ID && r.Type == other.Type && r.Val == other.Val +} + +func TestRecords(t *testing.T) { + initial := make(reconciler.RecordSetBundle[testRecord]) + initial.Add(testRecord{ID: "1", Type: "a", Val: "a-1"}) + initial.Add(testRecord{ID: "2", Type: "a", Val: "a-2"}) + initial.Add(testRecord{ID: "1", Type: "b", Val: "b-1"}) + + // test record types + assert.ElementsMatch(t, []string{"a", "b"}, initial.RecordTypes()) + + // test added, deleted and modified + updated := make(reconciler.RecordSetBundle[testRecord]) + updated.Add(testRecord{ID: "1", Type: "a", Val: "a-1-1"}) + updated.Add(testRecord{ID: "3", Type: "a", Val: "a-3"}) + updated.Add(testRecord{ID: "1", Type: "b", Val: "b-1"}) + updated.Add(testRecord{ID: "2", Type: "b", Val: "b-2"}) + updated.Add(testRecord{ID: "1", Type: "c", Val: "c-1"}) + + assert.ElementsMatch(t, []string{"a", "b", "c"}, updated.RecordTypes()) + + added := initial.GetAdded(updated) + assert.Equal(t, + reconciler.RecordSetBundle[testRecord]{ + "a": reconciler.RecordSet[testRecord]{ + "3": {ID: "3", Type: "a", Val: "a-3"}, + }, + "b": reconciler.RecordSet[testRecord]{ + "2": {ID: "2", Type: "b", Val: "b-2"}, + }, + "c": reconciler.RecordSet[testRecord]{ + "1": {ID: "1", Type: "c", Val: "c-1"}, + }, + }, added) + + removed := initial.GetRemoved(updated) + assert.Equal(t, + reconciler.RecordSetBundle[testRecord]{ + "a": reconciler.RecordSet[testRecord]{ + "2": {ID: "2", Type: "a", Val: "a-2"}, + }, + }, removed) + + modified := initial.GetModified(updated) + assert.Equal(t, + reconciler.RecordSetBundle[testRecord]{ + "a": reconciler.RecordSet[testRecord]{ + "1": {ID: "1", Type: "a", Val: "a-1-1"}, + }, + }, modified) +} diff --git a/internal/zero/reconciler/service.go b/internal/zero/reconciler/service.go new file mode 100644 index 000000000..1fb951d59 --- /dev/null +++ b/internal/zero/reconciler/service.go @@ -0,0 +1,89 @@ +package reconciler + +/* + * This is a main control loop for the reconciler service. + * + */ + +import ( + "context" + "time" + + "golang.org/x/sync/errgroup" + "golang.org/x/time/rate" + + "github.com/pomerium/pomerium/internal/atomicutil" + connect_mux "github.com/pomerium/zero-sdk/connect-mux" +) + +type service struct { + config *reconcilerConfig + + databrokerRateLimit *rate.Limiter + + bundles BundleQueue + + fullSyncRequest chan struct{} + bundleSyncRequest chan struct{} + periodicUpdateInterval atomicutil.Value[time.Duration] +} + +// Run creates a new bundle updater client +// that runs until the context is canceled or a fatal error occurs. +func Run(ctx context.Context, opts ...Option) error { + config := newConfig(opts...) + + c := &service{ + config: config, + databrokerRateLimit: rate.NewLimiter(rate.Limit(config.databrokerRPS), 1), + fullSyncRequest: make(chan struct{}, 1), + } + c.periodicUpdateInterval.Store(config.checkForUpdateIntervalWhenDisconnected) + + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { return c.watchUpdates(ctx) }) + eg.Go(func() error { return c.SyncLoop(ctx) }) + + return eg.Wait() +} + +// run is a main control loop. +// it is very simple and sequential download and reconcile. +// it may be later optimized by splitting between download and reconciliation process, +// as we would get more resource bundles beyond the config. +func (c *service) watchUpdates(ctx context.Context) error { + return c.config.api.Watch(ctx, + connect_mux.WithOnConnected(func(ctx context.Context) { + c.triggerFullUpdate(true) + }), + connect_mux.WithOnDisconnected(func(_ context.Context) { + c.triggerFullUpdate(false) + }), + connect_mux.WithOnBundleUpdated(func(_ context.Context, key string) { + c.triggerBundleUpdate(key) + }), + ) +} + +func (c *service) triggerBundleUpdate(id string) { + c.periodicUpdateInterval.Store(c.config.checkForUpdateIntervalWhenConnected) + c.bundles.MarkForSync(id) + + select { + case c.fullSyncRequest <- struct{}{}: + default: + } +} + +func (c *service) triggerFullUpdate(connected bool) { + timeout := c.config.checkForUpdateIntervalWhenDisconnected + if connected { + timeout = c.config.checkForUpdateIntervalWhenConnected + } + c.periodicUpdateInterval.Store(timeout) + + select { + case c.fullSyncRequest <- struct{}{}: + default: + } +} diff --git a/internal/zero/reconciler/sync.go b/internal/zero/reconciler/sync.go new file mode 100644 index 000000000..0630d0b91 --- /dev/null +++ b/internal/zero/reconciler/sync.go @@ -0,0 +1,231 @@ +package reconciler + +/* + * Sync syncs the bundles between their cloud source and the databroker. + * + * FullSync performs a full sync of the bundles by calling the API, + * and walking the list of bundles, and calling SyncBundle on each. + * It also removes any records in the databroker that are not in the list of bundles. + * + * WatchAndSync watches the API for changes, and calls SyncBundle on each change. + * + */ + +import ( + "context" + "errors" + "fmt" + "io" + "time" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/retry" +) + +// Sync synchronizes the bundles between their cloud source and the databroker. +func (c *service) SyncLoop(ctx context.Context) error { + ticker := time.NewTicker(c.periodicUpdateInterval.Load()) + defer ticker.Stop() + + for { + dur := c.periodicUpdateInterval.Load() + ticker.Reset(dur) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.bundleSyncRequest: + log.Ctx(ctx).Info().Msg("bundle sync triggered") + err := c.syncBundles(ctx) + if err != nil { + return fmt.Errorf("reconciler: sync bundles: %w", err) + } + case <-c.fullSyncRequest: + log.Ctx(ctx).Info().Msg("full sync triggered") + err := c.syncAll(ctx) + if err != nil { + return fmt.Errorf("reconciler: sync all: %w", err) + } + case <-ticker.C: + log.Ctx(ctx).Info().Msg("periodic sync triggered") + err := c.syncAll(ctx) + if err != nil { + return fmt.Errorf("reconciler: sync all: %w", err) + } + } + } +} + +func (c *service) syncAll(ctx context.Context) error { + err := c.syncBundleList(ctx) + if err != nil { + return fmt.Errorf("sync bundle list: %w", err) + } + + err = c.syncBundles(ctx) + if err != nil { + return fmt.Errorf("sync bundles: %w", err) + } + + return nil +} + +// trySyncAllBundles tries to sync all bundles in the queue. +func (c *service) syncBundleList(ctx context.Context) error { + // refresh bundle list, + // ignoring other signals while we're retrying + return retry.Retry(ctx, + "refresh bundle list", c.refreshBundleList, + retry.WithWatch("refresh bundle list", c.fullSyncRequest, nil), + retry.WithWatch("bundle update", c.bundleSyncRequest, nil), + ) +} + +// syncBundles retries until there are no more bundles to sync. +// updates bundle list if the full bundle update request arrives. +func (c *service) syncBundles(ctx context.Context) error { + return retry.Retry(ctx, + "sync bundles", c.trySyncBundles, + retry.WithWatch("refresh bundle list", c.fullSyncRequest, c.refreshBundleList), + retry.WithWatch("bundle update", c.bundleSyncRequest, nil), + ) +} + +// trySyncAllBundles tries to sync all bundles in the queue +// it returns nil if all bundles were synced successfully +func (c *service) trySyncBundles(ctx context.Context) error { + for { + id, ok := c.bundles.GetNextBundleToSync() + if !ok { // no more bundles to sync + return nil + } + + err := c.syncBundle(ctx, id) + if err != nil { + c.bundles.MarkForSyncLater(id) + return fmt.Errorf("sync bundle %s: %w", id, err) + } + } +} + +// syncBundle syncs the bundle to the databroker. +// Databroker holds last synced bundle state in form of a (etag, last-modified) tuple. +// This is only persisted in the databroker after all records are successfully synced. +// That allows us to ignore any changes based on the same bundle state, without need to re-check all records between bundle and databroker. +func (c *service) syncBundle(ctx context.Context, key string) error { + cached, err := c.GetBundleCacheEntry(ctx, key) + if err != nil && !errors.Is(err, ErrBundleCacheEntryNotFound) { + return fmt.Errorf("get bundle cache entry: %w", err) + } + + // download is much faster compared to databroker sync, + // so we don't use pipe but rather download to a temp file and then sync it to databroker + fd, err := c.GetTmpFile(key) + if err != nil { + return fmt.Errorf("get tmp file: %w", err) + } + defer func() { + if err := fd.Close(); err != nil { + log.Ctx(ctx).Error().Err(err).Msg("close tmp file") + } + }() + + conditional := cached.GetDownloadConditional() + log.Ctx(ctx).Debug().Str("id", key).Any("conditional", conditional).Msg("downloading bundle") + + result, err := c.config.api.DownloadClusterResourceBundle(ctx, fd, key, conditional) + if err != nil { + return fmt.Errorf("download bundle: %w", err) + } + + if result.NotModified { + log.Ctx(ctx).Debug().Str("bundle", key).Msg("bundle not changed") + return nil + } + + log.Ctx(ctx).Debug().Str("bundle", key). + Interface("cached-entry", cached). + Interface("current-entry", result.DownloadConditional). + Msg("bundle updated") + + _, err = fd.Seek(0, io.SeekStart) + if err != nil { + return fmt.Errorf("seek to start: %w", err) + } + + bundleRecordTypes, err := c.syncBundleToDatabroker(ctx, fd, cached.GetRecordTypes()) + if err != nil { + return fmt.Errorf("apply bundle to databroker: %w", err) + } + current := BundleCacheEntry{ + DownloadConditional: *result.DownloadConditional, + RecordTypes: bundleRecordTypes, + } + + log.Ctx(ctx).Info(). + Str("bundle", key). + Strs("record_types", bundleRecordTypes). + Str("etag", current.ETag). + Str("last_modified", current.LastModified). + Msg("bundle synced") + + err = c.SetBundleCacheEntry(ctx, key, current) + if err != nil { + return fmt.Errorf("set bundle cache entry: %w", err) + } + + return nil +} + +func strUnion(a, b []string) []string { + m := make(map[string]struct{}, len(a)+len(b)) + for _, s := range a { + m[s] = struct{}{} + } + for _, s := range b { + m[s] = struct{}{} + } + + out := make([]string, 0, len(m)) + for s := range m { + out = append(out, s) + } + return out +} + +func (c *service) syncBundleToDatabroker(ctx context.Context, src io.Reader, currentRecordTypes []string) ([]string, error) { + bundleRecords, err := ReadBundleRecords(src) + if err != nil { + return nil, fmt.Errorf("read bundle records: %w", err) + } + + databrokerRecords, err := GetDatabrokerRecords(ctx, + c.config.databrokerClient, + strUnion(bundleRecords.RecordTypes(), currentRecordTypes), + ) + if err != nil { + return nil, fmt.Errorf("get databroker records: %w", err) + } + + err = Reconcile(ctx, c.config.databrokerClient, bundleRecords, databrokerRecords) + if err != nil { + return nil, fmt.Errorf("reconcile databroker records: %w", err) + } + + return bundleRecords.RecordTypes(), nil +} + +func (c *service) refreshBundleList(ctx context.Context) error { + resp, err := c.config.api.GetClusterResourceBundles(ctx) + if err != nil { + return fmt.Errorf("get bundles: %w", err) + } + + ids := make([]string, 0, len(resp.Bundles)) + for _, v := range resp.Bundles { + ids = append(ids, v.Id) + } + + c.bundles.Set(ids) + return nil +} diff --git a/internal/zero/reconciler/tmpfile.go b/internal/zero/reconciler/tmpfile.go new file mode 100644 index 000000000..527186bd6 --- /dev/null +++ b/internal/zero/reconciler/tmpfile.go @@ -0,0 +1,40 @@ +package reconciler + +import ( + "fmt" + "io" + "os" + + "github.com/hashicorp/go-multierror" +) + +// ReadWriteSeekCloser is a file that can be read, written, seeked, and closed. +type ReadWriteSeekCloser interface { + io.ReadWriteSeeker + io.Closer +} + +// GetTmpFile returns a temporary file for the reconciler to use. +// TODO: encrypt contents to ensure encryption at rest +func (c *service) GetTmpFile(key string) (ReadWriteSeekCloser, error) { + fd, err := os.CreateTemp(c.config.tmpDir, fmt.Sprintf("pomerium-bundle-%s", key)) + if err != nil { + return nil, fmt.Errorf("create temp file: %w", err) + } + return &tmpFile{File: fd}, nil +} + +type tmpFile struct { + *os.File +} + +func (f *tmpFile) Close() error { + var errs *multierror.Error + if err := f.File.Close(); err != nil { + errs = multierror.Append(errs, err) + } + if err := os.Remove(f.File.Name()); err != nil { + errs = multierror.Append(errs, err) + } + return errs.ErrorOrNil() +}