mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-01 10:22:43 +02:00
zero: resource bundle reconciler (#4445)
This commit is contained in:
parent
c0b1309e90
commit
ea8762d706
17 changed files with 1559 additions and 0 deletions
1
go.mod
1
go.mod
|
@ -69,6 +69,7 @@ require (
|
|||
golang.org/x/net v0.17.0
|
||||
golang.org/x/oauth2 v0.12.0
|
||||
golang.org/x/sync v0.3.0
|
||||
golang.org/x/time v0.3.0
|
||||
google.golang.org/api v0.143.0
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13
|
||||
google.golang.org/grpc v1.58.3
|
||||
|
|
35
internal/zero/reconciler/bundles_format.go
Normal file
35
internal/zero/reconciler/bundles_format.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package reconciler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protodelim"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
var unmarshalOpts = protodelim.UnmarshalOptions{}
|
||||
|
||||
// ReadBundleRecords reads records in a protobuf wire format from src.
|
||||
// Each record is expected to be a databroker.Record.
|
||||
func ReadBundleRecords(src io.Reader) (RecordSetBundle[DatabrokerRecord], error) {
|
||||
r := bufio.NewReader(src)
|
||||
rsb := make(RecordSetBundle[DatabrokerRecord])
|
||||
for {
|
||||
record := new(databroker.Record)
|
||||
err := unmarshalOpts.UnmarshalFrom(r, record)
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading protobuf record: %w", err)
|
||||
}
|
||||
|
||||
rsb.Add(DatabrokerRecord{record})
|
||||
}
|
||||
|
||||
return rsb, nil
|
||||
}
|
64
internal/zero/reconciler/bundles_format_test.go
Normal file
64
internal/zero/reconciler/bundles_format_test.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package reconciler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/encoding/protodelim"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/config"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
func TestReadRecords(t *testing.T) {
|
||||
|
||||
dir := t.TempDir()
|
||||
fd, err := os.CreateTemp(dir, "config")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = fd.Close() })
|
||||
|
||||
err = writeSampleRecords(fd)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = fd.Seek(0, io.SeekStart)
|
||||
require.NoError(t, err)
|
||||
|
||||
records, err := ReadBundleRecords(fd)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, records, 1)
|
||||
}
|
||||
|
||||
func writeSampleRecords(dst io.Writer) error {
|
||||
var marshalOpts = protodelim.MarshalOptions{
|
||||
MarshalOptions: proto.MarshalOptions{
|
||||
AllowPartial: false,
|
||||
Deterministic: true,
|
||||
UseCachedSize: false,
|
||||
},
|
||||
}
|
||||
|
||||
cfg := protoutil.NewAny(&config.Config{
|
||||
Routes: []*config.Route{
|
||||
{
|
||||
From: "https://from.example.com",
|
||||
To: []string{"https://to.example.com"},
|
||||
},
|
||||
},
|
||||
})
|
||||
rec := &databroker.Record{
|
||||
Id: "config",
|
||||
Type: cfg.GetTypeUrl(),
|
||||
Data: cfg,
|
||||
}
|
||||
_, err := marshalOpts.MarshalTo(dst, rec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
124
internal/zero/reconciler/bundles_queue.go
Normal file
124
internal/zero/reconciler/bundles_queue.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
package reconciler
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type bundle struct {
|
||||
id string
|
||||
synced bool
|
||||
priority int
|
||||
}
|
||||
|
||||
type bundleHeap []bundle
|
||||
|
||||
func (h bundleHeap) Len() int { return len(h) }
|
||||
func (h bundleHeap) Less(i, j int) bool {
|
||||
// If one is synced and the other is not, the unsynced one comes first
|
||||
if h[i].synced != h[j].synced {
|
||||
return !h[i].synced
|
||||
}
|
||||
// Otherwise, the one with the lower priority comes first
|
||||
return h[i].priority < h[j].priority
|
||||
}
|
||||
|
||||
func (h bundleHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
|
||||
func (h *bundleHeap) Push(x interface{}) {
|
||||
item := x.(bundle)
|
||||
*h = append(*h, item)
|
||||
}
|
||||
|
||||
func (h *bundleHeap) Pop() interface{} {
|
||||
old := *h
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*h = old[0 : n-1]
|
||||
return x
|
||||
}
|
||||
|
||||
// BundleQueue is a priority queue of bundles to sync.
|
||||
type BundleQueue struct {
|
||||
sync.Mutex
|
||||
bundles bundleHeap
|
||||
counter int // to assign priorities based on order of insertion
|
||||
}
|
||||
|
||||
// Set sets the bundles to be synced. This will reset the sync status of all bundles.
|
||||
func (b *BundleQueue) Set(bundles []string) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
b.bundles = make(bundleHeap, len(bundles))
|
||||
b.counter = len(bundles)
|
||||
for i, id := range bundles {
|
||||
b.bundles[i] = bundle{
|
||||
id: id,
|
||||
synced: false,
|
||||
priority: i,
|
||||
}
|
||||
}
|
||||
heap.Init(&b.bundles)
|
||||
}
|
||||
|
||||
// MarkForSync marks the bundle with the given ID for syncing.
|
||||
func (b *BundleQueue) MarkForSync(id string) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
for i, bundle := range b.bundles {
|
||||
if bundle.id == id {
|
||||
b.bundles[i].synced = false
|
||||
heap.Fix(&b.bundles, i)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
newBundle := bundle{id: id, synced: false, priority: b.counter}
|
||||
heap.Push(&b.bundles, newBundle)
|
||||
b.counter++
|
||||
}
|
||||
|
||||
// MarkForSyncLater marks the bundle with the given ID for syncing later (after all other bundles).
|
||||
func (b *BundleQueue) MarkForSyncLater(id string) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
for i, bundle := range b.bundles {
|
||||
if bundle.id != id {
|
||||
continue
|
||||
}
|
||||
|
||||
// Increase the counter first to ensure that this bundle has the highest (last) priority.
|
||||
b.counter++
|
||||
b.bundles[i].synced = false
|
||||
b.bundles[i].priority = b.counter
|
||||
heap.Fix(&b.bundles, i)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// GetNextBundleToSync returns the ID of the next bundle to sync and whether there is one.
|
||||
func (b *BundleQueue) GetNextBundleToSync() (string, bool) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
if len(b.bundles) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Check the top bundle without popping
|
||||
if b.bundles[0].synced {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Mark the top bundle as synced and push it to the end
|
||||
id := b.bundles[0].id
|
||||
b.bundles[0].synced = true
|
||||
b.bundles[0].priority = b.counter
|
||||
heap.Fix(&b.bundles, 0)
|
||||
b.counter++
|
||||
|
||||
return id, true
|
||||
}
|
95
internal/zero/reconciler/bundles_queue_test.go
Normal file
95
internal/zero/reconciler/bundles_queue_test.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package reconciler_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/zero/reconciler"
|
||||
)
|
||||
|
||||
func TestQueueSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &reconciler.BundleQueue{}
|
||||
b.Set([]string{"bundle1", "bundle2"})
|
||||
|
||||
id1, ok1 := b.GetNextBundleToSync()
|
||||
id2, ok2 := b.GetNextBundleToSync()
|
||||
|
||||
assert.True(t, ok1, "Expected bundle1 to be set")
|
||||
assert.Equal(t, "bundle1", id1)
|
||||
assert.True(t, ok2, "Expected bundle2 to be set")
|
||||
assert.Equal(t, "bundle2", id2)
|
||||
|
||||
id3, ok3 := b.GetNextBundleToSync()
|
||||
assert.False(t, ok3, "Expected no more bundles to sync")
|
||||
assert.Empty(t, id3)
|
||||
}
|
||||
|
||||
func TestQueueMarkForSync(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &reconciler.BundleQueue{}
|
||||
b.Set([]string{"bundle1", "bundle2"})
|
||||
|
||||
b.MarkForSync("bundle2")
|
||||
id1, ok1 := b.GetNextBundleToSync()
|
||||
|
||||
assert.True(t, ok1, "Expected bundle1 to be marked for sync")
|
||||
assert.Equal(t, "bundle1", id1)
|
||||
|
||||
b.MarkForSync("bundle3")
|
||||
id2, ok2 := b.GetNextBundleToSync()
|
||||
id3, ok3 := b.GetNextBundleToSync()
|
||||
|
||||
assert.True(t, ok2, "Expected bundle2 to be marked for sync")
|
||||
assert.Equal(t, "bundle2", id2)
|
||||
assert.True(t, ok3, "Expected bundle3 to be marked for sync")
|
||||
assert.Equal(t, "bundle3", id3)
|
||||
}
|
||||
|
||||
func TestQueueMarkForSyncLater(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &reconciler.BundleQueue{}
|
||||
b.Set([]string{"bundle1", "bundle2", "bundle3"})
|
||||
|
||||
id1, ok1 := b.GetNextBundleToSync()
|
||||
b.MarkForSyncLater("bundle1")
|
||||
id2, ok2 := b.GetNextBundleToSync()
|
||||
id3, ok3 := b.GetNextBundleToSync()
|
||||
id4, ok4 := b.GetNextBundleToSync()
|
||||
id5, ok5 := b.GetNextBundleToSync()
|
||||
|
||||
assert.True(t, ok1, "Expected bundle1 to be marked for sync")
|
||||
assert.Equal(t, "bundle1", id1)
|
||||
assert.True(t, ok2, "Expected bundle2 to be marked for sync")
|
||||
assert.Equal(t, "bundle2", id2)
|
||||
assert.True(t, ok3, "Expected bundle3 to be marked for sync")
|
||||
assert.Equal(t, "bundle3", id3)
|
||||
assert.True(t, ok4, "Expected bundle1 to be marked for sync")
|
||||
assert.Equal(t, "bundle1", id4)
|
||||
assert.False(t, ok5, "Expected no more bundles to sync")
|
||||
assert.Empty(t, id5)
|
||||
|
||||
}
|
||||
|
||||
func TestQueueGetNextBundleToSync(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &reconciler.BundleQueue{}
|
||||
b.Set([]string{"bundle1", "bundle2"})
|
||||
|
||||
id1, ok1 := b.GetNextBundleToSync()
|
||||
id2, ok2 := b.GetNextBundleToSync()
|
||||
id3, ok3 := b.GetNextBundleToSync()
|
||||
|
||||
assert.True(t, ok1, "Expected bundle1 to be retrieved for sync")
|
||||
assert.Equal(t, "bundle1", id1)
|
||||
assert.True(t, ok2, "Expected bundle2 to be retrieved for sync")
|
||||
assert.Equal(t, "bundle2", id2)
|
||||
require.False(t, ok3, "Expected no more bundles to sync")
|
||||
assert.Empty(t, id3)
|
||||
}
|
110
internal/zero/reconciler/config.go
Normal file
110
internal/zero/reconciler/config.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
// Package reconciler syncs the state of resource bundles between the cloud and the databroker.
|
||||
package reconciler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
sdk "github.com/pomerium/zero-sdk"
|
||||
)
|
||||
|
||||
// reconcilerConfig contains the configuration for the resource bundles reconciler.
|
||||
type reconcilerConfig struct {
|
||||
api *sdk.API
|
||||
|
||||
databrokerClient databroker.DataBrokerServiceClient
|
||||
databrokerRPS int
|
||||
|
||||
tmpDir string
|
||||
|
||||
httpClient *http.Client
|
||||
|
||||
checkForUpdateIntervalWhenDisconnected time.Duration
|
||||
checkForUpdateIntervalWhenConnected time.Duration
|
||||
|
||||
syncBackoffMaxInterval time.Duration
|
||||
}
|
||||
|
||||
// Option configures the resource bundles reconciler
|
||||
type Option func(*reconcilerConfig)
|
||||
|
||||
// WithTemporaryDirectory configures the resource bundles client to use a temporary directory for
|
||||
// downloading files.
|
||||
func WithTemporaryDirectory(path string) Option {
|
||||
return func(cfg *reconcilerConfig) {
|
||||
cfg.tmpDir = path
|
||||
}
|
||||
}
|
||||
|
||||
// WithAPI configures the cluster api client.
|
||||
func WithAPI(client *sdk.API) Option {
|
||||
return func(cfg *reconcilerConfig) {
|
||||
cfg.api = client
|
||||
}
|
||||
}
|
||||
|
||||
// WithDataBrokerClient configures the databroker client.
|
||||
func WithDataBrokerClient(client databroker.DataBrokerServiceClient) Option {
|
||||
return func(cfg *reconcilerConfig) {
|
||||
cfg.databrokerClient = client
|
||||
}
|
||||
}
|
||||
|
||||
// WithDownloadHTTPClient configures the http client used for downloading files.
|
||||
func WithDownloadHTTPClient(client *http.Client) Option {
|
||||
return func(cfg *reconcilerConfig) {
|
||||
cfg.httpClient = client
|
||||
}
|
||||
}
|
||||
|
||||
// WithDatabrokerRPSLimit configures the maximum number of requests per second to the databroker.
|
||||
func WithDatabrokerRPSLimit(rps int) Option {
|
||||
return func(cfg *reconcilerConfig) {
|
||||
cfg.databrokerRPS = rps
|
||||
}
|
||||
}
|
||||
|
||||
// WithCheckForUpdateIntervalWhenDisconnected configures the interval at which the reconciler will check
|
||||
// for updates when disconnected from the cloud.
|
||||
func WithCheckForUpdateIntervalWhenDisconnected(interval time.Duration) Option {
|
||||
return func(cfg *reconcilerConfig) {
|
||||
cfg.checkForUpdateIntervalWhenDisconnected = interval
|
||||
}
|
||||
}
|
||||
|
||||
// WithCheckForUpdateIntervalWhenConnected configures the interval at which the reconciler will check
|
||||
// for updates when connected to the cloud.
|
||||
func WithCheckForUpdateIntervalWhenConnected(interval time.Duration) Option {
|
||||
return func(cfg *reconcilerConfig) {
|
||||
cfg.checkForUpdateIntervalWhenConnected = interval
|
||||
}
|
||||
}
|
||||
|
||||
// WithSyncBackoffMaxInterval configures the maximum interval between sync attempts.
|
||||
func WithSyncBackoffMaxInterval(interval time.Duration) Option {
|
||||
return func(cfg *reconcilerConfig) {
|
||||
cfg.syncBackoffMaxInterval = interval
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(opts ...Option) *reconcilerConfig {
|
||||
cfg := &reconcilerConfig{}
|
||||
for _, opt := range []Option{
|
||||
WithTemporaryDirectory(os.TempDir()),
|
||||
WithDownloadHTTPClient(http.DefaultClient),
|
||||
WithDatabrokerRPSLimit(1_000),
|
||||
WithCheckForUpdateIntervalWhenDisconnected(time.Minute * 5),
|
||||
WithCheckForUpdateIntervalWhenConnected(time.Hour),
|
||||
WithSyncBackoffMaxInterval(time.Minute),
|
||||
} {
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
81
internal/zero/reconciler/databroker.go
Normal file
81
internal/zero/reconciler/databroker.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package reconciler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
// DatabrokerRecord is a wrapper around a databroker record.
|
||||
type DatabrokerRecord struct {
|
||||
V *databroker.Record
|
||||
}
|
||||
|
||||
var _ Record[DatabrokerRecord] = DatabrokerRecord{}
|
||||
|
||||
// GetID returns the databroker record's ID.
|
||||
func (r DatabrokerRecord) GetID() string {
|
||||
return r.V.GetId()
|
||||
}
|
||||
|
||||
// GetType returns the databroker record's type.
|
||||
func (r DatabrokerRecord) GetType() string {
|
||||
return r.V.GetType()
|
||||
}
|
||||
|
||||
// Equal returns true if the databroker records are equal.
|
||||
func (r DatabrokerRecord) Equal(other DatabrokerRecord) bool {
|
||||
return r.V.Type == other.V.Type &&
|
||||
r.V.Id == other.V.Id &&
|
||||
proto.Equal(r.V.Data, other.V.Data)
|
||||
}
|
||||
|
||||
// GetDatabrokerRecords gets all databroker records of the given types.
|
||||
func GetDatabrokerRecords(
|
||||
ctx context.Context,
|
||||
client databroker.DataBrokerServiceClient,
|
||||
types []string,
|
||||
) (RecordSetBundle[DatabrokerRecord], error) {
|
||||
rsb := make(RecordSetBundle[DatabrokerRecord])
|
||||
|
||||
for _, typ := range types {
|
||||
recs, err := getDatabrokerRecords(ctx, client, typ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get databroker records for type %s: %w", typ, err)
|
||||
}
|
||||
rsb[typ] = recs
|
||||
}
|
||||
|
||||
return rsb, nil
|
||||
}
|
||||
|
||||
func getDatabrokerRecords(
|
||||
ctx context.Context,
|
||||
client databroker.DataBrokerServiceClient,
|
||||
typ string,
|
||||
) (RecordSet[DatabrokerRecord], error) {
|
||||
stream, err := client.SyncLatest(ctx, &databroker.SyncLatestRequest{Type: typ})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sync latest databroker: %w", err)
|
||||
}
|
||||
|
||||
recordSet := make(RecordSet[DatabrokerRecord])
|
||||
for {
|
||||
res, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("receive databroker record: %w", err)
|
||||
}
|
||||
|
||||
if record := res.GetRecord(); record != nil {
|
||||
recordSet[record.GetId()] = DatabrokerRecord{record}
|
||||
}
|
||||
}
|
||||
return recordSet, nil
|
||||
}
|
53
internal/zero/reconciler/databroker_changeset.go
Normal file
53
internal/zero/reconciler/databroker_changeset.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package reconciler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
// DatabrokerChangeSet is a set of databroker changes.
|
||||
type DatabrokerChangeSet struct {
|
||||
now *timestamppb.Timestamp
|
||||
updates []*databroker.Record
|
||||
}
|
||||
|
||||
// NewDatabrokerChangeSet creates a new databroker change set.
|
||||
func NewDatabrokerChangeSet() *DatabrokerChangeSet {
|
||||
return &DatabrokerChangeSet{
|
||||
now: timestamppb.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Remove adds a record to the change set.
|
||||
func (cs *DatabrokerChangeSet) Remove(typ string, id string) {
|
||||
cs.updates = append(cs.updates, &databroker.Record{
|
||||
Type: typ,
|
||||
Id: id,
|
||||
DeletedAt: cs.now,
|
||||
})
|
||||
}
|
||||
|
||||
// Upsert adds a record to the change set.
|
||||
func (cs *DatabrokerChangeSet) Upsert(record *databroker.Record) {
|
||||
cs.updates = append(cs.updates, &databroker.Record{
|
||||
Type: record.Type,
|
||||
Id: record.Id,
|
||||
Data: record.Data,
|
||||
})
|
||||
}
|
||||
|
||||
// ApplyChanges applies the changes to the databroker.
|
||||
func ApplyChanges(ctx context.Context, client databroker.DataBrokerServiceClient, changes *DatabrokerChangeSet) error {
|
||||
updates := databroker.OptimumPutRequestsFromRecords(changes.updates)
|
||||
for _, req := range updates {
|
||||
_, err := client.Put(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("put databroker record: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
168
internal/zero/reconciler/download_cache.go
Normal file
168
internal/zero/reconciler/download_cache.go
Normal file
|
@ -0,0 +1,168 @@
|
|||
package reconciler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
zero_sdk "github.com/pomerium/zero-sdk"
|
||||
)
|
||||
|
||||
// BundleCacheEntry is a cache entry for a bundle
|
||||
// that is kept in the databroker to avoid downloading
|
||||
// the same bundle multiple times.
|
||||
//
|
||||
// by using the ETag and LastModified headers, we do not need to
|
||||
// keep caches of the bundles themselves, which can be large.
|
||||
//
|
||||
// also it works in case of multiple instances, as it uses
|
||||
// the databroker database as a shared cache.
|
||||
type BundleCacheEntry struct {
|
||||
zero_sdk.DownloadConditional
|
||||
RecordTypes []string
|
||||
}
|
||||
|
||||
const (
|
||||
bundleCacheEntryRecordType = "pomerium.io/BundleCacheEntry"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrBundleCacheEntryNotFound is returned when a bundle cache entry is not found
|
||||
ErrBundleCacheEntryNotFound = errors.New("bundle cache entry not found")
|
||||
)
|
||||
|
||||
// GetBundleCacheEntry gets a bundle cache entry from the databroker
|
||||
func (c *service) GetBundleCacheEntry(ctx context.Context, id string) (*BundleCacheEntry, error) {
|
||||
record, err := c.config.databrokerClient.Get(ctx, &databroker.GetRequest{
|
||||
Type: bundleCacheEntryRecordType,
|
||||
Id: id,
|
||||
})
|
||||
if err != nil && status.Code(err) == codes.NotFound {
|
||||
return nil, ErrBundleCacheEntryNotFound
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("get bundle cache entry: %w", err)
|
||||
}
|
||||
|
||||
var dst BundleCacheEntry
|
||||
data := record.GetRecord().GetData()
|
||||
err = dst.FromAny(data)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).
|
||||
Str("bundle-id", id).
|
||||
Str("data", protojson.Format(data)).
|
||||
Msg("could not unmarshal bundle cache entry")
|
||||
// we would allow it to be overwritten by the update process
|
||||
return nil, ErrBundleCacheEntryNotFound
|
||||
}
|
||||
|
||||
return &dst, nil
|
||||
}
|
||||
|
||||
// SetBundleCacheEntry sets a bundle cache entry in the databroker
|
||||
func (c *service) SetBundleCacheEntry(ctx context.Context, id string, src BundleCacheEntry) error {
|
||||
val, err := src.ToAny()
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal bundle cache entry: %w", err)
|
||||
}
|
||||
_, err = c.config.databrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Records: []*databroker.Record{
|
||||
{
|
||||
Type: bundleCacheEntryRecordType,
|
||||
Id: id,
|
||||
Data: val,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("set bundle cache entry: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToAny marshals a BundleCacheEntry into an anypb.Any
|
||||
func (r *BundleCacheEntry) ToAny() (*anypb.Any, error) {
|
||||
err := r.Validate()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate: %w", err)
|
||||
}
|
||||
|
||||
types := make([]*structpb.Value, 0, len(r.RecordTypes))
|
||||
for _, t := range r.RecordTypes {
|
||||
types = append(types, structpb.NewStringValue(t))
|
||||
}
|
||||
|
||||
return protoutil.NewAny(&structpb.Struct{
|
||||
Fields: map[string]*structpb.Value{
|
||||
"etag": structpb.NewStringValue(r.ETag),
|
||||
"last_modified": structpb.NewStringValue(r.LastModified),
|
||||
"record_types": structpb.NewListValue(&structpb.ListValue{Values: types}),
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
|
||||
// FromAny unmarshals an anypb.Any into a BundleCacheEntry
|
||||
func (r *BundleCacheEntry) FromAny(any *anypb.Any) error {
|
||||
var s structpb.Struct
|
||||
err := any.UnmarshalTo(&s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal struct: %w", err)
|
||||
}
|
||||
|
||||
r.ETag = s.GetFields()["etag"].GetStringValue()
|
||||
r.LastModified = s.GetFields()["last_modified"].GetStringValue()
|
||||
|
||||
for _, v := range s.GetFields()["record_types"].GetListValue().GetValues() {
|
||||
r.RecordTypes = append(r.RecordTypes, v.GetStringValue())
|
||||
}
|
||||
|
||||
err = r.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates a BundleCacheEntry
|
||||
func (r *BundleCacheEntry) Validate() error {
|
||||
var errs *multierror.Error
|
||||
if len(r.RecordTypes) == 0 {
|
||||
errs = multierror.Append(errs, errors.New("record_types is required"))
|
||||
}
|
||||
if err := r.DownloadConditional.Validate(); err != nil {
|
||||
errs = multierror.Append(errs, err)
|
||||
}
|
||||
return errs.ErrorOrNil()
|
||||
}
|
||||
|
||||
// GetDownloadConditional returns conditional download information
|
||||
func (r *BundleCacheEntry) GetDownloadConditional() *zero_sdk.DownloadConditional {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cond := r.DownloadConditional
|
||||
return &cond
|
||||
}
|
||||
|
||||
// GetRecordTypes returns the record types
|
||||
func (r *BundleCacheEntry) GetRecordTypes() []string {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return r.RecordTypes
|
||||
}
|
||||
|
||||
// Equals returns true if the two cache entries are equal
|
||||
func (r *BundleCacheEntry) Equals(other *BundleCacheEntry) bool {
|
||||
return r != nil && other != nil &&
|
||||
r.ETag == other.ETag && r.LastModified == other.LastModified
|
||||
}
|
29
internal/zero/reconciler/download_cache_test.go
Normal file
29
internal/zero/reconciler/download_cache_test.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package reconciler_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/zero/reconciler"
|
||||
zero_sdk "github.com/pomerium/zero-sdk"
|
||||
)
|
||||
|
||||
func TestCacheEntryProto(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
original := reconciler.BundleCacheEntry{
|
||||
DownloadConditional: zero_sdk.DownloadConditional{
|
||||
ETag: "etag value",
|
||||
LastModified: "2009-02-13 18:31:30 -0500 EST",
|
||||
},
|
||||
RecordTypes: []string{"one", "two"},
|
||||
}
|
||||
originalProto, err := original.ToAny()
|
||||
require.NoError(t, err)
|
||||
var unmarshaled reconciler.BundleCacheEntry
|
||||
err = unmarshaled.FromAny(originalProto)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, original.Equals(&unmarshaled))
|
||||
}
|
34
internal/zero/reconciler/reconciler.go
Normal file
34
internal/zero/reconciler/reconciler.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package reconciler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
// Reconcile reconciles the target and current record sets with the databroker.
|
||||
func Reconcile(
|
||||
ctx context.Context,
|
||||
client databroker.DataBrokerServiceClient,
|
||||
target, current RecordSetBundle[DatabrokerRecord],
|
||||
) error {
|
||||
updates := NewDatabrokerChangeSet()
|
||||
|
||||
for _, rec := range current.GetRemoved(target).Flatten() {
|
||||
updates.Remove(rec.GetType(), rec.GetID())
|
||||
}
|
||||
for _, rec := range current.GetModified(target).Flatten() {
|
||||
updates.Upsert(rec.V)
|
||||
}
|
||||
for _, rec := range current.GetAdded(target).Flatten() {
|
||||
updates.Upsert(rec.V)
|
||||
}
|
||||
|
||||
err := ApplyChanges(ctx, client, updates)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply databroker changes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
196
internal/zero/reconciler/reconciler_test.go
Normal file
196
internal/zero/reconciler/reconciler_test.go
Normal file
|
@ -0,0 +1,196 @@
|
|||
package reconciler_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
|
||||
databroker_int "github.com/pomerium/pomerium/internal/databroker"
|
||||
"github.com/pomerium/pomerium/internal/zero/reconciler"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
func newDatabroker(t *testing.T) (context.Context, databroker.DataBrokerServiceClient) {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
gs := grpc.NewServer()
|
||||
srv := databroker_int.New()
|
||||
|
||||
databroker.RegisterDataBrokerServiceServer(gs, srv)
|
||||
|
||||
lis := bufconn.Listen(1)
|
||||
t.Cleanup(func() {
|
||||
lis.Close()
|
||||
gs.Stop()
|
||||
})
|
||||
|
||||
go func() { _ = gs.Serve(lis) }()
|
||||
|
||||
conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(func(context.Context, string) (conn net.Conn, e error) {
|
||||
return lis.Dial()
|
||||
}), grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
return ctx, databroker.NewDataBrokerServiceClient(conn)
|
||||
}
|
||||
|
||||
func newRecordBundle(records []testRecord) reconciler.RecordSetBundle[reconciler.DatabrokerRecord] {
|
||||
bundle := make(reconciler.RecordSetBundle[reconciler.DatabrokerRecord])
|
||||
for _, r := range records {
|
||||
bundle.Add(newRecord(r))
|
||||
}
|
||||
return bundle
|
||||
}
|
||||
|
||||
func newRecord(r testRecord) reconciler.DatabrokerRecord {
|
||||
return reconciler.DatabrokerRecord{
|
||||
V: &databroker.Record{
|
||||
Type: r.Type,
|
||||
Id: r.ID,
|
||||
Data: protoutil.NewAnyString(r.Val),
|
||||
}}
|
||||
}
|
||||
|
||||
func assertBundle(t *testing.T, want []testRecord, got reconciler.RecordSetBundle[reconciler.DatabrokerRecord]) {
|
||||
t.Helper()
|
||||
|
||||
for _, wantRecord := range want {
|
||||
gotRecord, ok := got.Get(wantRecord.Type, wantRecord.ID)
|
||||
if assert.True(t, ok, "record %s/%s not found", wantRecord.Type, wantRecord.ID) {
|
||||
assertRecord(t, wantRecord, gotRecord)
|
||||
}
|
||||
}
|
||||
assert.Len(t, got.Flatten(), len(want))
|
||||
}
|
||||
|
||||
func assertRecord(t *testing.T, want testRecord, got reconciler.DatabrokerRecord) {
|
||||
t.Helper()
|
||||
|
||||
var val wrapperspb.StringValue
|
||||
err := got.V.Data.UnmarshalTo(&val)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, want.Type, got.V.Type)
|
||||
assert.Equal(t, want.ID, got.V.Id)
|
||||
assert.Equal(t, want.Val, val.Value)
|
||||
}
|
||||
|
||||
func TestHelpers(t *testing.T) {
|
||||
want := []testRecord{
|
||||
{"type1", "id1", "value1"},
|
||||
{"type1", "id2", "value2"},
|
||||
}
|
||||
|
||||
bundle := newRecordBundle(want)
|
||||
assertBundle(t, want, bundle)
|
||||
}
|
||||
|
||||
func wantRemoved(want, current []string) []string {
|
||||
wantM := make(map[string]struct{}, len(want))
|
||||
for _, w := range want {
|
||||
wantM[w] = struct{}{}
|
||||
}
|
||||
var toRemove []string
|
||||
for _, c := range current {
|
||||
if _, ok := wantM[c]; !ok {
|
||||
toRemove = append(toRemove, c)
|
||||
}
|
||||
}
|
||||
return toRemove
|
||||
}
|
||||
|
||||
func reconcile(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
client databroker.DataBrokerServiceClient,
|
||||
want []testRecord,
|
||||
current reconciler.RecordSetBundle[reconciler.DatabrokerRecord],
|
||||
) reconciler.RecordSetBundle[reconciler.DatabrokerRecord] {
|
||||
t.Helper()
|
||||
|
||||
wantBundle := newRecordBundle(want)
|
||||
err := reconciler.Reconcile(ctx, client, wantBundle, current)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := reconciler.GetDatabrokerRecords(ctx, client, wantBundle.RecordTypes())
|
||||
require.NoError(t, err)
|
||||
assertBundle(t, want, got)
|
||||
|
||||
res, err := reconciler.GetDatabrokerRecords(ctx, client, wantRemoved(wantBundle.RecordTypes(), current.RecordTypes()))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, res.Flatten())
|
||||
|
||||
return got
|
||||
}
|
||||
|
||||
func TestReconcile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, client := newDatabroker(t)
|
||||
|
||||
err := reconciler.Reconcile(ctx, client, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var current reconciler.RecordSetBundle[reconciler.DatabrokerRecord]
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
want []testRecord
|
||||
}{
|
||||
{"empty", nil},
|
||||
{"initial", []testRecord{
|
||||
{"type1", "id1", "value1"},
|
||||
{"type1", "id2", "value2"},
|
||||
}},
|
||||
{"add one", []testRecord{
|
||||
{"type1", "id1", "value1"},
|
||||
{"type1", "id2", "value2"},
|
||||
{"type1", "id3", "value3"},
|
||||
}},
|
||||
{"update one", []testRecord{
|
||||
{"type1", "id1", "value1"},
|
||||
{"type1", "id2", "value2-updated"},
|
||||
{"type1", "id3", "value3"},
|
||||
}},
|
||||
{"delete one", []testRecord{
|
||||
{"type1", "id1", "value1"},
|
||||
{"type1", "id3", "value3"},
|
||||
}},
|
||||
{"delete all", nil},
|
||||
{"multiple types", []testRecord{
|
||||
{"type1", "id1", "value1"},
|
||||
{"type1", "id2", "value2"},
|
||||
{"type2", "id1", "value1"},
|
||||
{"type2", "id2", "value2"},
|
||||
}},
|
||||
{"multiple types update", []testRecord{
|
||||
{"type1", "id1", "value1"},
|
||||
{"type1", "id2", "value2-updated"},
|
||||
{"type2", "id1", "value1"},
|
||||
{"type2", "id2", "value2-updated"},
|
||||
}},
|
||||
{"multiple types delete", []testRecord{
|
||||
{"type1", "id1", "value1"},
|
||||
{"type2", "id1", "value1"},
|
||||
}},
|
||||
{"multiple types delete one type, add one value", []testRecord{
|
||||
{"type1", "id1", "value1"},
|
||||
{"type1", "id4", "value4"},
|
||||
}},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
current = reconcile(ctx, t, client, tc.want, current)
|
||||
})
|
||||
}
|
||||
}
|
132
internal/zero/reconciler/records.go
Normal file
132
internal/zero/reconciler/records.go
Normal file
|
@ -0,0 +1,132 @@
|
|||
package reconciler
|
||||
|
||||
// RecordSetBundle is an index of databroker records by type
|
||||
type RecordSetBundle[T Record[T]] map[string]RecordSet[T]
|
||||
|
||||
// RecordSet is an index of databroker records by their id.
|
||||
type RecordSet[T Record[T]] map[string]T
|
||||
|
||||
// Record is a record
|
||||
type Record[T any] interface {
|
||||
GetID() string
|
||||
GetType() string
|
||||
Equal(other T) bool
|
||||
}
|
||||
|
||||
// RecordTypes returns the types of records in the bundle.
|
||||
func (rsb RecordSetBundle[T]) RecordTypes() []string {
|
||||
types := make([]string, 0, len(rsb))
|
||||
for typ := range rsb {
|
||||
types = append(types, typ)
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
||||
// Add adds a record to the bundle.
|
||||
func (rsb RecordSetBundle[T]) Add(record T) {
|
||||
rs, ok := rsb[record.GetType()]
|
||||
if !ok {
|
||||
rs = make(RecordSet[T])
|
||||
rsb[record.GetType()] = rs
|
||||
}
|
||||
rs[record.GetID()] = record
|
||||
}
|
||||
|
||||
// GetAdded returns the records that are in other but not in rsb.
|
||||
func (rsb RecordSetBundle[T]) GetAdded(other RecordSetBundle[T]) RecordSetBundle[T] {
|
||||
added := make(RecordSetBundle[T])
|
||||
for otherType, otherRS := range other {
|
||||
rs, ok := rsb[otherType]
|
||||
if !ok {
|
||||
added[otherType] = otherRS
|
||||
continue
|
||||
}
|
||||
rss := rs.GetAdded(other[otherType])
|
||||
if len(rss) > 0 {
|
||||
added[otherType] = rss
|
||||
}
|
||||
}
|
||||
return added
|
||||
}
|
||||
|
||||
// GetRemoved returns the records that are in rs but not in other.
|
||||
func (rsb RecordSetBundle[T]) GetRemoved(other RecordSetBundle[T]) RecordSetBundle[T] {
|
||||
return other.GetAdded(rsb)
|
||||
}
|
||||
|
||||
// GetModified returns the records that are in both rs and other but have different data.
|
||||
func (rsb RecordSetBundle[T]) GetModified(other RecordSetBundle[T]) RecordSetBundle[T] {
|
||||
modified := make(RecordSetBundle[T])
|
||||
for otherType, otherRS := range other {
|
||||
rs, ok := rsb[otherType]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
m := rs.GetModified(otherRS)
|
||||
if len(m) > 0 {
|
||||
modified[otherType] = m
|
||||
}
|
||||
}
|
||||
return modified
|
||||
}
|
||||
|
||||
// GetAdded returns the records that are in other but not in rs.
|
||||
func (rs RecordSet[T]) GetAdded(other RecordSet[T]) RecordSet[T] {
|
||||
added := make(RecordSet[T])
|
||||
for id, record := range other {
|
||||
if _, ok := rs[id]; !ok {
|
||||
added[id] = record
|
||||
}
|
||||
}
|
||||
return added
|
||||
}
|
||||
|
||||
// GetRemoved returns the records that are in rs but not in other.
|
||||
func (rs RecordSet[T]) GetRemoved(other RecordSet[T]) RecordSet[T] {
|
||||
return other.GetAdded(rs)
|
||||
}
|
||||
|
||||
// GetModified returns the records that are in both rs and other but have different data.
|
||||
// by comparing the protobuf bytes of the payload.
|
||||
func (rs RecordSet[T]) GetModified(other RecordSet[T]) RecordSet[T] {
|
||||
modified := make(RecordSet[T])
|
||||
for id, record := range other {
|
||||
otherRecord, ok := rs[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if !record.Equal(otherRecord) {
|
||||
modified[id] = record
|
||||
}
|
||||
}
|
||||
return modified
|
||||
}
|
||||
|
||||
// Flatten returns all records in the set.
|
||||
func (rs RecordSet[T]) Flatten() []T {
|
||||
records := make([]T, 0, len(rs))
|
||||
for _, record := range rs {
|
||||
records = append(records, record)
|
||||
}
|
||||
return records
|
||||
}
|
||||
|
||||
// Flatten returns all records in the bundle.
|
||||
func (rsb RecordSetBundle[T]) Flatten() []T {
|
||||
records := make([]T, 0)
|
||||
for _, rs := range rsb {
|
||||
records = append(records, rs.Flatten()...)
|
||||
}
|
||||
return records
|
||||
}
|
||||
|
||||
// Get returns a record by type and id.
|
||||
func (rsb RecordSetBundle[T]) Get(typeName, id string) (record T, ok bool) {
|
||||
rs, ok := rsb[typeName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
record, ok = rs[id]
|
||||
return
|
||||
}
|
77
internal/zero/reconciler/records_test.go
Normal file
77
internal/zero/reconciler/records_test.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package reconciler_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/zero/reconciler"
|
||||
)
|
||||
|
||||
type testRecord struct {
|
||||
Type string
|
||||
ID string
|
||||
Val string
|
||||
}
|
||||
|
||||
func (r testRecord) GetID() string {
|
||||
return r.ID
|
||||
}
|
||||
|
||||
func (r testRecord) GetType() string {
|
||||
return r.Type
|
||||
}
|
||||
|
||||
func (r testRecord) Equal(other testRecord) bool {
|
||||
return r.ID == other.ID && r.Type == other.Type && r.Val == other.Val
|
||||
}
|
||||
|
||||
func TestRecords(t *testing.T) {
|
||||
initial := make(reconciler.RecordSetBundle[testRecord])
|
||||
initial.Add(testRecord{ID: "1", Type: "a", Val: "a-1"})
|
||||
initial.Add(testRecord{ID: "2", Type: "a", Val: "a-2"})
|
||||
initial.Add(testRecord{ID: "1", Type: "b", Val: "b-1"})
|
||||
|
||||
// test record types
|
||||
assert.ElementsMatch(t, []string{"a", "b"}, initial.RecordTypes())
|
||||
|
||||
// test added, deleted and modified
|
||||
updated := make(reconciler.RecordSetBundle[testRecord])
|
||||
updated.Add(testRecord{ID: "1", Type: "a", Val: "a-1-1"})
|
||||
updated.Add(testRecord{ID: "3", Type: "a", Val: "a-3"})
|
||||
updated.Add(testRecord{ID: "1", Type: "b", Val: "b-1"})
|
||||
updated.Add(testRecord{ID: "2", Type: "b", Val: "b-2"})
|
||||
updated.Add(testRecord{ID: "1", Type: "c", Val: "c-1"})
|
||||
|
||||
assert.ElementsMatch(t, []string{"a", "b", "c"}, updated.RecordTypes())
|
||||
|
||||
added := initial.GetAdded(updated)
|
||||
assert.Equal(t,
|
||||
reconciler.RecordSetBundle[testRecord]{
|
||||
"a": reconciler.RecordSet[testRecord]{
|
||||
"3": {ID: "3", Type: "a", Val: "a-3"},
|
||||
},
|
||||
"b": reconciler.RecordSet[testRecord]{
|
||||
"2": {ID: "2", Type: "b", Val: "b-2"},
|
||||
},
|
||||
"c": reconciler.RecordSet[testRecord]{
|
||||
"1": {ID: "1", Type: "c", Val: "c-1"},
|
||||
},
|
||||
}, added)
|
||||
|
||||
removed := initial.GetRemoved(updated)
|
||||
assert.Equal(t,
|
||||
reconciler.RecordSetBundle[testRecord]{
|
||||
"a": reconciler.RecordSet[testRecord]{
|
||||
"2": {ID: "2", Type: "a", Val: "a-2"},
|
||||
},
|
||||
}, removed)
|
||||
|
||||
modified := initial.GetModified(updated)
|
||||
assert.Equal(t,
|
||||
reconciler.RecordSetBundle[testRecord]{
|
||||
"a": reconciler.RecordSet[testRecord]{
|
||||
"1": {ID: "1", Type: "a", Val: "a-1-1"},
|
||||
},
|
||||
}, modified)
|
||||
}
|
89
internal/zero/reconciler/service.go
Normal file
89
internal/zero/reconciler/service.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package reconciler
|
||||
|
||||
/*
|
||||
* This is a main control loop for the reconciler service.
|
||||
*
|
||||
*/
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
connect_mux "github.com/pomerium/zero-sdk/connect-mux"
|
||||
)
|
||||
|
||||
type service struct {
|
||||
config *reconcilerConfig
|
||||
|
||||
databrokerRateLimit *rate.Limiter
|
||||
|
||||
bundles BundleQueue
|
||||
|
||||
fullSyncRequest chan struct{}
|
||||
bundleSyncRequest chan struct{}
|
||||
periodicUpdateInterval atomicutil.Value[time.Duration]
|
||||
}
|
||||
|
||||
// Run creates a new bundle updater client
|
||||
// that runs until the context is canceled or a fatal error occurs.
|
||||
func Run(ctx context.Context, opts ...Option) error {
|
||||
config := newConfig(opts...)
|
||||
|
||||
c := &service{
|
||||
config: config,
|
||||
databrokerRateLimit: rate.NewLimiter(rate.Limit(config.databrokerRPS), 1),
|
||||
fullSyncRequest: make(chan struct{}, 1),
|
||||
}
|
||||
c.periodicUpdateInterval.Store(config.checkForUpdateIntervalWhenDisconnected)
|
||||
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error { return c.watchUpdates(ctx) })
|
||||
eg.Go(func() error { return c.SyncLoop(ctx) })
|
||||
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
// run is a main control loop.
|
||||
// it is very simple and sequential download and reconcile.
|
||||
// it may be later optimized by splitting between download and reconciliation process,
|
||||
// as we would get more resource bundles beyond the config.
|
||||
func (c *service) watchUpdates(ctx context.Context) error {
|
||||
return c.config.api.Watch(ctx,
|
||||
connect_mux.WithOnConnected(func(ctx context.Context) {
|
||||
c.triggerFullUpdate(true)
|
||||
}),
|
||||
connect_mux.WithOnDisconnected(func(_ context.Context) {
|
||||
c.triggerFullUpdate(false)
|
||||
}),
|
||||
connect_mux.WithOnBundleUpdated(func(_ context.Context, key string) {
|
||||
c.triggerBundleUpdate(key)
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *service) triggerBundleUpdate(id string) {
|
||||
c.periodicUpdateInterval.Store(c.config.checkForUpdateIntervalWhenConnected)
|
||||
c.bundles.MarkForSync(id)
|
||||
|
||||
select {
|
||||
case c.fullSyncRequest <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *service) triggerFullUpdate(connected bool) {
|
||||
timeout := c.config.checkForUpdateIntervalWhenDisconnected
|
||||
if connected {
|
||||
timeout = c.config.checkForUpdateIntervalWhenConnected
|
||||
}
|
||||
c.periodicUpdateInterval.Store(timeout)
|
||||
|
||||
select {
|
||||
case c.fullSyncRequest <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
231
internal/zero/reconciler/sync.go
Normal file
231
internal/zero/reconciler/sync.go
Normal file
|
@ -0,0 +1,231 @@
|
|||
package reconciler
|
||||
|
||||
/*
|
||||
* Sync syncs the bundles between their cloud source and the databroker.
|
||||
*
|
||||
* FullSync performs a full sync of the bundles by calling the API,
|
||||
* and walking the list of bundles, and calling SyncBundle on each.
|
||||
* It also removes any records in the databroker that are not in the list of bundles.
|
||||
*
|
||||
* WatchAndSync watches the API for changes, and calls SyncBundle on each change.
|
||||
*
|
||||
*/
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/retry"
|
||||
)
|
||||
|
||||
// Sync synchronizes the bundles between their cloud source and the databroker.
|
||||
func (c *service) SyncLoop(ctx context.Context) error {
|
||||
ticker := time.NewTicker(c.periodicUpdateInterval.Load())
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
dur := c.periodicUpdateInterval.Load()
|
||||
ticker.Reset(dur)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.bundleSyncRequest:
|
||||
log.Ctx(ctx).Info().Msg("bundle sync triggered")
|
||||
err := c.syncBundles(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reconciler: sync bundles: %w", err)
|
||||
}
|
||||
case <-c.fullSyncRequest:
|
||||
log.Ctx(ctx).Info().Msg("full sync triggered")
|
||||
err := c.syncAll(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reconciler: sync all: %w", err)
|
||||
}
|
||||
case <-ticker.C:
|
||||
log.Ctx(ctx).Info().Msg("periodic sync triggered")
|
||||
err := c.syncAll(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reconciler: sync all: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *service) syncAll(ctx context.Context) error {
|
||||
err := c.syncBundleList(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sync bundle list: %w", err)
|
||||
}
|
||||
|
||||
err = c.syncBundles(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sync bundles: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// trySyncAllBundles tries to sync all bundles in the queue.
|
||||
func (c *service) syncBundleList(ctx context.Context) error {
|
||||
// refresh bundle list,
|
||||
// ignoring other signals while we're retrying
|
||||
return retry.Retry(ctx,
|
||||
"refresh bundle list", c.refreshBundleList,
|
||||
retry.WithWatch("refresh bundle list", c.fullSyncRequest, nil),
|
||||
retry.WithWatch("bundle update", c.bundleSyncRequest, nil),
|
||||
)
|
||||
}
|
||||
|
||||
// syncBundles retries until there are no more bundles to sync.
|
||||
// updates bundle list if the full bundle update request arrives.
|
||||
func (c *service) syncBundles(ctx context.Context) error {
|
||||
return retry.Retry(ctx,
|
||||
"sync bundles", c.trySyncBundles,
|
||||
retry.WithWatch("refresh bundle list", c.fullSyncRequest, c.refreshBundleList),
|
||||
retry.WithWatch("bundle update", c.bundleSyncRequest, nil),
|
||||
)
|
||||
}
|
||||
|
||||
// trySyncAllBundles tries to sync all bundles in the queue
|
||||
// it returns nil if all bundles were synced successfully
|
||||
func (c *service) trySyncBundles(ctx context.Context) error {
|
||||
for {
|
||||
id, ok := c.bundles.GetNextBundleToSync()
|
||||
if !ok { // no more bundles to sync
|
||||
return nil
|
||||
}
|
||||
|
||||
err := c.syncBundle(ctx, id)
|
||||
if err != nil {
|
||||
c.bundles.MarkForSyncLater(id)
|
||||
return fmt.Errorf("sync bundle %s: %w", id, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// syncBundle syncs the bundle to the databroker.
|
||||
// Databroker holds last synced bundle state in form of a (etag, last-modified) tuple.
|
||||
// This is only persisted in the databroker after all records are successfully synced.
|
||||
// That allows us to ignore any changes based on the same bundle state, without need to re-check all records between bundle and databroker.
|
||||
func (c *service) syncBundle(ctx context.Context, key string) error {
|
||||
cached, err := c.GetBundleCacheEntry(ctx, key)
|
||||
if err != nil && !errors.Is(err, ErrBundleCacheEntryNotFound) {
|
||||
return fmt.Errorf("get bundle cache entry: %w", err)
|
||||
}
|
||||
|
||||
// download is much faster compared to databroker sync,
|
||||
// so we don't use pipe but rather download to a temp file and then sync it to databroker
|
||||
fd, err := c.GetTmpFile(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get tmp file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := fd.Close(); err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("close tmp file")
|
||||
}
|
||||
}()
|
||||
|
||||
conditional := cached.GetDownloadConditional()
|
||||
log.Ctx(ctx).Debug().Str("id", key).Any("conditional", conditional).Msg("downloading bundle")
|
||||
|
||||
result, err := c.config.api.DownloadClusterResourceBundle(ctx, fd, key, conditional)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download bundle: %w", err)
|
||||
}
|
||||
|
||||
if result.NotModified {
|
||||
log.Ctx(ctx).Debug().Str("bundle", key).Msg("bundle not changed")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Ctx(ctx).Debug().Str("bundle", key).
|
||||
Interface("cached-entry", cached).
|
||||
Interface("current-entry", result.DownloadConditional).
|
||||
Msg("bundle updated")
|
||||
|
||||
_, err = fd.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
return fmt.Errorf("seek to start: %w", err)
|
||||
}
|
||||
|
||||
bundleRecordTypes, err := c.syncBundleToDatabroker(ctx, fd, cached.GetRecordTypes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply bundle to databroker: %w", err)
|
||||
}
|
||||
current := BundleCacheEntry{
|
||||
DownloadConditional: *result.DownloadConditional,
|
||||
RecordTypes: bundleRecordTypes,
|
||||
}
|
||||
|
||||
log.Ctx(ctx).Info().
|
||||
Str("bundle", key).
|
||||
Strs("record_types", bundleRecordTypes).
|
||||
Str("etag", current.ETag).
|
||||
Str("last_modified", current.LastModified).
|
||||
Msg("bundle synced")
|
||||
|
||||
err = c.SetBundleCacheEntry(ctx, key, current)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set bundle cache entry: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func strUnion(a, b []string) []string {
|
||||
m := make(map[string]struct{}, len(a)+len(b))
|
||||
for _, s := range a {
|
||||
m[s] = struct{}{}
|
||||
}
|
||||
for _, s := range b {
|
||||
m[s] = struct{}{}
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(m))
|
||||
for s := range m {
|
||||
out = append(out, s)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *service) syncBundleToDatabroker(ctx context.Context, src io.Reader, currentRecordTypes []string) ([]string, error) {
|
||||
bundleRecords, err := ReadBundleRecords(src)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read bundle records: %w", err)
|
||||
}
|
||||
|
||||
databrokerRecords, err := GetDatabrokerRecords(ctx,
|
||||
c.config.databrokerClient,
|
||||
strUnion(bundleRecords.RecordTypes(), currentRecordTypes),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get databroker records: %w", err)
|
||||
}
|
||||
|
||||
err = Reconcile(ctx, c.config.databrokerClient, bundleRecords, databrokerRecords)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reconcile databroker records: %w", err)
|
||||
}
|
||||
|
||||
return bundleRecords.RecordTypes(), nil
|
||||
}
|
||||
|
||||
func (c *service) refreshBundleList(ctx context.Context) error {
|
||||
resp, err := c.config.api.GetClusterResourceBundles(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get bundles: %w", err)
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(resp.Bundles))
|
||||
for _, v := range resp.Bundles {
|
||||
ids = append(ids, v.Id)
|
||||
}
|
||||
|
||||
c.bundles.Set(ids)
|
||||
return nil
|
||||
}
|
40
internal/zero/reconciler/tmpfile.go
Normal file
40
internal/zero/reconciler/tmpfile.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package reconciler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
// ReadWriteSeekCloser is a file that can be read, written, seeked, and closed.
|
||||
type ReadWriteSeekCloser interface {
|
||||
io.ReadWriteSeeker
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// GetTmpFile returns a temporary file for the reconciler to use.
|
||||
// TODO: encrypt contents to ensure encryption at rest
|
||||
func (c *service) GetTmpFile(key string) (ReadWriteSeekCloser, error) {
|
||||
fd, err := os.CreateTemp(c.config.tmpDir, fmt.Sprintf("pomerium-bundle-%s", key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
return &tmpFile{File: fd}, nil
|
||||
}
|
||||
|
||||
type tmpFile struct {
|
||||
*os.File
|
||||
}
|
||||
|
||||
func (f *tmpFile) Close() error {
|
||||
var errs *multierror.Error
|
||||
if err := f.File.Close(); err != nil {
|
||||
errs = multierror.Append(errs, err)
|
||||
}
|
||||
if err := os.Remove(f.File.Name()); err != nil {
|
||||
errs = multierror.Append(errs, err)
|
||||
}
|
||||
return errs.ErrorOrNil()
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue