diff --git a/internal/zero/cmd/command.go b/internal/zero/cmd/command.go index d5c34f543..1ee8d9607 100644 --- a/internal/zero/cmd/command.go +++ b/internal/zero/cmd/command.go @@ -56,8 +56,10 @@ func IsManagedMode(configFile string) bool { } func withInterrupt(ctx context.Context) context.Context { - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) go func(ctx context.Context) { + defer cancel(context.Canceled) + ch := make(chan os.Signal, 2) defer signal.Stop(ch) @@ -66,10 +68,9 @@ func withInterrupt(ctx context.Context) context.Context { select { case sig := <-ch: - log.Ctx(ctx).Info().Str("signal", sig.String()).Msg("quitting...") + cancel(fmt.Errorf("received signal: %s", sig)) case <-ctx.Done(): } - cancel() }(ctx) return ctx } diff --git a/internal/zero/controller/controller.go b/internal/zero/controller/controller.go index 624f31f19..5fc4f1250 100644 --- a/internal/zero/controller/controller.go +++ b/internal/zero/controller/controller.go @@ -3,7 +3,6 @@ package controller import ( "context" - "errors" "fmt" "net" "net/url" @@ -28,7 +27,6 @@ import ( // Run runs Pomerium is managed mode using the provided token. func Run(ctx context.Context, opts ...Option) error { c := controller{cfg: newControllerConfig(opts...)} - eg, ctx := errgroup.WithContext(ctx) err := c.initAPI(ctx) if err != nil { @@ -58,11 +56,17 @@ func Run(ctx context.Context, opts ...Option) error { } c.bootstrapConfig = src + eg, ctx := errgroup.WithContext(ctx) eg.Go(func() error { return run(ctx, "connect", c.runConnect) }) eg.Go(func() error { return run(ctx, "connect-log", c.RunConnectLog) }) eg.Go(func() error { return run(ctx, "zero-bootstrap", c.runBootstrap) }) eg.Go(func() error { return run(ctx, "pomerium-core", c.runPomeriumCore) }) eg.Go(func() error { return run(ctx, "zero-control-loop", c.runZeroControlLoop) }) + eg.Go(func() error { + <-ctx.Done() + log.Ctx(ctx).Info().Msgf("shutting down: %v", context.Cause(ctx)) + return nil + }) return eg.Wait() } @@ -91,8 +95,9 @@ func (c *controller) initAPI(ctx context.Context) error { func run(ctx context.Context, name string, runFn func(context.Context) error) error { log.Ctx(ctx).Debug().Str("name", name).Msg("starting") + defer log.Ctx(ctx).Debug().Str("name", name).Msg("stopped") err := runFn(ctx) - if err != nil && !errors.Is(err, context.Canceled) { + if err != nil && ctx.Err() == nil { return fmt.Errorf("%s: %w", name, err) } return nil