diff --git a/internal/zero/controller/reconciler.go b/internal/zero/controller/reconciler.go index 9c6993e87..f8ecd4aac 100644 --- a/internal/zero/controller/reconciler.go +++ b/internal/zero/controller/reconciler.go @@ -4,16 +4,9 @@ import ( "context" "github.com/pomerium/pomerium/internal/zero/reconciler" - "github.com/pomerium/pomerium/pkg/grpc/databroker" ) func (c *controller) RunReconciler(ctx context.Context) error { - leaser := databroker.NewLeaser("zero-reconciler", c.cfg.reconcilerLeaseDuration, c) - return leaser.Run(ctx) -} - -// RunLeased implements the databroker.Leaser interface. -func (c *controller) RunLeased(ctx context.Context) error { return reconciler.Run(ctx, reconciler.WithAPI(c.api), reconciler.WithDataBrokerClient(c.GetDataBrokerServiceClient()), diff --git a/internal/zero/reconciler/download_cache.go b/internal/zero/reconciler/download_cache.go index 2a709965b..ec6b295c0 100644 --- a/internal/zero/reconciler/download_cache.go +++ b/internal/zero/reconciler/download_cache.go @@ -33,7 +33,8 @@ type BundleCacheEntry struct { } const ( - bundleCacheEntryRecordType = "pomerium.io/BundleCacheEntry" + // BundleCacheEntryRecordType is the databroker record type for BundleCacheEntry + BundleCacheEntryRecordType = "pomerium.io/BundleCacheEntry" ) var ( @@ -44,7 +45,7 @@ var ( // GetBundleCacheEntry gets a bundle cache entry from the databroker func (c *service) GetBundleCacheEntry(ctx context.Context, id string) (*BundleCacheEntry, error) { record, err := c.config.databrokerClient.Get(ctx, &databroker.GetRequest{ - Type: bundleCacheEntryRecordType, + Type: BundleCacheEntryRecordType, Id: id, }) if err != nil && status.Code(err) == codes.NotFound { @@ -77,7 +78,7 @@ func (c *service) SetBundleCacheEntry(ctx context.Context, id string, src Bundle _, err = c.config.databrokerClient.Put(ctx, &databroker.PutRequest{ Records: []*databroker.Record{ { - Type: bundleCacheEntryRecordType, + Type: BundleCacheEntryRecordType, Id: id, Data: val, }, diff --git a/internal/zero/reconciler/restart.go b/internal/zero/reconciler/restart.go new file mode 100644 index 000000000..079b09302 --- /dev/null +++ b/internal/zero/reconciler/restart.go @@ -0,0 +1,72 @@ +package reconciler + +import ( + "context" + "fmt" + "sync" +) + +// RunWithRestart executes execFn. +// The execution would be restarted, by means of canceling the context provided to execFn, each time restartFn quits. +// the error returned by restartFn is purely informational and does not affect the execution; may be nil. +// the loop is stopped when the context provided to RunWithRestart is canceled or a genuine error is returned by execFn, not caused by the context. +func RunWithRestart( + ctx context.Context, + execFn func(context.Context) error, + restartFn func(context.Context) error, +) error { + contexts := make(chan context.Context) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var wg sync.WaitGroup + wg.Add(2) + + var err error + go func() { + err = restartWithContext(contexts, execFn) + cancel() + wg.Done() + }() + go func() { + restartContexts(ctx, contexts, restartFn) + wg.Done() + }() + + wg.Wait() + return err +} + +func restartContexts( + base context.Context, + contexts chan<- context.Context, + restartFn func(context.Context) error, +) { + defer close(contexts) + for base.Err() == nil { + ctx, cancel := context.WithCancelCause(base) + select { + case contexts <- ctx: + err := restartFn(ctx) + cancel(fmt.Errorf("requesting restart: %w", err)) + case <-base.Done(): + cancel(fmt.Errorf("parent context canceled: %w", base.Err())) + return + } + } +} + +func restartWithContext( + contexts <-chan context.Context, + execFn func(context.Context) error, +) error { + var err error + for ctx := range contexts { + err = execFn(ctx) + if ctx.Err() == nil { + return err + } + } + return err +} diff --git a/internal/zero/reconciler/restart_test.go b/internal/zero/reconciler/restart_test.go new file mode 100644 index 000000000..2fbe0fdff --- /dev/null +++ b/internal/zero/reconciler/restart_test.go @@ -0,0 +1,110 @@ +package reconciler_test + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/zero/reconciler" +) + +func TestRestart(t *testing.T) { + t.Parallel() + + for i := 0; i < 20; i++ { + t.Run(fmt.Sprintf("quit on error %d", i), func(t *testing.T) { + t.Parallel() + + errExpected := errors.New("execFn error") + count := 0 + err := reconciler.RunWithRestart(context.Background(), + func(context.Context) error { + count++ + if count == 1 { + return errExpected + } + return errors.New("execFn should not be called more than once") + }, + func(ctx context.Context) error { + <-ctx.Done() + return ctx.Err() + }, + ) + assert.ErrorIs(t, err, errExpected) + }) + + t.Run(fmt.Sprintf("quit on no error %d", i), func(t *testing.T) { + t.Parallel() + + count := 0 + err := reconciler.RunWithRestart(context.Background(), + func(context.Context) error { + count++ + if count == 1 { + return nil + } + return errors.New("execFn should not be called more than once") + }, + func(ctx context.Context) error { + <-ctx.Done() + return ctx.Err() + }, + ) + assert.NoError(t, err) + }) + + t.Run(fmt.Sprintf("parent context canceled %d", i), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + ready := make(chan struct{}) + err := reconciler.RunWithRestart(ctx, + func(context.Context) error { + <-ready + cancel() + return ctx.Err() + }, + func(context.Context) error { + close(ready) + <-ctx.Done() + return ctx.Err() + }, + ) + assert.ErrorIs(t, err, context.Canceled) + }) + + t.Run(fmt.Sprintf("triggers restart %d", i), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + errExpected := errors.New("execFn error") + count := 0 + ready := make(chan struct{}) + err := reconciler.RunWithRestart(ctx, + func(ctx context.Context) error { + count++ + if count == 1 { // wait for us to be restarted + close(ready) + <-ctx.Done() + return ctx.Err() + } else if count == 2 { // just quit + return errExpected + } + return errors.New("execFn should not be called more than twice") + }, + func(ctx context.Context) error { + <-ready + return errors.New("restart required") + }, + ) + assert.ErrorIs(t, err, errExpected) + }) + } +} diff --git a/internal/zero/reconciler/service.go b/internal/zero/reconciler/service.go index 1fb951d59..7ab537ae5 100644 --- a/internal/zero/reconciler/service.go +++ b/internal/zero/reconciler/service.go @@ -7,12 +7,14 @@ package reconciler import ( "context" + "fmt" "time" "golang.org/x/sync/errgroup" "golang.org/x/time/rate" "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" connect_mux "github.com/pomerium/zero-sdk/connect-mux" ) @@ -40,6 +42,11 @@ func Run(ctx context.Context, opts ...Option) error { } c.periodicUpdateInterval.Store(config.checkForUpdateIntervalWhenDisconnected) + return c.runMainLoop(ctx) +} + +// RunLeased implements the databroker.LeaseHandler interface +func (c *service) RunLeased(ctx context.Context) error { eg, ctx := errgroup.WithContext(ctx) eg.Go(func() error { return c.watchUpdates(ctx) }) eg.Go(func() error { return c.SyncLoop(ctx) }) @@ -47,6 +54,44 @@ func Run(ctx context.Context, opts ...Option) error { return eg.Wait() } +// GetDataBrokerServiceClient implements the databroker.LeaseHandler interface. +func (c *service) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { + return c.config.databrokerClient +} + +func (c *service) runMainLoop(ctx context.Context) error { + leaser := databroker.NewLeaser("zero-reconciler", time.Second*30, c) + return RunWithRestart(ctx, func(ctx context.Context) error { + return leaser.Run(ctx) + }, c.databrokerChangeMonitor) +} + +// databrokerChangeMonitor runs infinite sync loop to see if there is any change in databroker +func (c *service) databrokerChangeMonitor(ctx context.Context) error { + _, recordVersion, serverVersion, err := databroker.InitialSync(ctx, c.GetDataBrokerServiceClient(), &databroker.SyncLatestRequest{ + Type: BundleCacheEntryRecordType, + }) + if err != nil { + return fmt.Errorf("error during initial sync: %w", err) + } + + stream, err := c.GetDataBrokerServiceClient().Sync(ctx, &databroker.SyncRequest{ + Type: BundleCacheEntryRecordType, + ServerVersion: serverVersion, + RecordVersion: recordVersion, + }) + if err != nil { + return fmt.Errorf("error calling sync: %w", err) + } + + for { + _, err := stream.Recv() + if err != nil { + return fmt.Errorf("error receiving record: %w", err) + } + } +} + // run is a main control loop. // it is very simple and sequential download and reconcile. // it may be later optimized by splitting between download and reconciliation process,