mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-01 18:33:19 +02:00
zero: restart config reconciliation when databroker storage is changed (#4623)
This commit is contained in:
parent
60ab9dafbe
commit
0e1061d813
5 changed files with 231 additions and 10 deletions
|
@ -4,16 +4,9 @@ import (
|
|||
"context"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/zero/reconciler"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
func (c *controller) RunReconciler(ctx context.Context) error {
|
||||
leaser := databroker.NewLeaser("zero-reconciler", c.cfg.reconcilerLeaseDuration, c)
|
||||
return leaser.Run(ctx)
|
||||
}
|
||||
|
||||
// RunLeased implements the databroker.Leaser interface.
|
||||
func (c *controller) RunLeased(ctx context.Context) error {
|
||||
return reconciler.Run(ctx,
|
||||
reconciler.WithAPI(c.api),
|
||||
reconciler.WithDataBrokerClient(c.GetDataBrokerServiceClient()),
|
||||
|
|
|
@ -33,7 +33,8 @@ type BundleCacheEntry struct {
|
|||
}
|
||||
|
||||
const (
|
||||
bundleCacheEntryRecordType = "pomerium.io/BundleCacheEntry"
|
||||
// BundleCacheEntryRecordType is the databroker record type for BundleCacheEntry
|
||||
BundleCacheEntryRecordType = "pomerium.io/BundleCacheEntry"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -44,7 +45,7 @@ var (
|
|||
// GetBundleCacheEntry gets a bundle cache entry from the databroker
|
||||
func (c *service) GetBundleCacheEntry(ctx context.Context, id string) (*BundleCacheEntry, error) {
|
||||
record, err := c.config.databrokerClient.Get(ctx, &databroker.GetRequest{
|
||||
Type: bundleCacheEntryRecordType,
|
||||
Type: BundleCacheEntryRecordType,
|
||||
Id: id,
|
||||
})
|
||||
if err != nil && status.Code(err) == codes.NotFound {
|
||||
|
@ -77,7 +78,7 @@ func (c *service) SetBundleCacheEntry(ctx context.Context, id string, src Bundle
|
|||
_, err = c.config.databrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Records: []*databroker.Record{
|
||||
{
|
||||
Type: bundleCacheEntryRecordType,
|
||||
Type: BundleCacheEntryRecordType,
|
||||
Id: id,
|
||||
Data: val,
|
||||
},
|
||||
|
|
72
internal/zero/reconciler/restart.go
Normal file
72
internal/zero/reconciler/restart.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package reconciler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// RunWithRestart executes execFn.
|
||||
// The execution would be restarted, by means of canceling the context provided to execFn, each time restartFn quits.
|
||||
// the error returned by restartFn is purely informational and does not affect the execution; may be nil.
|
||||
// the loop is stopped when the context provided to RunWithRestart is canceled or a genuine error is returned by execFn, not caused by the context.
|
||||
func RunWithRestart(
|
||||
ctx context.Context,
|
||||
execFn func(context.Context) error,
|
||||
restartFn func(context.Context) error,
|
||||
) error {
|
||||
contexts := make(chan context.Context)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var err error
|
||||
go func() {
|
||||
err = restartWithContext(contexts, execFn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
restartContexts(ctx, contexts, restartFn)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func restartContexts(
|
||||
base context.Context,
|
||||
contexts chan<- context.Context,
|
||||
restartFn func(context.Context) error,
|
||||
) {
|
||||
defer close(contexts)
|
||||
for base.Err() == nil {
|
||||
ctx, cancel := context.WithCancelCause(base)
|
||||
select {
|
||||
case contexts <- ctx:
|
||||
err := restartFn(ctx)
|
||||
cancel(fmt.Errorf("requesting restart: %w", err))
|
||||
case <-base.Done():
|
||||
cancel(fmt.Errorf("parent context canceled: %w", base.Err()))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func restartWithContext(
|
||||
contexts <-chan context.Context,
|
||||
execFn func(context.Context) error,
|
||||
) error {
|
||||
var err error
|
||||
for ctx := range contexts {
|
||||
err = execFn(ctx)
|
||||
if ctx.Err() == nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
110
internal/zero/reconciler/restart_test.go
Normal file
110
internal/zero/reconciler/restart_test.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
package reconciler_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/zero/reconciler"
|
||||
)
|
||||
|
||||
func TestRestart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
t.Run(fmt.Sprintf("quit on error %d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errExpected := errors.New("execFn error")
|
||||
count := 0
|
||||
err := reconciler.RunWithRestart(context.Background(),
|
||||
func(context.Context) error {
|
||||
count++
|
||||
if count == 1 {
|
||||
return errExpected
|
||||
}
|
||||
return errors.New("execFn should not be called more than once")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
},
|
||||
)
|
||||
assert.ErrorIs(t, err, errExpected)
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("quit on no error %d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
count := 0
|
||||
err := reconciler.RunWithRestart(context.Background(),
|
||||
func(context.Context) error {
|
||||
count++
|
||||
if count == 1 {
|
||||
return nil
|
||||
}
|
||||
return errors.New("execFn should not be called more than once")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("parent context canceled %d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
ready := make(chan struct{})
|
||||
err := reconciler.RunWithRestart(ctx,
|
||||
func(context.Context) error {
|
||||
<-ready
|
||||
cancel()
|
||||
return ctx.Err()
|
||||
},
|
||||
func(context.Context) error {
|
||||
close(ready)
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
},
|
||||
)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("triggers restart %d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
errExpected := errors.New("execFn error")
|
||||
count := 0
|
||||
ready := make(chan struct{})
|
||||
err := reconciler.RunWithRestart(ctx,
|
||||
func(ctx context.Context) error {
|
||||
count++
|
||||
if count == 1 { // wait for us to be restarted
|
||||
close(ready)
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
} else if count == 2 { // just quit
|
||||
return errExpected
|
||||
}
|
||||
return errors.New("execFn should not be called more than twice")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
<-ready
|
||||
return errors.New("restart required")
|
||||
},
|
||||
)
|
||||
assert.ErrorIs(t, err, errExpected)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -7,12 +7,14 @@ package reconciler
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
connect_mux "github.com/pomerium/zero-sdk/connect-mux"
|
||||
)
|
||||
|
||||
|
@ -40,6 +42,11 @@ func Run(ctx context.Context, opts ...Option) error {
|
|||
}
|
||||
c.periodicUpdateInterval.Store(config.checkForUpdateIntervalWhenDisconnected)
|
||||
|
||||
return c.runMainLoop(ctx)
|
||||
}
|
||||
|
||||
// RunLeased implements the databroker.LeaseHandler interface
|
||||
func (c *service) RunLeased(ctx context.Context) error {
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error { return c.watchUpdates(ctx) })
|
||||
eg.Go(func() error { return c.SyncLoop(ctx) })
|
||||
|
@ -47,6 +54,44 @@ func Run(ctx context.Context, opts ...Option) error {
|
|||
return eg.Wait()
|
||||
}
|
||||
|
||||
// GetDataBrokerServiceClient implements the databroker.LeaseHandler interface.
|
||||
func (c *service) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
||||
return c.config.databrokerClient
|
||||
}
|
||||
|
||||
func (c *service) runMainLoop(ctx context.Context) error {
|
||||
leaser := databroker.NewLeaser("zero-reconciler", time.Second*30, c)
|
||||
return RunWithRestart(ctx, func(ctx context.Context) error {
|
||||
return leaser.Run(ctx)
|
||||
}, c.databrokerChangeMonitor)
|
||||
}
|
||||
|
||||
// databrokerChangeMonitor runs infinite sync loop to see if there is any change in databroker
|
||||
func (c *service) databrokerChangeMonitor(ctx context.Context) error {
|
||||
_, recordVersion, serverVersion, err := databroker.InitialSync(ctx, c.GetDataBrokerServiceClient(), &databroker.SyncLatestRequest{
|
||||
Type: BundleCacheEntryRecordType,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error during initial sync: %w", err)
|
||||
}
|
||||
|
||||
stream, err := c.GetDataBrokerServiceClient().Sync(ctx, &databroker.SyncRequest{
|
||||
Type: BundleCacheEntryRecordType,
|
||||
ServerVersion: serverVersion,
|
||||
RecordVersion: recordVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error calling sync: %w", err)
|
||||
}
|
||||
|
||||
for {
|
||||
_, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error receiving record: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// run is a main control loop.
|
||||
// it is very simple and sequential download and reconcile.
|
||||
// it may be later optimized by splitting between download and reconciliation process,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue