From fe31799eb5d20074a9722075b0d003efa124725b Mon Sep 17 00:00:00 2001 From: Joe Kralicky Date: Fri, 25 Oct 2024 14:50:56 -0400 Subject: [PATCH] Fix many instances of contexts and loggers not being propagated (#5340) This also replaces instances where we manually write "return ctx.Err()" with "return context.Cause(ctx)" which is functionally identical, but will also correctly propagate cause errors if present. --- authenticate/authenticate.go | 6 +- authenticate/authenticate_test.go | 3 +- authenticate/state.go | 5 +- authorize/authorize.go | 9 +-- authorize/authorize_test.go | 6 +- authorize/check_response_test.go | 6 +- authorize/databroker_test.go | 2 +- authorize/evaluator/evaluator.go | 10 +-- authorize/evaluator/policy_evaluator.go | 3 +- authorize/state.go | 9 +-- cmd/pomerium/main.go | 2 +- config/config_source.go | 3 +- config/config_source_test.go | 2 +- config/envoyconfig/clusters.go | 4 +- config/envoyconfig/clusters_test.go | 6 +- config/envoyconfig/envoyconfig.go | 10 +-- databroker/cache.go | 10 +-- databroker/cache_test.go | 3 +- databroker/databroker.go | 6 +- databroker/databroker_test.go | 2 +- integration/main_test.go | 2 +- internal/authenticateflow/stateful.go | 4 +- internal/authenticateflow/stateful_test.go | 10 +-- internal/authenticateflow/stateless.go | 3 +- internal/autocert/manager.go | 12 ++-- internal/autocert/manager_test.go | 2 +- internal/autocert/storage_locker.go | 2 +- internal/controlplane/events.go | 2 +- internal/controlplane/http.go | 5 +- internal/controlplane/server.go | 49 +++++++++------ internal/controlplane/server_test.go | 2 +- internal/controlplane/xdsmgr/log.go | 9 +-- internal/controlplane/xdsmgr/xdsmgr.go | 18 +++--- internal/databroker/config_source.go | 7 +-- internal/databroker/config_source_test.go | 2 +- internal/databroker/registry.go | 16 +++-- internal/databroker/server.go | 62 ++++++++++--------- internal/databroker/server_test.go | 11 ++-- internal/enabler/enabler_test.go | 2 +- internal/fileutil/watcher.go | 3 + internal/log/debug.go | 10 +++ internal/log/log.go | 3 + internal/log/middleware.go | 17 +++++ internal/log/warnings.go | 9 +++ internal/retry/retry.go | 2 +- internal/tests/xdserr/health.go | 2 +- internal/zero/bootstrap/bootstrap.go | 2 +- internal/zero/bootstrap/source.go | 2 +- internal/zero/cmd/command_import.go | 2 +- internal/zero/connect-mux/messages.go | 2 +- .../controller/databroker_restart_test.go | 2 +- internal/zero/controller/leaser.go | 2 +- .../usagereporter/usagereporter_test.go | 2 +- internal/zero/healthcheck/syncer.go | 2 +- internal/zero/reconciler/sync.go | 2 +- internal/zero/telemetry/telemetry.go | 2 +- pkg/cmd/pomerium/pomerium.go | 18 +++--- pkg/envoy/envoy.go | 4 +- pkg/fanout/publish.go | 2 +- pkg/grpc/client.go | 2 +- pkg/grpc/databroker/leaser.go | 4 +- pkg/grpc/databroker/leaser_test.go | 2 +- pkg/grpc/databroker/reconciler.go | 2 +- pkg/grpc/databroker/sync_test.go | 2 +- pkg/grpc/databroker/syncer.go | 12 +++- pkg/grpc/databroker/syncer_test.go | 2 +- pkg/identity/legacymanager/manager.go | 6 +- pkg/identity/legacymanager/sync.go | 4 +- pkg/identity/manager/sync.go | 4 +- pkg/storage/postgres/backend.go | 4 +- pkg/storage/postgres/backend_test.go | 2 +- pkg/storage/postgres/registry_test.go | 8 +-- proxy/data_test.go | 4 +- proxy/handlers_test.go | 11 ++-- proxy/proxy.go | 20 +++--- proxy/proxy_test.go | 4 +- proxy/state.go | 8 +-- 77 files changed, 297 insertions(+), 221 deletions(-) create mode 100644 internal/log/debug.go diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 96fedf6c3..79c6fdf01 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -44,7 +44,7 @@ type Authenticate struct { } // New validates and creates a new authenticate service from a set of Options. -func New(cfg *config.Config, options ...Option) (*Authenticate, error) { +func New(ctx context.Context, cfg *config.Config, options ...Option) (*Authenticate, error) { authenticateConfig := getAuthenticateConfig(options...) a := &Authenticate{ cfg: authenticateConfig, @@ -54,7 +54,7 @@ func New(cfg *config.Config, options ...Option) (*Authenticate, error) { a.options.Store(cfg.Options) - state, err := newAuthenticateStateFromConfig(cfg, authenticateConfig) + state, err := newAuthenticateStateFromConfig(ctx, cfg, authenticateConfig) if err != nil { return nil, err } @@ -70,7 +70,7 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) { } a.options.Store(cfg.Options) - if state, err := newAuthenticateStateFromConfig(cfg, a.cfg); err != nil { + if state, err := newAuthenticateStateFromConfig(ctx, cfg, a.cfg); err != nil { log.Ctx(ctx).Error().Err(err).Msg("authenticate: failed to update state") } else { a.state.Store(state) diff --git a/authenticate/authenticate_test.go b/authenticate/authenticate_test.go index 5583b645a..ddcb696e6 100644 --- a/authenticate/authenticate_test.go +++ b/authenticate/authenticate_test.go @@ -1,6 +1,7 @@ package authenticate import ( + "context" "testing" "github.com/pomerium/pomerium/config" @@ -106,7 +107,7 @@ func TestNew(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - _, err := New(&config.Config{Options: tt.opts}) + _, err := New(context.Background(), &config.Config{Options: tt.opts}) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/authenticate/state.go b/authenticate/state.go index b7e79681d..3680f0f46 100644 --- a/authenticate/state.go +++ b/authenticate/state.go @@ -64,6 +64,7 @@ func newAuthenticateState() *authenticateState { } func newAuthenticateStateFromConfig( + ctx context.Context, cfg *config.Config, authenticateConfig *authenticateConfig, ) (*authenticateState, error) { err := ValidateOptions(cfg.Options) @@ -145,7 +146,7 @@ func newAuthenticateStateFromConfig( } if cfg.Options.UseStatelessAuthenticateFlow() { - state.flow, err = authenticateflow.NewStateless( + state.flow, err = authenticateflow.NewStateless(ctx, cfg, cookieStore, authenticateConfig.getIdentityProvider, @@ -153,7 +154,7 @@ func newAuthenticateStateFromConfig( authenticateConfig.authEventFn, ) } else { - state.flow, err = authenticateflow.NewStateful(cfg, cookieStore) + state.flow, err = authenticateflow.NewStateful(ctx, cfg, cookieStore) } if err != nil { return nil, err diff --git a/authorize/authorize.go b/authorize/authorize.go index d15fddfbd..d6c2b3ffa 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -40,7 +40,7 @@ type Authorize struct { } // New validates and creates a new Authorize service from a set of config options. -func New(cfg *config.Config) (*Authorize, error) { +func New(ctx context.Context, cfg *config.Config) (*Authorize, error) { a := &Authorize{ currentOptions: config.NewAtomicOptions(), store: store.New(), @@ -48,7 +48,7 @@ func New(cfg *config.Config) (*Authorize, error) { } a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod) - state, err := newAuthorizeStateFromConfig(cfg, a.store, nil) + state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, nil) if err != nil { return nil, err } @@ -89,12 +89,13 @@ func validateOptions(o *config.Options) error { // newPolicyEvaluator returns an policy evaluator. func newPolicyEvaluator( + ctx context.Context, opts *config.Options, store *store.Store, previous *evaluator.Evaluator, ) (*evaluator.Evaluator, error) { metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 { return int64(opts.NumPolicies()) }) - ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context { + ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { return c.Str("service", "authorize") }) ctx, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator") @@ -150,7 +151,7 @@ func newPolicyEvaluator( func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) { currentState := a.state.Load() a.currentOptions.Store(cfg.Options) - if state, err := newAuthorizeStateFromConfig(cfg, a.store, currentState.evaluator); err != nil { + if state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, currentState.evaluator); err != nil { log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state") } else { a.state.Store(state) diff --git a/authorize/authorize_test.go b/authorize/authorize_test.go index e545a8acf..3e020429e 100644 --- a/authorize/authorize_test.go +++ b/authorize/authorize_test.go @@ -82,7 +82,7 @@ func TestNew(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := New(&config.Config{Options: &tt.config}) + _, err := New(context.Background(), &config.Config{Options: &tt.config}) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return @@ -114,7 +114,7 @@ func TestAuthorize_OnConfigChange(t *testing.T) { SharedKey: tc.SharedKey, Policies: tc.Policies, } - a, err := New(&config.Config{Options: o}) + a, err := New(context.Background(), &config.Config{Options: o}) require.NoError(t, err) require.NotNil(t, a) @@ -185,7 +185,7 @@ func TestNewPolicyEvaluator_addDefaultClientCertificateRule(t *testing.T) { c.opts.Policies = []config.Policy{{ To: mustParseWeightedURLs(t, "http://example.com"), }} - e, err := newPolicyEvaluator(c.opts, store, nil) + e, err := newPolicyEvaluator(context.Background(), c.opts, store, nil) require.NoError(t, err) r, err := e.Evaluate(context.Background(), &evaluator.Request{ diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index 9b171eafe..513d366b5 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -34,7 +34,7 @@ func TestAuthorize_handleResult(t *testing.T) { t.Cleanup(authnSrv.Close) opt.AuthenticateURLString = authnSrv.URL - a, err := New(&config.Config{Options: opt}) + a, err := New(context.Background(), &config.Config{Options: opt}) require.NoError(t, err) t.Run("user-unauthenticated", func(t *testing.T) { @@ -129,7 +129,7 @@ func TestAuthorize_okResponse(t *testing.T) { a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} a.currentOptions.Store(opt) a.store = store.New() - pe, err := newPolicyEvaluator(opt, a.store, nil) + pe, err := newPolicyEvaluator(context.Background(), opt, a.store, nil) require.NoError(t, err) a.state.Load().evaluator = pe @@ -327,7 +327,7 @@ func TestRequireLogin(t *testing.T) { t.Cleanup(authnSrv.Close) opt.AuthenticateURLString = authnSrv.URL - a, err := New(&config.Config{Options: opt}) + a, err := New(context.Background(), &config.Config{Options: opt}) require.NoError(t, err) t.Run("accept empty", func(t *testing.T) { diff --git a/authorize/databroker_test.go b/authorize/databroker_test.go index a9e441c21..1c773fc21 100644 --- a/authorize/databroker_test.go +++ b/authorize/databroker_test.go @@ -65,7 +65,7 @@ func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) { t.Cleanup(clearTimeout) opt := config.NewDefaultOptions() - a, err := New(&config.Config{Options: opt}) + a, err := New(context.Background(), &config.Config{Options: opt}) require.NoError(t, err) s1 := &session.Session{Id: "s1", ExpiresAt: timestamppb.New(time.Now().Add(-time.Second))} diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index c0d1da608..50ede2071 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -108,7 +108,7 @@ func New( ) (*Evaluator, error) { cfg := getConfig(options...) - err := updateStore(store, cfg) + err := updateStore(ctx, store, cfg) if err != nil { return nil, err } @@ -325,8 +325,8 @@ func (e *Evaluator) getClientCA(policy *config.Policy) (string, error) { return string(e.clientCA), nil } -func updateStore(store *store.Store, cfg *evaluatorConfig) error { - jwk, err := getJWK(cfg) +func updateStore(ctx context.Context, store *store.Store, cfg *evaluatorConfig) error { + jwk, err := getJWK(ctx, cfg) if err != nil { return fmt.Errorf("authorize: couldn't create signer: %w", err) } @@ -341,7 +341,7 @@ func updateStore(store *store.Store, cfg *evaluatorConfig) error { return nil } -func getJWK(cfg *evaluatorConfig) (*jose.JSONWebKey, error) { +func getJWK(ctx context.Context, cfg *evaluatorConfig) (*jose.JSONWebKey, error) { var decodedCert []byte // if we don't have a signing key, generate one if len(cfg.SigningKey) == 0 { @@ -361,7 +361,7 @@ func getJWK(cfg *evaluatorConfig) (*jose.JSONWebKey, error) { if err != nil { return nil, fmt.Errorf("couldn't generate signing key: %w", err) } - log.Info().Str("Algorithm", jwk.Algorithm). + log.Ctx(ctx).Info().Str("Algorithm", jwk.Algorithm). Str("KeyID", jwk.KeyID). Interface("Public Key", jwk.Public()). Msg("authorize: signing key") diff --git a/authorize/evaluator/policy_evaluator.go b/authorize/evaluator/policy_evaluator.go index 894441139..a0b7fef96 100644 --- a/authorize/evaluator/policy_evaluator.go +++ b/authorize/evaluator/policy_evaluator.go @@ -147,7 +147,8 @@ func NewPolicyEvaluator( // for each script, create a rego and prepare a query. for i := range e.queries { - log.Ctx(ctx).Debug(). + log.Ctx(ctx). + Trace(). Str("script", e.queries[i].script). Str("from", configPolicy.From). Interface("to", configPolicy.To). diff --git a/authorize/state.go b/authorize/state.go index 115b0bea2..a94e0643f 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -33,6 +33,7 @@ type authorizeState struct { } func newAuthorizeStateFromConfig( + ctx context.Context, cfg *config.Config, store *store.Store, previousPolicyEvaluator *evaluator.Evaluator, ) (*authorizeState, error) { if err := validateOptions(cfg.Options); err != nil { @@ -43,7 +44,7 @@ func newAuthorizeStateFromConfig( var err error - state.evaluator, err = newPolicyEvaluator(cfg.Options, store, previousPolicyEvaluator) + state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousPolicyEvaluator) if err != nil { return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err) } @@ -58,7 +59,7 @@ func newAuthorizeStateFromConfig( return nil, err } - cc, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{ + cc, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{ OutboundPort: cfg.OutboundPort, InstallationID: cfg.Options.InstallationID, ServiceName: cfg.Options.Services, @@ -84,9 +85,9 @@ func newAuthorizeStateFromConfig( } if cfg.Options.UseStatelessAuthenticateFlow() { - state.authenticateFlow, err = authenticateflow.NewStateless(cfg, nil, nil, nil, nil) + state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, nil, nil, nil, nil) } else { - state.authenticateFlow, err = authenticateflow.NewStateful(cfg, nil) + state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, nil) } if err != nil { return nil, err diff --git a/cmd/pomerium/main.go b/cmd/pomerium/main.go index 93c942245..9c035e8fa 100644 --- a/cmd/pomerium/main.go +++ b/cmd/pomerium/main.go @@ -54,7 +54,7 @@ func run(ctx context.Context, configFile string) error { var src config.Source - src, err := config.NewFileOrEnvironmentSource(configFile, files.FullVersion()) + src, err := config.NewFileOrEnvironmentSource(ctx, configFile, files.FullVersion()) if err != nil { return err } diff --git a/config/config_source.go b/config/config_source.go index 4af0b24fd..2de950c1f 100644 --- a/config/config_source.go +++ b/config/config_source.go @@ -103,9 +103,10 @@ type FileOrEnvironmentSource struct { // NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource. func NewFileOrEnvironmentSource( + ctx context.Context, configFile, envoyVersion string, ) (*FileOrEnvironmentSource, error) { - ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context { + ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { return c.Str("config_file_source", configFile) }) diff --git a/config/config_source_test.go b/config/config_source_test.go index 5c5b18268..4ff6ce7f4 100644 --- a/config/config_source_test.go +++ b/config/config_source_test.go @@ -137,7 +137,7 @@ runtime_flags: require.NoError(t, err) var src Source - src, err = NewFileOrEnvironmentSource(configFilePath, "") + src, err = NewFileOrEnvironmentSource(context.Background(), configFilePath, "") require.NoError(t, err) src = NewFileWatcherSource(context.Background(), src) diff --git a/config/envoyconfig/clusters.go b/config/envoyconfig/clusters.go index e41138292..2d50ddbb0 100644 --- a/config/envoyconfig/clusters.go +++ b/config/envoyconfig/clusters.go @@ -247,7 +247,7 @@ func (b *Builder) buildInternalTransportSocket( b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName), }, } - bs, err := getCombinedCertificateAuthority(cfg) + bs, err := getCombinedCertificateAuthority(ctx, cfg) if err != nil { log.Ctx(ctx).Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found") } else { @@ -347,7 +347,7 @@ func (b *Builder) buildPolicyValidationContext( } validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs) } else { - bs, err := getCombinedCertificateAuthority(cfg) + bs, err := getCombinedCertificateAuthority(ctx, cfg) if err != nil { log.Ctx(ctx).Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found") } else { diff --git a/config/envoyconfig/clusters_test.go b/config/envoyconfig/clusters_test.go index dc7c1d5d7..1dcbbf7df 100644 --- a/config/envoyconfig/clusters_test.go +++ b/config/envoyconfig/clusters_test.go @@ -43,14 +43,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) { customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-57394a4e5157303436544830.pem") b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) - rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}}) + rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}}) rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename() o1 := config.NewDefaultOptions() o2 := config.NewDefaultOptions() o2.CA = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}) - combinedCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{CA: o2.CA}}) + combinedCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{CA: o2.CA}}) combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename() t.Run("insecure", func(t *testing.T) { @@ -522,7 +522,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { func Test_buildCluster(t *testing.T) { ctx := context.Background() b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) - rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}}) + rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}}) rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename() o1 := config.NewDefaultOptions() t.Run("insecure", func(t *testing.T) { diff --git a/config/envoyconfig/envoyconfig.go b/config/envoyconfig/envoyconfig.go index 24c749119..77fcf9589 100644 --- a/config/envoyconfig/envoyconfig.go +++ b/config/envoyconfig/envoyconfig.go @@ -189,7 +189,7 @@ var rootCABundle struct { value string } -func getRootCertificateAuthority() (string, error) { +func getRootCertificateAuthority(ctx context.Context) (string, error) { rootCABundle.Do(func() { // from https://www.envoyproxy.io/docs/envoy/latest/intro/arch_overview/security/ssl#arch-overview-ssl-enabling-verification knownRootLocations := []string{ @@ -207,10 +207,10 @@ func getRootCertificateAuthority() (string, error) { } } if rootCABundle.value == "" { - log.Error().Strs("known-locations", knownRootLocations). + log.Ctx(ctx).Error().Strs("known-locations", knownRootLocations). Msgf("no root certificates were found in any of the known locations") } else { - log.Info().Msgf("using %s as the system root certificate authority bundle", rootCABundle.value) + log.Ctx(ctx).Info().Msgf("using %s as the system root certificate authority bundle", rootCABundle.value) } }) if rootCABundle.value == "" { @@ -219,8 +219,8 @@ func getRootCertificateAuthority() (string, error) { return rootCABundle.value, nil } -func getCombinedCertificateAuthority(cfg *config.Config) ([]byte, error) { - rootFile, err := getRootCertificateAuthority() +func getCombinedCertificateAuthority(ctx context.Context, cfg *config.Config) ([]byte, error) { + rootFile, err := getRootCertificateAuthority(ctx) if err != nil { return nil, err } diff --git a/databroker/cache.go b/databroker/cache.go index 55785734b..a26a8ebcd 100644 --- a/databroker/cache.go +++ b/databroker/cache.go @@ -45,7 +45,7 @@ type DataBroker struct { } // New creates a new databroker service. -func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) { +func New(ctx context.Context, cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) { localListener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, err @@ -61,8 +61,8 @@ func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) { // No metrics handler because we have one in the control plane. Add one // if we no longer register with that grpc Server localGRPCServer := grpc.NewServer( - grpc.StreamInterceptor(si), - grpc.UnaryInterceptor(ui), + grpc.ChainStreamInterceptor(log.StreamServerInterceptor(log.Ctx(ctx)), si), + grpc.ChainUnaryInterceptor(log.UnaryServerInterceptor(log.Ctx(ctx)), ui), ) sharedKey, err := cfg.Options.GetSharedKey() @@ -79,7 +79,7 @@ func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) { grpc.WithStatsHandler(clientStatsHandler.Handler), } - ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context { + ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { return c.Str("service", "databroker").Str("config_source", "bootstrap") }) localGRPCConnection, err := grpc.DialContext( @@ -91,7 +91,7 @@ func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) { return nil, err } - dataBrokerServer, err := newDataBrokerServer(cfg) + dataBrokerServer, err := newDataBrokerServer(ctx, cfg) if err != nil { return nil, err } diff --git a/databroker/cache_test.go b/databroker/cache_test.go index e7ba0355c..64d212931 100644 --- a/databroker/cache_test.go +++ b/databroker/cache_test.go @@ -1,6 +1,7 @@ package databroker import ( + "context" "testing" "github.com/pomerium/pomerium/config" @@ -20,7 +21,7 @@ func TestNew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.opts.Provider = "google" - _, err := New(&config.Config{Options: &tt.opts}, events.New()) + _, err := New(context.Background(), &config.Config{Options: &tt.opts}, events.New()) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/databroker/databroker.go b/databroker/databroker.go index 75029b747..7817f403d 100644 --- a/databroker/databroker.go +++ b/databroker/databroker.go @@ -23,7 +23,7 @@ type dataBrokerServer struct { } // newDataBrokerServer creates a new databroker service server. -func newDataBrokerServer(cfg *config.Config) (*dataBrokerServer, error) { +func newDataBrokerServer(ctx context.Context, cfg *config.Config) (*dataBrokerServer, error) { srv := &dataBrokerServer{ sharedKey: atomicutil.NewValue([]byte{}), } @@ -33,7 +33,7 @@ func newDataBrokerServer(cfg *config.Config) (*dataBrokerServer, error) { return nil, err } - srv.server = databroker.New(opts...) + srv.server = databroker.New(ctx, opts...) srv.setKey(cfg) return srv, nil } @@ -46,7 +46,7 @@ func (srv *dataBrokerServer) OnConfigChange(ctx context.Context, cfg *config.Con return } - srv.server.UpdateConfig(opts...) + srv.server.UpdateConfig(ctx, opts...) srv.setKey(cfg) } diff --git a/databroker/databroker_test.go b/databroker/databroker_test.go index 04b774cb4..1d1fa46bf 100644 --- a/databroker/databroker_test.go +++ b/databroker/databroker_test.go @@ -29,7 +29,7 @@ var lis *bufconn.Listener func init() { lis = bufconn.Listen(bufSize) s := grpc.NewServer() - internalSrv := internal_databroker.New() + internalSrv := internal_databroker.New(context.Background()) srv := &dataBrokerServer{server: internalSrv, sharedKey: atomicutil.NewValue([]byte{})} databroker.RegisterDataBrokerServiceServer(s, srv) diff --git a/integration/main_test.go b/integration/main_test.go index d7b9f5c88..cc3dec380 100644 --- a/integration/main_test.go +++ b/integration/main_test.go @@ -163,7 +163,7 @@ func waitForHealthy(ctx context.Context) error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-ticker.C: } } diff --git a/internal/authenticateflow/stateful.go b/internal/authenticateflow/stateful.go index 20a1b5ef0..44050325d 100644 --- a/internal/authenticateflow/stateful.go +++ b/internal/authenticateflow/stateful.go @@ -56,7 +56,7 @@ type Stateful struct { // NewStateful initializes the authentication flow for the given configuration // and session store. -func NewStateful(cfg *config.Config, sessionStore sessions.SessionStore) (*Stateful, error) { +func NewStateful(ctx context.Context, cfg *config.Config, sessionStore sessions.SessionStore) (*Stateful, error) { s := &Stateful{ sessionDuration: cfg.Options.CookieExpire, sessionStore: sessionStore, @@ -88,7 +88,7 @@ func NewStateful(cfg *config.Config, sessionStore sessions.SessionStore) (*State s.defaultIdentityProviderID = idp.GetId() } - dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), + dataBrokerConn, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{ OutboundPort: cfg.OutboundPort, InstallationID: cfg.Options.InstallationID, diff --git a/internal/authenticateflow/stateful_test.go b/internal/authenticateflow/stateful_test.go index 8e2527b90..173374410 100644 --- a/internal/authenticateflow/stateful_test.go +++ b/internal/authenticateflow/stateful_test.go @@ -69,7 +69,7 @@ func TestStatefulSignIn(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { sessionStore := &mstore.Store{SaveError: tt.saveError} - flow, err := NewStateful(&config.Config{Options: opts}, sessionStore) + flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, sessionStore) if err != nil { t.Fatal(err) } @@ -123,7 +123,7 @@ func TestStatefulAuthenticateSignInURL(t *testing.T) { opts.AuthenticateURLString = "https://authenticate.example.com" key := cryptutil.NewKey() opts.SharedKey = base64.StdEncoding.EncodeToString(key) - flow, err := NewStateful(&config.Config{Options: opts}, nil) + flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil) require.NoError(t, err) t.Run("NilQueryParams", func(t *testing.T) { @@ -238,7 +238,7 @@ func TestStatefulCallback(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - flow, err := NewStateful(&config.Config{Options: opts}, tt.sessionStore) + flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, tt.sessionStore) if err != nil { t.Fatal(err) } @@ -289,7 +289,7 @@ func TestStatefulCallback(t *testing.T) { func TestStatefulRevokeSession(t *testing.T) { opts := config.NewDefaultOptions() - flow, err := NewStateful(&config.Config{Options: opts}, nil) + flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil) require.NoError(t, err) ctrl := gomock.NewController(t) @@ -367,7 +367,7 @@ func TestPersistSession(t *testing.T) { opts := config.NewDefaultOptions() opts.CookieExpire = 4 * time.Hour - flow, err := NewStateful(&config.Config{Options: opts}, nil) + flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil) require.NoError(t, err) ctrl := gomock.NewController(t) diff --git a/internal/authenticateflow/stateless.go b/internal/authenticateflow/stateless.go index e1997d98d..e89e8199a 100644 --- a/internal/authenticateflow/stateless.go +++ b/internal/authenticateflow/stateless.go @@ -64,6 +64,7 @@ type Stateless struct { // NewStateless initializes the authentication flow for the given // configuration, session store, and additional options. func NewStateless( + ctx context.Context, cfg *config.Config, sessionStore sessions.SessionStore, getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error), @@ -131,7 +132,7 @@ func NewStateless( return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err) } - dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{ + dataBrokerConn, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{ OutboundPort: cfg.OutboundPort, InstallationID: cfg.Options.InstallationID, ServiceName: cfg.Options.Services, diff --git a/internal/autocert/manager.go b/internal/autocert/manager.go index 3ca276980..429f2173e 100644 --- a/internal/autocert/manager.go +++ b/internal/autocert/manager.go @@ -63,11 +63,12 @@ type Manager struct { } // New creates a new autocert manager. -func New(src config.Source) (*Manager, error) { - return newManager(context.Background(), src, certmagic.DefaultACME, renewalInterval) +func New(ctx context.Context, src config.Source) (*Manager, error) { + return newManager(ctx, src, certmagic.DefaultACME, renewalInterval) } -func newManager(ctx context.Context, +func newManager( + ctx context.Context, src config.Source, acmeTemplate certmagic.ACMEIssuer, checkInterval time.Duration, @@ -96,12 +97,13 @@ func newManager(ctx context.Context, if err != nil { return nil, err } - mgr.certmagic = certmagic.New(certmagic.NewCache(certmagic.CacheOptions{ + cache := certmagic.NewCache(certmagic.CacheOptions{ GetConfigForCert: func(_ certmagic.Certificate) (*certmagic.Config, error) { return mgr.certmagic, nil }, Logger: logger, - }), certmagic.Config{ + }) + mgr.certmagic = certmagic.New(cache, certmagic.Config{ Logger: logger, Storage: certmagicStorage, }) diff --git a/internal/autocert/manager_test.go b/internal/autocert/manager_test.go index ae3fbc1a0..5538a246c 100644 --- a/internal/autocert/manager_test.go +++ b/internal/autocert/manager_test.go @@ -316,7 +316,7 @@ func TestRedirect(t *testing.T) { }, }, }) - _, err = New(src) + _, err = New(context.Background(), src) if !assert.NoError(t, err) { return } diff --git a/internal/autocert/storage_locker.go b/internal/autocert/storage_locker.go index 950d98aef..f215a8e10 100644 --- a/internal/autocert/storage_locker.go +++ b/internal/autocert/storage_locker.go @@ -47,7 +47,7 @@ func (l *locker) Lock(ctx context.Context, name string) error { // wait select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-time.After(lockPollInterval): } continue diff --git a/internal/controlplane/events.go b/internal/controlplane/events.go index 578ac07b1..71c866376 100644 --- a/internal/controlplane/events.go +++ b/internal/controlplane/events.go @@ -71,7 +71,7 @@ func (srv *Server) getDataBrokerClient(ctx context.Context) (databrokerpb.DataBr return nil, err } - cc, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{ + cc, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{ OutboundPort: cfg.OutboundPort, InstallationID: cfg.Options.InstallationID, ServiceName: cfg.Options.Services, diff --git a/internal/controlplane/http.go b/internal/controlplane/http.go index 840f42ab1..d648fefbf 100644 --- a/internal/controlplane/http.go +++ b/internal/controlplane/http.go @@ -8,6 +8,7 @@ import ( "github.com/CAFxX/httpcompression" "github.com/gorilla/mux" + "github.com/rs/zerolog" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/handlers" @@ -19,7 +20,7 @@ import ( "github.com/pomerium/pomerium/pkg/telemetry/requestid" ) -func (srv *Server) addHTTPMiddleware(root *mux.Router, _ *config.Config) { +func (srv *Server) addHTTPMiddleware(root *mux.Router, logger *zerolog.Logger, _ *config.Config) { compressor, err := httpcompression.DefaultAdapter() if err != nil { panic(err) @@ -28,7 +29,7 @@ func (srv *Server) addHTTPMiddleware(root *mux.Router, _ *config.Config) { root.Use(compressor) root.Use(srv.reproxy.Middleware) root.Use(requestid.HTTPMiddleware()) - root.Use(log.NewHandler(log.Logger)) + root.Use(log.NewHandler(func() *zerolog.Logger { return logger })) root.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) { log.FromRequest(r).Debug(). Dur("duration", duration). diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index b3bac98d6..cfd74effc 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -69,10 +69,16 @@ type Server struct { } // NewServer creates a new Server. Listener ports are chosen by the OS. -func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr *events.Manager) (*Server, error) { +func NewServer( + ctx context.Context, + cfg *config.Config, + metricsMgr *config.MetricsManager, + eventsMgr *events.Manager, +) (*Server, error) { srv := &Server{ metricsMgr: metricsMgr, EventsMgr: eventsMgr, + filemgr: filemgr.NewManager(), reproxy: reproxy.New(), haveSetCapacity: map[string]bool{}, updateConfig: make(chan *config.Config, 1), @@ -80,6 +86,10 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr httpRouter: atomicutil.NewValue(mux.NewRouter()), } + ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { + return c.Str("server_name", cfg.Options.Services) + }) + var err error // setup gRPC @@ -95,8 +105,16 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr ) srv.GRPCServer = grpc.NewServer( grpc.StatsHandler(telemetry.NewGRPCServerStatsHandler(cfg.Options.Services)), - grpc.ChainUnaryInterceptor(requestid.UnaryServerInterceptor(), ui), - grpc.ChainStreamInterceptor(requestid.StreamServerInterceptor(), si), + grpc.ChainUnaryInterceptor( + log.UnaryServerInterceptor(log.Ctx(ctx)), + requestid.UnaryServerInterceptor(), + ui, + ), + grpc.ChainStreamInterceptor( + log.StreamServerInterceptor(log.Ctx(ctx)), + requestid.StreamServerInterceptor(), + si, + ), ) reflection.Register(srv.GRPCServer) srv.registerAccessLogHandlers() @@ -125,7 +143,7 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr return nil, err } - if err := srv.updateRouter(cfg); err != nil { + if err := srv.updateRouter(ctx, cfg); err != nil { return nil, err } srv.DebugRouter = mux.NewRouter() @@ -141,7 +159,6 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr // metrics srv.MetricsRouter.Handle("/metrics", srv.metricsMgr) - srv.filemgr = filemgr.NewManager() srv.filemgr.ClearCache() srv.Builder = envoyconfig.New( @@ -152,10 +169,6 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr srv.reproxy, ) - ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context { - return c.Str("server_name", cfg.Options.Services) - }) - res, err := srv.buildDiscoveryResources(ctx) if err != nil { return nil, err @@ -211,7 +224,7 @@ func (srv *Server) Run(ctx context.Context) error { for { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case cfg := <-srv.updateConfig: err := srv.update(ctx, cfg) if err != nil { @@ -232,29 +245,29 @@ func (srv *Server) OnConfigChange(ctx context.Context, cfg *config.Config) error select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case srv.updateConfig <- cfg: } return nil } // EnableAuthenticate enables the authenticate service. -func (srv *Server) EnableAuthenticate(svc Service) error { +func (srv *Server) EnableAuthenticate(ctx context.Context, svc Service) error { srv.authenticateSvc = svc - return srv.updateRouter(srv.currentConfig.Load()) + return srv.updateRouter(ctx, srv.currentConfig.Load()) } // EnableProxy enables the proxy service. -func (srv *Server) EnableProxy(svc Service) error { +func (srv *Server) EnableProxy(ctx context.Context, svc Service) error { srv.proxySvc = svc - return srv.updateRouter(srv.currentConfig.Load()) + return srv.updateRouter(ctx, srv.currentConfig.Load()) } func (srv *Server) update(ctx context.Context, cfg *config.Config) error { ctx, span := trace.StartSpan(ctx, "controlplane.Server.update") defer span.End() - if err := srv.updateRouter(cfg); err != nil { + if err := srv.updateRouter(ctx, cfg); err != nil { return err } srv.reproxy.Update(ctx, cfg) @@ -267,9 +280,9 @@ func (srv *Server) update(ctx context.Context, cfg *config.Config) error { return nil } -func (srv *Server) updateRouter(cfg *config.Config) error { +func (srv *Server) updateRouter(ctx context.Context, cfg *config.Config) error { httpRouter := mux.NewRouter() - srv.addHTTPMiddleware(httpRouter, cfg) + srv.addHTTPMiddleware(httpRouter, log.Ctx(ctx), cfg) if err := srv.mountCommonEndpoints(httpRouter, cfg); err != nil { return err } diff --git a/internal/controlplane/server_test.go b/internal/controlplane/server_test.go index 792e661b3..f7ae90155 100644 --- a/internal/controlplane/server_test.go +++ b/internal/controlplane/server_test.go @@ -38,7 +38,7 @@ func TestServerHTTP(t *testing.T) { cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs=" src := config.NewStaticSource(cfg) - srv, err := NewServer(cfg, config.NewMetricsManager(ctx, src), events.New()) + srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New()) require.NoError(t, err) go srv.Run(ctx) diff --git a/internal/controlplane/xdsmgr/log.go b/internal/controlplane/xdsmgr/log.go index d975a9d9f..7c7e6446e 100644 --- a/internal/controlplane/xdsmgr/log.go +++ b/internal/controlplane/xdsmgr/log.go @@ -1,6 +1,7 @@ package xdsmgr import ( + "context" "errors" envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" @@ -19,8 +20,8 @@ var ( routeConfigurationTypeURL = protoutil.GetTypeURL((*envoy_config_route_v3.RouteConfiguration)(nil)) ) -func logNACK(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) { - log.Debug(). +func logNACK(ctx context.Context, req *envoy_service_discovery_v3.DeltaDiscoveryRequest) { + log.Ctx(ctx).Debug(). Str("type-url", req.GetTypeUrl()). Any("error-detail", req.GetErrorDetail()). Msg("xdsmgr: nack") @@ -28,8 +29,8 @@ func logNACK(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) { health.ReportError(getHealthCheck(req.GetTypeUrl()), errors.New(req.GetErrorDetail().GetMessage())) } -func logACK(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) { - log.Debug(). +func logACK(ctx context.Context, req *envoy_service_discovery_v3.DeltaDiscoveryRequest) { + log.Ctx(ctx).Debug(). Str("type-url", req.GetTypeUrl()). Msg("xdsmgr: ack") diff --git a/internal/controlplane/xdsmgr/xdsmgr.go b/internal/controlplane/xdsmgr/xdsmgr.go index d46664790..b402ac2d9 100644 --- a/internal/controlplane/xdsmgr/xdsmgr.go +++ b/internal/controlplane/xdsmgr/xdsmgr.go @@ -113,7 +113,7 @@ func (mgr *Manager) DeltaAggregatedResources( for _, resource := range mgr.resources[req.GetTypeUrl()] { state.clientResourceVersions[resource.Name] = resource.Version } - logNACK(req) + logNACK(ctx, req) case req.GetResponseNonce() == mgr.nonce: // an ACK for the last response // - set the client resource versions to the current resource versions @@ -121,10 +121,11 @@ func (mgr *Manager) DeltaAggregatedResources( for _, resource := range mgr.resources[req.GetTypeUrl()] { state.clientResourceVersions[resource.Name] = resource.Version } - logACK(req) + logACK(ctx, req) default: // an ACK for a response that's not the last response - log.Ctx(ctx).Debug(). + log.Ctx(ctx). + Debug(). Str("type-url", req.GetTypeUrl()). Msg("xdsmgr: ack") } @@ -161,7 +162,7 @@ func (mgr *Manager) DeltaAggregatedResources( select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case incoming <- req: } } @@ -173,7 +174,7 @@ func (mgr *Manager) DeltaAggregatedResources( var typeURLs []string select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case req := <-incoming: handleDeltaRequest(changeCtx, req) typeURLs = []string{req.GetTypeUrl()} @@ -193,7 +194,7 @@ func (mgr *Manager) DeltaAggregatedResources( select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case outgoing <- res: } } @@ -204,9 +205,10 @@ func (mgr *Manager) DeltaAggregatedResources( for { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case res := <-outgoing: - log.Ctx(ctx).Debug(). + log.Ctx(ctx). + Debug(). Str("type-url", res.GetTypeUrl()). Int("resource-count", len(res.GetResources())). Int("removed-resource-count", len(res.GetRemovedResources())). diff --git a/internal/databroker/config_source.go b/internal/databroker/config_source.go index d34507fdd..7a379777b 100644 --- a/internal/databroker/config_source.go +++ b/internal/databroker/config_source.go @@ -96,7 +96,7 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) { cfg := src.underlyingConfig.Clone() // start the updater - src.runUpdater(cfg) + src.runUpdater(ctx, cfg) now = time.Now() err := src.buildNewConfigLocked(ctx, cfg) @@ -234,7 +234,7 @@ func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, po cfg.Options.AdditionalPolicies = append(cfg.Options.AdditionalPolicies, additionalPolicies...) } -func (src *ConfigSource) runUpdater(cfg *config.Config) { +func (src *ConfigSource) runUpdater(ctx context.Context, cfg *config.Config) { sharedKey, _ := cfg.Options.GetSharedKey() connectionOptions := &grpc.OutboundOptions{ OutboundPort: cfg.OutboundPort, @@ -257,7 +257,6 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) { src.cancel = nil } - ctx := context.Background() ctx, src.cancel = context.WithCancel(ctx) cc, err := src.outboundGRPCConnection.Get(ctx, connectionOptions) @@ -268,7 +267,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) { client := databroker.NewDataBrokerServiceClient(cc) - syncer := databroker.NewSyncer("databroker", &syncerHandler{ + syncer := databroker.NewSyncer(ctx, "databroker", &syncerHandler{ client: client, src: src, }, databroker.WithTypeURL(grpcutil.GetTypeURL(new(configpb.Config))), diff --git a/internal/databroker/config_source_test.go b/internal/databroker/config_source_test.go index 063fd94be..4116b5b1e 100644 --- a/internal/databroker/config_source_test.go +++ b/internal/databroker/config_source_test.go @@ -41,7 +41,7 @@ func TestConfigSource(t *testing.T) { defer func() { _ = li.Close() }() _, outboundPort, _ := net.SplitHostPort(li.Addr().String()) - dataBrokerServer := New() + dataBrokerServer := New(ctx) srv := grpc.NewServer() databroker.RegisterDataBrokerServiceServer(srv, dataBrokerServer) go func() { _ = srv.Serve(li) }() diff --git a/internal/databroker/registry.go b/internal/databroker/registry.go index 387525c5c..fee359e4a 100644 --- a/internal/databroker/registry.go +++ b/internal/databroker/registry.go @@ -28,7 +28,7 @@ func (srv *Server) Report(ctx context.Context, req *registrypb.RegisterRequest) ctx, span := trace.StartSpan(ctx, "databroker.grpc.Report") defer span.End() - r, err := srv.getRegistry() + r, err := srv.getRegistry(ctx) if err != nil { return nil, err } @@ -41,7 +41,7 @@ func (srv *Server) List(ctx context.Context, req *registrypb.ListRequest) (*regi ctx, span := trace.StartSpan(ctx, "databroker.grpc.List") defer span.End() - r, err := srv.getRegistry() + r, err := srv.getRegistry(ctx) if err != nil { return nil, err } @@ -55,7 +55,7 @@ func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry ctx, span := trace.StartSpan(ctx, "databroker.grpc.Watch") defer span.End() - r, err := srv.getRegistry() + r, err := srv.getRegistry(ctx) if err != nil { return err } @@ -66,8 +66,8 @@ func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry }) } -func (srv *Server) getRegistry() (registry.Interface, error) { - backend, err := srv.getBackend() +func (srv *Server) getRegistry(ctx context.Context) (registry.Interface, error) { + backend, err := srv.getBackend(ctx) if err != nil { return nil, err } @@ -81,7 +81,7 @@ func (srv *Server) getRegistry() (registry.Interface, error) { r = srv.registry var err error if r == nil { - r, err = srv.newRegistryLocked(backend) + r, err = srv.newRegistryLocked(ctx, backend) srv.registry = r } srv.mu.Unlock() @@ -92,9 +92,7 @@ func (srv *Server) getRegistry() (registry.Interface, error) { return r, nil } -func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interface, error) { - ctx := context.Background() - +func (srv *Server) newRegistryLocked(ctx context.Context, backend storage.Backend) (registry.Interface, error) { if hasRegistryServer, ok := backend.(interface { RegistryServer() registrypb.RegistryServer }); ok { diff --git a/internal/databroker/server.go b/internal/databroker/server.go index 17c7731a6..4a6f207b3 100644 --- a/internal/databroker/server.go +++ b/internal/databroker/server.go @@ -28,25 +28,26 @@ import ( type Server struct { cfg *serverConfig - mu sync.RWMutex - backend storage.Backend - registry registry.Interface + mu sync.RWMutex + backend storage.Backend + backendCtx context.Context + registry registry.Interface } // New creates a new server. -func New(options ...ServerOption) *Server { - srv := &Server{} - srv.UpdateConfig(options...) +func New(ctx context.Context, options ...ServerOption) *Server { + srv := &Server{ + backendCtx: ctx, + } + srv.UpdateConfig(ctx, options...) return srv } // UpdateConfig updates the server with the new options. -func (srv *Server) UpdateConfig(options ...ServerOption) { +func (srv *Server) UpdateConfig(ctx context.Context, options ...ServerOption) { srv.mu.Lock() defer srv.mu.Unlock() - ctx := context.TODO() - cfg := newServerConfig(options...) if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) { log.Ctx(ctx).Debug().Msg("databroker: no changes detected, re-using existing DBs") @@ -80,7 +81,7 @@ func (srv *Server) AcquireLease(ctx context.Context, req *databroker.AcquireLeas Dur("duration", req.GetDuration().AsDuration()). Msg("acquire lease") - db, err := srv.getBackend() + db, err := srv.getBackend(ctx) if err != nil { return nil, err } @@ -107,7 +108,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr Str("id", req.GetId()). Msg("get") - db, err := srv.getBackend() + db, err := srv.getBackend(ctx) if err != nil { return nil, err } @@ -131,7 +132,7 @@ func (srv *Server) ListTypes(ctx context.Context, _ *emptypb.Empty) (*databroker defer span.End() log.Ctx(ctx).Debug().Msg("list types") - db, err := srv.getBackend() + db, err := srv.getBackend(ctx) if err != nil { return nil, err } @@ -156,7 +157,7 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da query := strings.ToLower(req.GetQuery()) - db, err := srv.getBackend() + db, err := srv.getBackend(ctx) if err != nil { return nil, err } @@ -217,7 +218,7 @@ func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databr Msg("put") } - db, err := srv.getBackend() + db, err := srv.getBackend(ctx) if err != nil { return nil, err } @@ -256,7 +257,7 @@ func (srv *Server) Patch(ctx context.Context, req *databroker.PatchRequest) (*da Msg("patch") } - db, err := srv.getBackend() + db, err := srv.getBackend(ctx) if err != nil { return nil, err } @@ -282,7 +283,7 @@ func (srv *Server) ReleaseLease(ctx context.Context, req *databroker.ReleaseLeas Str("id", req.GetId()). Msg("release lease") - db, err := srv.getBackend() + db, err := srv.getBackend(ctx) if err != nil { return nil, err } @@ -305,7 +306,7 @@ func (srv *Server) RenewLease(ctx context.Context, req *databroker.RenewLeaseReq Dur("duration", req.GetDuration().AsDuration()). Msg("renew lease") - db, err := srv.getBackend() + db, err := srv.getBackend(ctx) if err != nil { return nil, err } @@ -325,7 +326,7 @@ func (srv *Server) SetOptions(ctx context.Context, req *databroker.SetOptionsReq ctx, span := trace.StartSpan(ctx, "databroker.grpc.SetOptions") defer span.End() - backend, err := srv.getBackend() + backend, err := srv.getBackend(ctx) if err != nil { return nil, err } @@ -351,12 +352,13 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke ctx, cancel := context.WithCancel(ctx) defer cancel() - log.Ctx(ctx).Debug(). + log.Ctx(ctx). + Debug(). Uint64("server_version", req.GetServerVersion()). Uint64("record_version", req.GetRecordVersion()). Msg("sync") - backend, err := srv.getBackend() + backend, err := srv.getBackend(ctx) if err != nil { return err } @@ -392,7 +394,7 @@ func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databrok Str("type", req.GetType()). Msg("sync latest") - backend, err := srv.getBackend() + backend, err := srv.getBackend(ctx) if err != nil { return err } @@ -430,7 +432,7 @@ func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databrok }) } -func (srv *Server) getBackend() (backend storage.Backend, err error) { +func (srv *Server) getBackend(ctx context.Context) (backend storage.Backend, err error) { // double-checked locking: // first try the read lock, then re-try with the write lock, and finally create a new backend if nil srv.mu.RLock() @@ -441,7 +443,7 @@ func (srv *Server) getBackend() (backend storage.Backend, err error) { backend = srv.backend var err error if backend == nil { - backend, err = srv.newBackendLocked() + backend, err = srv.newBackendLocked(ctx) srv.backend = backend } srv.mu.Unlock() @@ -452,18 +454,18 @@ func (srv *Server) getBackend() (backend storage.Backend, err error) { return backend, nil } -func (srv *Server) newBackendLocked() (backend storage.Backend, err error) { - ctx := context.Background() - +func (srv *Server) newBackendLocked(ctx context.Context) (storage.Backend, error) { switch srv.cfg.storageType { case config.StorageInMemoryName: - log.Ctx(ctx).Info().Msg("using in-memory store") + log.Ctx(ctx).Info().Msg("initializing new in-memory store") return inmemory.New(), nil case config.StoragePostgresName: - log.Ctx(ctx).Info().Msg("using postgres store") - backend = postgres.New(srv.cfg.storageConnectionString) + log.Ctx(ctx).Info().Msg("initializing new postgres store") + // NB: the context passed to postgres.New here is a separate context scoped + // to the lifetime of the server itself. 'ctx' may be a short-lived request + // context, since the backend is lazy-initialized. + return postgres.New(srv.backendCtx, srv.cfg.storageConnectionString), nil default: return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType) } - return backend, nil } diff --git a/internal/databroker/server_test.go b/internal/databroker/server_test.go index 0a348c770..ac7c775b0 100644 --- a/internal/databroker/server_test.go +++ b/internal/databroker/server_test.go @@ -48,7 +48,8 @@ func (h testSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint func newServer(cfg *serverConfig) *Server { return &Server{ - cfg: cfg, + cfg: cfg, + backendCtx: context.Background(), } } @@ -277,7 +278,7 @@ func TestServer_Sync(t *testing.T) { updateRecords := make(chan uint64, 10) client := databroker.NewDataBrokerServiceClient(cc) - syncer := databroker.NewSyncer("TEST", testSyncerHandler{ + syncer := databroker.NewSyncer(ctx, "TEST", testSyncerHandler{ getDataBrokerServiceClient: func() databroker.DataBrokerServiceClient { return client }, @@ -292,12 +293,12 @@ func TestServer_Sync(t *testing.T) { select { case <-clearRecords: case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) } select { case <-updateRecords: case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) } @@ -313,7 +314,7 @@ func TestServer_Sync(t *testing.T) { select { case <-updateRecords: case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) } return nil diff --git a/internal/enabler/enabler_test.go b/internal/enabler/enabler_test.go index 91004097f..a419bfdab 100644 --- a/internal/enabler/enabler_test.go +++ b/internal/enabler/enabler_test.go @@ -45,7 +45,7 @@ func TestEnabler(t *testing.T) { started.Add(1) <-ctx.Done() stopped.Add(1) - return ctx.Err() + return context.Cause(ctx) }), true) time.AfterFunc(time.Millisecond*10, e.Disable) go e.Run(ctx) diff --git a/internal/fileutil/watcher.go b/internal/fileutil/watcher.go index 338889fc3..ca972b90a 100644 --- a/internal/fileutil/watcher.go +++ b/internal/fileutil/watcher.go @@ -80,6 +80,9 @@ func (watcher *Watcher) initLocked(ctx context.Context) { if watcher.pollingWatcher == nil { watcher.pollingWatcher = filenotify.NewPollingWatcher(nil) + context.AfterFunc(ctx, func() { + watcher.pollingWatcher.Close() + }) } errors := watcher.pollingWatcher.Errors() diff --git a/internal/log/debug.go b/internal/log/debug.go new file mode 100644 index 000000000..1d4f38bb7 --- /dev/null +++ b/internal/log/debug.go @@ -0,0 +1,10 @@ +package log + +import "sync/atomic" + +var ( + // Debug option to disable the Zap log shim + DebugDisableZapLogger atomic.Bool + // Debug option to suppress global warnings + DebugDisableGlobalWarnings atomic.Bool +) diff --git a/internal/log/log.go b/internal/log/log.go index 9ea271dca..f039cd12f 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -52,6 +52,9 @@ func Logger() *zerolog.Logger { // ZapLogger returns the global zap logger. func ZapLogger() *zap.Logger { + if DebugDisableZapLogger.Load() { + return zap.NewNop() + } return zapLogger.Load() } diff --git a/internal/log/middleware.go b/internal/log/middleware.go index c59880a9a..62575da6f 100644 --- a/internal/log/middleware.go +++ b/internal/log/middleware.go @@ -1,11 +1,14 @@ package log import ( + "context" "net" "net/http" "time" + "github.com/pomerium/protoutil/streams" "github.com/rs/zerolog" + "google.golang.org/grpc" "github.com/pomerium/pomerium/internal/middleware/responsewriter" "github.com/pomerium/pomerium/pkg/telemetry/requestid" @@ -121,3 +124,17 @@ func HeadersHandler(headers []string) func(next http.Handler) http.Handler { }) } } + +func StreamServerInterceptor(lg *zerolog.Logger) grpc.StreamServerInterceptor { + return func(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + s := streams.NewServerStreamWithContext(ss) + s.SetContext(lg.WithContext(s.Ctx)) + return handler(srv, s) + } +} + +func UnaryServerInterceptor(lg *zerolog.Logger) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + return handler(lg.WithContext(ctx), req) + } +} diff --git a/internal/log/warnings.go b/internal/log/warnings.go index d5c418434..a718fabda 100644 --- a/internal/log/warnings.go +++ b/internal/log/warnings.go @@ -11,6 +11,9 @@ var warnCookieSecretOnce sync.Once // WarnCookieSecret warns about the cookie secret. func WarnCookieSecret() { warnCookieSecretOnce.Do(func() { + if DebugDisableGlobalWarnings.Load() { + return + } Info(). Msg("using a generated COOKIE_SECRET. " + "Set the COOKIE_SECRET to avoid users being logged out on restart. " + @@ -23,6 +26,9 @@ var warnNoTLSCertificateOnce syncutil.OnceMap[string] // WarnNoTLSCertificate warns about no TLS certificate. func WarnNoTLSCertificate(domain string) { warnNoTLSCertificateOnce.Do(domain, func() { + if DebugDisableGlobalWarnings.Load() { + return + } Info(). Str("domain", domain). Msg("no TLS certificate found for domain, using a self-signed certificate") @@ -34,6 +40,9 @@ var warnWebSocketHTTP1_1Once syncutil.OnceMap[string] // WarnWebSocketHTTP1_1 warns about falling back to http 1.1 due to web socket support. func WarnWebSocketHTTP1_1(clusterID string) { warnWebSocketHTTP1_1Once.Do(clusterID, func() { + if DebugDisableGlobalWarnings.Load() { + return + } Info(). Str("cluster-id", clusterID). Msg("forcing http/1.1 due to web socket support") diff --git a/internal/retry/retry.go b/internal/retry/retry.go index 5dd5de2bb..ed3d6f8ac 100644 --- a/internal/retry/retry.go +++ b/internal/retry/retry.go @@ -103,7 +103,7 @@ func makeSelect( fn: func(ctx context.Context) error { // unreachable, the context handler will never be called // as its channel can only be closed - return ctx.Err() + return context.Cause(ctx) }, ch: reflect.ValueOf(ctx.Done()), }, diff --git a/internal/tests/xdserr/health.go b/internal/tests/xdserr/health.go index 24f199b3d..29f27a2c2 100644 --- a/internal/tests/xdserr/health.go +++ b/internal/tests/xdserr/health.go @@ -22,7 +22,7 @@ func WaitForHealthy(ctx context.Context, client *http.Client, routes []*config.R healthy++ } } - return ctx.Err() + return context.Cause(ctx) } func checkHealth(ctx context.Context, client *http.Client, addr string) error { diff --git a/internal/zero/bootstrap/bootstrap.go b/internal/zero/bootstrap/bootstrap.go index cfa70d08b..c29a8468d 100644 --- a/internal/zero/bootstrap/bootstrap.go +++ b/internal/zero/bootstrap/bootstrap.go @@ -74,7 +74,7 @@ func (svc *Source) updateLoop(ctx context.Context) error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-svc.checkForUpdate: case <-ticker.C: } diff --git a/internal/zero/bootstrap/source.go b/internal/zero/bootstrap/source.go index 27a79a22f..694f1954a 100644 --- a/internal/zero/bootstrap/source.go +++ b/internal/zero/bootstrap/source.go @@ -34,7 +34,7 @@ type source struct { func (src *source) WaitReady(ctx context.Context) error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-src.ready: return nil } diff --git a/internal/zero/cmd/command_import.go b/internal/zero/cmd/command_import.go index d907175a0..b63722b5a 100644 --- a/internal/zero/cmd/command_import.go +++ b/internal/zero/cmd/command_import.go @@ -36,7 +36,7 @@ func BuildImportCmd() *cobra.Command { return fmt.Errorf("no config file provided") } log.SetLevel(zerolog.ErrorLevel) - src, err := config.NewFileOrEnvironmentSource(configFile, files.FullVersion()) + src, err := config.NewFileOrEnvironmentSource(cmd.Context(), configFile, files.FullVersion()) if err != nil { return err } diff --git a/internal/zero/connect-mux/messages.go b/internal/zero/connect-mux/messages.go index 3e648f75e..79f673414 100644 --- a/internal/zero/connect-mux/messages.go +++ b/internal/zero/connect-mux/messages.go @@ -17,7 +17,7 @@ import ( func (svc *Mux) Watch(ctx context.Context, opts ...WatchOption) error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-svc.ready: } diff --git a/internal/zero/controller/databroker_restart_test.go b/internal/zero/controller/databroker_restart_test.go index 8c25ad530..686bd3163 100644 --- a/internal/zero/controller/databroker_restart_test.go +++ b/internal/zero/controller/databroker_restart_test.go @@ -102,7 +102,7 @@ func TestDatabrokerRestart(t *testing.T) { cl(context.Background(), newConfig()) <-ctx.Done() require.ErrorIs(t, context.Cause(ctx), controller.ErrBootstrapConfigurationChanged) - return ctx.Err() + return context.Cause(ctx) } return nil }) diff --git a/internal/zero/controller/leaser.go b/internal/zero/controller/leaser.go index de3303e7d..e1c8d37f6 100644 --- a/internal/zero/controller/leaser.go +++ b/internal/zero/controller/leaser.go @@ -56,5 +56,5 @@ func (w *LeaseStatus) MonitorLease(ctx context.Context, _ databroker.DataBrokerS w.v.Store(true) <-ctx.Done() w.v.Store(false) - return ctx.Err() + return context.Cause(ctx) } diff --git a/internal/zero/controller/usagereporter/usagereporter_test.go b/internal/zero/controller/usagereporter/usagereporter_test.go index 24a07ba14..633227abd 100644 --- a/internal/zero/controller/usagereporter/usagereporter_test.go +++ b/internal/zero/controller/usagereporter/usagereporter_test.go @@ -37,7 +37,7 @@ func TestUsageReporter(t *testing.T) { t.Cleanup(cancel) cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { - databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New()) + databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx)) }) t.Cleanup(func() { cc.Close() }) diff --git a/internal/zero/healthcheck/syncer.go b/internal/zero/healthcheck/syncer.go index 19029f085..4f7ae0d98 100644 --- a/internal/zero/healthcheck/syncer.go +++ b/internal/zero/healthcheck/syncer.go @@ -10,7 +10,7 @@ import ( ) func (c *Checker) ConfigSyncer(ctx context.Context) error { - syncer := databroker.NewSyncer("zero-health-check", c, databroker.WithTypeURL(protoutil.GetTypeURL(new(configpb.Config)))) + syncer := databroker.NewSyncer(ctx, "zero-health-check", c, databroker.WithTypeURL(protoutil.GetTypeURL(new(configpb.Config)))) return syncer.Run(ctx) } diff --git a/internal/zero/reconciler/sync.go b/internal/zero/reconciler/sync.go index e929262ff..6dc5b6de2 100644 --- a/internal/zero/reconciler/sync.go +++ b/internal/zero/reconciler/sync.go @@ -35,7 +35,7 @@ func (c *service) SyncLoop(ctx context.Context) error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-c.bundleSyncRequest: log.Ctx(ctx).Debug().Msg("bundle sync triggered") err := c.syncBundles(ctx) diff --git a/internal/zero/telemetry/telemetry.go b/internal/zero/telemetry/telemetry.go index d4686bb3d..67eb8ad45 100644 --- a/internal/zero/telemetry/telemetry.go +++ b/internal/zero/telemetry/telemetry.go @@ -105,7 +105,7 @@ func (srv *Telemetry) handleRequests(ctx context.Context) error { case req := <-requests: srv.handleRequest(ctx, req) case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) } } }) diff --git a/pkg/cmd/pomerium/pomerium.go b/pkg/cmd/pomerium/pomerium.go index ca18d2985..489c4ed55 100644 --- a/pkg/cmd/pomerium/pomerium.go +++ b/pkg/cmd/pomerium/pomerium.go @@ -32,7 +32,7 @@ import ( // Run runs the main pomerium application. func Run(ctx context.Context, src config.Source) error { - _, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Debug().Msgf(s, i...) })) + _, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Ctx(ctx).Debug().Msgf(s, i...) })) evt := log.Ctx(ctx).Info(). Str("envoy_version", files.FullVersion()). @@ -53,7 +53,7 @@ func Run(ctx context.Context, src config.Source) error { // trigger changes when underlying files are changed src = config.NewFileWatcherSource(ctx, src) - src, err = autocert.New(src) + src, err = autocert.New(ctx, src) if err != nil { return err } @@ -71,7 +71,7 @@ func Run(ctx context.Context, src config.Source) error { cfg := src.GetConfig() // setup the control plane - controlPlane, err := controlplane.NewServer(cfg, metricsMgr, eventsMgr) + controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr) if err != nil { return fmt.Errorf("error creating control plane: %w", err) } @@ -166,11 +166,11 @@ func setupAuthenticate(ctx context.Context, src config.Source, controlPlane *con return nil } - svc, err := authenticate.New(src.GetConfig()) + svc, err := authenticate.New(ctx, src.GetConfig()) if err != nil { return fmt.Errorf("error creating authenticate service: %w", err) } - err = controlPlane.EnableAuthenticate(svc) + err = controlPlane.EnableAuthenticate(ctx, svc) if err != nil { return fmt.Errorf("error adding authenticate service to control plane: %w", err) } @@ -183,7 +183,7 @@ func setupAuthenticate(ctx context.Context, src config.Source, controlPlane *con } func setupAuthorize(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) { - svc, err := authorize.New(src.GetConfig()) + svc, err := authorize.New(ctx, src.GetConfig()) if err != nil { return nil, fmt.Errorf("error creating authorize service: %w", err) } @@ -200,7 +200,7 @@ func setupDataBroker(ctx context.Context, controlPlane *controlplane.Server, eventsMgr *events.Manager, ) (*databroker_service.DataBroker, error) { - svc, err := databroker_service.New(src.GetConfig(), eventsMgr) + svc, err := databroker_service.New(ctx, src.GetConfig(), eventsMgr) if err != nil { return nil, fmt.Errorf("error creating databroker service: %w", err) } @@ -223,11 +223,11 @@ func setupProxy(ctx context.Context, src config.Source, controlPlane *controlpla return nil } - svc, err := proxy.New(src.GetConfig()) + svc, err := proxy.New(ctx, src.GetConfig()) if err != nil { return fmt.Errorf("error creating proxy service: %w", err) } - err = controlPlane.EnableProxy(svc) + err = controlPlane.EnableProxy(ctx, svc) if err != nil { return fmt.Errorf("error adding proxy service to control plane: %w", err) } diff --git a/pkg/envoy/envoy.go b/pkg/envoy/envoy.go index 29369c277..0980b0c82 100644 --- a/pkg/envoy/envoy.go +++ b/pkg/envoy/envoy.go @@ -185,7 +185,7 @@ func (srv *Server) run(ctx context.Context, cfg *config.Config) error { // monitor the process so we exit if it prematurely exits var monitorProcessCtx context.Context - monitorProcessCtx, srv.monitorProcessCancel = context.WithCancel(context.Background()) + monitorProcessCtx, srv.monitorProcessCancel = context.WithCancel(context.WithoutCancel(ctx)) go srv.monitorProcess(monitorProcessCtx, int32(cmd.Process.Pid)) if srv.resourceMonitor != nil { @@ -251,7 +251,7 @@ func (srv *Server) parseLog(line string) (name string, logLevel string, msg stri func (srv *Server) handleLogs(ctx context.Context, rc io.ReadCloser) { defer rc.Close() - l := log.With().Str("service", "envoy").Logger() + l := log.Ctx(ctx).With().Str("service", "envoy").Logger() bo := backoff.NewExponentialBackOff() s := bufio.NewReader(rc) diff --git a/pkg/fanout/publish.go b/pkg/fanout/publish.go index 3bd549d22..33bb67478 100644 --- a/pkg/fanout/publish.go +++ b/pkg/fanout/publish.go @@ -10,7 +10,7 @@ func (f *FanOut[T]) Publish(ctx context.Context, msg T) error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-f.done: return ErrStopped case f.messages <- msg: diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 621378d32..a8e7aad71 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -141,7 +141,7 @@ func WaitForReady(ctx context.Context, cc *grpc.ClientConn, timeout time.Duratio for { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-ticker.C: } diff --git a/pkg/grpc/databroker/leaser.go b/pkg/grpc/databroker/leaser.go index 4a07f6c74..dcc547670 100644 --- a/pkg/grpc/databroker/leaser.go +++ b/pkg/grpc/databroker/leaser.go @@ -71,13 +71,13 @@ func (locker *Leaser) Run(ctx context.Context) error { case err == nil: select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-retryTicker.C: } case errors.Is(err, retryableError{}): select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-time.After(bo.NextBackOff()): } default: diff --git a/pkg/grpc/databroker/leaser_test.go b/pkg/grpc/databroker/leaser_test.go index f1680b0b4..20d777188 100644 --- a/pkg/grpc/databroker/leaser_test.go +++ b/pkg/grpc/databroker/leaser_test.go @@ -169,7 +169,7 @@ func TestLeasers(t *testing.T) { fn2 := func(ctx context.Context) error { atomic.AddInt64(&counter, 10) <-ctx.Done() - return ctx.Err() + return context.Cause(ctx) } leaser := databroker.NewLeasers("TEST", time.Second*30, client, fn1, fn2) err := leaser.Run(context.Background()) diff --git a/pkg/grpc/databroker/reconciler.go b/pkg/grpc/databroker/reconciler.go index ff4e93bec..c11604087 100644 --- a/pkg/grpc/databroker/reconciler.go +++ b/pkg/grpc/databroker/reconciler.go @@ -110,7 +110,7 @@ func (r *Reconciler) reconcileLoop(ctx context.Context) error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-r.trigger: } } diff --git a/pkg/grpc/databroker/sync_test.go b/pkg/grpc/databroker/sync_test.go index 85a454876..f047b3058 100644 --- a/pkg/grpc/databroker/sync_test.go +++ b/pkg/grpc/databroker/sync_test.go @@ -24,7 +24,7 @@ func Test_SyncLatestRecords(t *testing.T) { defer clearTimeout() cc := testutil.NewGRPCServer(t, func(s *grpc.Server) { - databrokerpb.RegisterDataBrokerServiceServer(s, databroker.New()) + databrokerpb.RegisterDataBrokerServiceServer(s, databroker.New(ctx)) }) c := databrokerpb.NewDataBrokerServiceClient(cc) diff --git a/pkg/grpc/databroker/syncer.go b/pkg/grpc/databroker/syncer.go index db4256889..769fd3425 100644 --- a/pkg/grpc/databroker/syncer.go +++ b/pkg/grpc/databroker/syncer.go @@ -72,8 +72,8 @@ type Syncer struct { } // NewSyncer creates a new Syncer. -func NewSyncer(id string, handler SyncerHandler, options ...SyncerOption) *Syncer { - closeCtx, closeCtxCancel := context.WithCancel(context.Background()) +func NewSyncer(ctx context.Context, id string, handler SyncerHandler, options ...SyncerOption) *Syncer { + closeCtx, closeCtxCancel := context.WithCancel(context.WithoutCancel(ctx)) bo := backoff.NewExponentialBackOff() bo.MaxElapsedTime = 0 @@ -120,7 +120,7 @@ func (syncer *Syncer) Run(ctx context.Context) error { log.Ctx(ctx).Error().Err(err).Msg("sync") select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-time.After(syncer.backoff.NextBackOff()): } } @@ -133,6 +133,9 @@ func (syncer *Syncer) init(ctx context.Context) error { Type: syncer.cfg.typeURL, }) if err != nil { + if status.Code(err) == codes.Canceled && ctx.Err() != nil { + err = fmt.Errorf("%w: %w", err, context.Cause(ctx)) + } return fmt.Errorf("error during initial sync: %w", err) } syncer.backoff.Reset() @@ -167,6 +170,9 @@ func (syncer *Syncer) sync(ctx context.Context) error { syncer.serverVersion = 0 return nil } else if err != nil { + if status.Code(err) == codes.Canceled && ctx.Err() != nil { + err = fmt.Errorf("%w: %w", err, context.Cause(ctx)) + } return fmt.Errorf("error receiving sync record: %w", err) } diff --git a/pkg/grpc/databroker/syncer_test.go b/pkg/grpc/databroker/syncer_test.go index ee1bc6e23..583c4b852 100644 --- a/pkg/grpc/databroker/syncer_test.go +++ b/pkg/grpc/databroker/syncer_test.go @@ -157,7 +157,7 @@ func TestSyncer(t *testing.T) { clearCh := make(chan struct{}) updateCh := make(chan []*Record) - syncer := NewSyncer("test", testSyncerHandler{ + syncer := NewSyncer(ctx, "test", testSyncerHandler{ getDataBrokerServiceClient: func() DataBrokerServiceClient { return NewDataBrokerServiceClient(gc) }, diff --git a/pkg/identity/legacymanager/manager.go b/pkg/identity/legacymanager/manager.go index 9641fc47d..0bf68b525 100644 --- a/pkg/identity/legacymanager/manager.go +++ b/pkg/identity/legacymanager/manager.go @@ -124,13 +124,13 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords // wait for initial sync select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-clear: mgr.reset() } select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case msg := <-update: mgr.onUpdateRecords(ctx, msg) } @@ -150,7 +150,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords for { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-clear: mgr.reset() case msg := <-update: diff --git a/pkg/identity/legacymanager/sync.go b/pkg/identity/legacymanager/sync.go index 94810a749..00703738a 100644 --- a/pkg/identity/legacymanager/sync.go +++ b/pkg/identity/legacymanager/sync.go @@ -17,7 +17,7 @@ type dataBrokerSyncer struct { } func newDataBrokerSyncer( - _ context.Context, + ctx context.Context, cfg *atomicutil.Value[*config], update chan<- updateRecordsMessage, clear chan<- struct{}, @@ -28,7 +28,7 @@ func newDataBrokerSyncer( update: update, clear: clear, } - syncer.syncer = databroker.NewSyncer("identity_manager", syncer) + syncer.syncer = databroker.NewSyncer(ctx, "identity_manager", syncer) return syncer } diff --git a/pkg/identity/manager/sync.go b/pkg/identity/manager/sync.go index 1eb71d49f..5d90b252d 100644 --- a/pkg/identity/manager/sync.go +++ b/pkg/identity/manager/sync.go @@ -16,7 +16,7 @@ type sessionSyncerHandler struct { } func newSessionSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer { - return databroker.NewSyncer("identity_manager/sessions", sessionSyncerHandler{baseCtx: ctx, mgr: mgr}, + return databroker.NewSyncer(ctx, "identity_manager/sessions", sessionSyncerHandler{baseCtx: ctx, mgr: mgr}, databroker.WithTypeURL(grpcutil.GetTypeURL(new(session.Session)))) } @@ -50,7 +50,7 @@ type userSyncerHandler struct { } func newUserSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer { - return databroker.NewSyncer("identity_manager/users", userSyncerHandler{baseCtx: ctx, mgr: mgr}, + return databroker.NewSyncer(ctx, "identity_manager/users", userSyncerHandler{baseCtx: ctx, mgr: mgr}, databroker.WithTypeURL(grpcutil.GetTypeURL(new(user.User)))) } diff --git a/pkg/storage/postgres/backend.go b/pkg/storage/postgres/backend.go index 8e3530879..91e9d7cfb 100644 --- a/pkg/storage/postgres/backend.go +++ b/pkg/storage/postgres/backend.go @@ -37,14 +37,14 @@ type Backend struct { } // New creates a new Backend. -func New(dsn string, options ...Option) *Backend { +func New(ctx context.Context, dsn string, options ...Option) *Backend { backend := &Backend{ cfg: getConfig(options...), dsn: dsn, onRecordChange: signal.New(), onServiceChange: signal.New(), } - backend.closeCtx, backend.close = context.WithCancel(context.Background()) + backend.closeCtx, backend.close = context.WithCancel(ctx) go backend.doPeriodically(func(ctx context.Context) error { _, pool, err := backend.init(ctx) diff --git a/pkg/storage/postgres/backend_test.go b/pkg/storage/postgres/backend_test.go index b10f16ed4..b41b985a4 100644 --- a/pkg/storage/postgres/backend_test.go +++ b/pkg/storage/postgres/backend_test.go @@ -34,7 +34,7 @@ func TestBackend(t *testing.T) { defer clearTimeout() require.NoError(t, testutil.WithTestPostgres(func(dsn string) error { - backend := New(dsn) + backend := New(ctx, dsn) defer backend.Close() t.Run("put", func(t *testing.T) { diff --git a/pkg/storage/postgres/registry_test.go b/pkg/storage/postgres/registry_test.go index 51cf9b7f3..bc57bc839 100644 --- a/pkg/storage/postgres/registry_test.go +++ b/pkg/storage/postgres/registry_test.go @@ -42,7 +42,7 @@ func TestRegistry(t *testing.T) { defer clearTimeout() require.NoError(t, testutil.WithTestPostgres(func(dsn string) error { - backend := New(dsn) + backend := New(ctx, dsn) defer backend.Close() eg, ctx := errgroup.WithContext(ctx) @@ -53,7 +53,7 @@ func TestRegistry(t *testing.T) { send: func(res *registry.ServiceList) error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case listResults <- res: } return nil @@ -73,7 +73,7 @@ func TestRegistry(t *testing.T) { eg.Go(func() error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case res := <-listResults: testutil.AssertProtoEqual(t, ®istry.ServiceList{}, res) } @@ -92,7 +92,7 @@ func TestRegistry(t *testing.T) { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case res := <-listResults: testutil.AssertProtoEqual(t, ®istry.ServiceList{ Services: []*registry.Service{ diff --git a/proxy/data_test.go b/proxy/data_test.go index 7b2e090f4..2e74036dc 100644 --- a/proxy/data_test.go +++ b/proxy/data_test.go @@ -32,14 +32,14 @@ func Test_getUserInfoData(t *testing.T) { defer clearTimeout() cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { - databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New()) + databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx)) }) t.Cleanup(func() { cc.Close() }) client := databrokerpb.NewDataBrokerServiceClient(cc) opts := testOptions(t) - proxy, err := New(&config.Config{Options: opts}) + proxy, err := New(ctx, &config.Config{Options: opts}) require.NoError(t, err) proxy.state.Load().dataBrokerClient = client diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 49b1f5287..2817b22d0 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -2,6 +2,7 @@ package proxy import ( "bytes" + "context" "io" "net/http" "net/http/httptest" @@ -36,7 +37,7 @@ func TestProxy_SignOut(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { opts := testOptions(t) - p, err := New(&config.Config{Options: opts}) + p, err := New(context.Background(), &config.Config{Options: opts}) if err != nil { t.Fatal(err) } @@ -129,7 +130,7 @@ func TestProxy_ProgrammaticLogin(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := New(&config.Config{Options: tt.options}) + p, err := New(context.Background(), &config.Config{Options: tt.options}) if err != nil { t.Fatal(err) } @@ -270,7 +271,7 @@ func TestLoadSessionState(t *testing.T) { t.Parallel() opts := testOptions(t) - proxy, err := New(&config.Config{Options: opts}) + proxy, err := New(context.Background(), &config.Config{Options: opts}) require.NoError(t, err) r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) @@ -285,7 +286,7 @@ func TestLoadSessionState(t *testing.T) { t.Parallel() opts := testOptions(t) - proxy, err := New(&config.Config{Options: opts}) + proxy, err := New(context.Background(), &config.Config{Options: opts}) require.NoError(t, err) session := encodeSession(t, opts, &sessions.State{ @@ -308,7 +309,7 @@ func TestLoadSessionState(t *testing.T) { t.Parallel() opts := testOptions(t) - proxy, err := New(&config.Config{Options: opts}) + proxy, err := New(context.Background(), &config.Config{Options: opts}) require.NoError(t, err) session := encodeSession(t, opts, &sessions.State{ diff --git a/proxy/proxy.go b/proxy/proxy.go index a5ffc593d..58f795860 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -60,8 +60,8 @@ type Proxy struct { // New takes a Proxy service from options and a validation function. // Function returns an error if options fail to validate. -func New(cfg *config.Config) (*Proxy, error) { - state, err := newProxyStateFromConfig(cfg) +func New(ctx context.Context, cfg *config.Config) (*Proxy, error) { + state, err := newProxyStateFromConfig(ctx, cfg) if err != nil { return nil, err } @@ -71,7 +71,7 @@ func New(cfg *config.Config) (*Proxy, error) { currentOptions: config.NewAtomicOptions(), currentRouter: atomicutil.NewValue(httputil.NewRouter()), } - p.OnConfigChange(context.Background(), cfg) + p.OnConfigChange(ctx, cfg) p.webauthn = webauthn.New(p.getWebauthnState) metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { @@ -87,25 +87,25 @@ func (p *Proxy) Mount(r *mux.Router) { } // OnConfigChange updates internal structures based on config.Options -func (p *Proxy) OnConfigChange(_ context.Context, cfg *config.Config) { +func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) { if p == nil { return } p.currentOptions.Store(cfg.Options) - if err := p.setHandlers(cfg.Options); err != nil { - log.Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings") + if err := p.setHandlers(ctx, cfg.Options); err != nil { + log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings") } - if state, err := newProxyStateFromConfig(cfg); err != nil { - log.Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings") + if state, err := newProxyStateFromConfig(ctx, cfg); err != nil { + log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings") } else { p.state.Store(state) } } -func (p *Proxy) setHandlers(opts *config.Options) error { +func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error { if opts.NumPolicies() == 0 { - log.Info().Msg("proxy: configuration has no policies") + log.Ctx(ctx).Info().Msg("proxy: configuration has no policies") } r := httputil.NewRouter() r.NotFoundHandler = httputil.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) error { diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index c17ff8d66..da49e0161 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -104,7 +104,7 @@ func TestNew(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := New(&config.Config{Options: tt.opts}) + got, err := New(context.Background(), &config.Config{Options: tt.opts}) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return @@ -197,7 +197,7 @@ func Test_UpdateOptions(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := New(&config.Config{Options: tt.originalOptions}) + p, err := New(context.Background(), &config.Config{Options: tt.originalOptions}) if err != nil { t.Fatal(err) } diff --git a/proxy/state.go b/proxy/state.go index cd5bc22a3..5a7727e13 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -31,7 +31,7 @@ type proxyState struct { authenticateFlow authenticateFlow } -func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { +func newProxyStateFromConfig(ctx context.Context, cfg *config.Config) (*proxyState, error) { err := ValidateOptions(cfg.Options) if err != nil { return nil, err @@ -57,7 +57,7 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { return nil, err } - dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{ + dataBrokerConn, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{ OutboundPort: cfg.OutboundPort, InstallationID: cfg.Options.InstallationID, ServiceName: cfg.Options.Services, @@ -71,10 +71,10 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist if cfg.Options.UseStatelessAuthenticateFlow() { - state.authenticateFlow, err = authenticateflow.NewStateless( + state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, state.sessionStore, nil, nil, nil) } else { - state.authenticateFlow, err = authenticateflow.NewStateful(cfg, state.sessionStore) + state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, state.sessionStore) } if err != nil { return nil, err