Fix many instances of contexts and loggers not being propagated (#5340)

This also replaces instances where we manually write "return ctx.Err()"
with "return context.Cause(ctx)" which is functionally identical, but
will also correctly propagate cause errors if present.
This commit is contained in:
Joe Kralicky 2024-10-25 14:50:56 -04:00 committed by GitHub
parent e1880ba20f
commit fe31799eb5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
77 changed files with 297 additions and 221 deletions

View file

@ -44,7 +44,7 @@ type Authenticate struct {
} }
// New validates and creates a new authenticate service from a set of Options. // New validates and creates a new authenticate service from a set of Options.
func New(cfg *config.Config, options ...Option) (*Authenticate, error) { func New(ctx context.Context, cfg *config.Config, options ...Option) (*Authenticate, error) {
authenticateConfig := getAuthenticateConfig(options...) authenticateConfig := getAuthenticateConfig(options...)
a := &Authenticate{ a := &Authenticate{
cfg: authenticateConfig, cfg: authenticateConfig,
@ -54,7 +54,7 @@ func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
a.options.Store(cfg.Options) a.options.Store(cfg.Options)
state, err := newAuthenticateStateFromConfig(cfg, authenticateConfig) state, err := newAuthenticateStateFromConfig(ctx, cfg, authenticateConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -70,7 +70,7 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
} }
a.options.Store(cfg.Options) a.options.Store(cfg.Options)
if state, err := newAuthenticateStateFromConfig(cfg, a.cfg); err != nil { if state, err := newAuthenticateStateFromConfig(ctx, cfg, a.cfg); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authenticate: failed to update state") log.Ctx(ctx).Error().Err(err).Msg("authenticate: failed to update state")
} else { } else {
a.state.Store(state) a.state.Store(state)

View file

@ -1,6 +1,7 @@
package authenticate package authenticate
import ( import (
"context"
"testing" "testing"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
@ -106,7 +107,7 @@ func TestNew(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := New(&config.Config{Options: tt.opts}) _, err := New(context.Background(), &config.Config{Options: tt.opts})
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return return

View file

@ -64,6 +64,7 @@ func newAuthenticateState() *authenticateState {
} }
func newAuthenticateStateFromConfig( func newAuthenticateStateFromConfig(
ctx context.Context,
cfg *config.Config, authenticateConfig *authenticateConfig, cfg *config.Config, authenticateConfig *authenticateConfig,
) (*authenticateState, error) { ) (*authenticateState, error) {
err := ValidateOptions(cfg.Options) err := ValidateOptions(cfg.Options)
@ -145,7 +146,7 @@ func newAuthenticateStateFromConfig(
} }
if cfg.Options.UseStatelessAuthenticateFlow() { if cfg.Options.UseStatelessAuthenticateFlow() {
state.flow, err = authenticateflow.NewStateless( state.flow, err = authenticateflow.NewStateless(ctx,
cfg, cfg,
cookieStore, cookieStore,
authenticateConfig.getIdentityProvider, authenticateConfig.getIdentityProvider,
@ -153,7 +154,7 @@ func newAuthenticateStateFromConfig(
authenticateConfig.authEventFn, authenticateConfig.authEventFn,
) )
} else { } else {
state.flow, err = authenticateflow.NewStateful(cfg, cookieStore) state.flow, err = authenticateflow.NewStateful(ctx, cfg, cookieStore)
} }
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -40,7 +40,7 @@ type Authorize struct {
} }
// New validates and creates a new Authorize service from a set of config options. // New validates and creates a new Authorize service from a set of config options.
func New(cfg *config.Config) (*Authorize, error) { func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
a := &Authorize{ a := &Authorize{
currentOptions: config.NewAtomicOptions(), currentOptions: config.NewAtomicOptions(),
store: store.New(), store: store.New(),
@ -48,7 +48,7 @@ func New(cfg *config.Config) (*Authorize, error) {
} }
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod) a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
state, err := newAuthorizeStateFromConfig(cfg, a.store, nil) state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -89,12 +89,13 @@ func validateOptions(o *config.Options) error {
// newPolicyEvaluator returns an policy evaluator. // newPolicyEvaluator returns an policy evaluator.
func newPolicyEvaluator( func newPolicyEvaluator(
ctx context.Context,
opts *config.Options, store *store.Store, previous *evaluator.Evaluator, opts *config.Options, store *store.Store, previous *evaluator.Evaluator,
) (*evaluator.Evaluator, error) { ) (*evaluator.Evaluator, error) {
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 { metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
return int64(opts.NumPolicies()) return int64(opts.NumPolicies())
}) })
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context { ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "authorize") return c.Str("service", "authorize")
}) })
ctx, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator") ctx, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator")
@ -150,7 +151,7 @@ func newPolicyEvaluator(
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) { func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
currentState := a.state.Load() currentState := a.state.Load()
a.currentOptions.Store(cfg.Options) a.currentOptions.Store(cfg.Options)
if state, err := newAuthorizeStateFromConfig(cfg, a.store, currentState.evaluator); err != nil { if state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, currentState.evaluator); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state") log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
} else { } else {
a.state.Store(state) a.state.Store(state)

View file

@ -82,7 +82,7 @@ func TestNew(t *testing.T) {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
_, err := New(&config.Config{Options: &tt.config}) _, err := New(context.Background(), &config.Config{Options: &tt.config})
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -114,7 +114,7 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
SharedKey: tc.SharedKey, SharedKey: tc.SharedKey,
Policies: tc.Policies, Policies: tc.Policies,
} }
a, err := New(&config.Config{Options: o}) a, err := New(context.Background(), &config.Config{Options: o})
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, a) require.NotNil(t, a)
@ -185,7 +185,7 @@ func TestNewPolicyEvaluator_addDefaultClientCertificateRule(t *testing.T) {
c.opts.Policies = []config.Policy{{ c.opts.Policies = []config.Policy{{
To: mustParseWeightedURLs(t, "http://example.com"), To: mustParseWeightedURLs(t, "http://example.com"),
}} }}
e, err := newPolicyEvaluator(c.opts, store, nil) e, err := newPolicyEvaluator(context.Background(), c.opts, store, nil)
require.NoError(t, err) require.NoError(t, err)
r, err := e.Evaluate(context.Background(), &evaluator.Request{ r, err := e.Evaluate(context.Background(), &evaluator.Request{

View file

@ -34,7 +34,7 @@ func TestAuthorize_handleResult(t *testing.T) {
t.Cleanup(authnSrv.Close) t.Cleanup(authnSrv.Close)
opt.AuthenticateURLString = authnSrv.URL opt.AuthenticateURLString = authnSrv.URL
a, err := New(&config.Config{Options: opt}) a, err := New(context.Background(), &config.Config{Options: opt})
require.NoError(t, err) require.NoError(t, err)
t.Run("user-unauthenticated", func(t *testing.T) { t.Run("user-unauthenticated", func(t *testing.T) {
@ -129,7 +129,7 @@ func TestAuthorize_okResponse(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(opt) a.currentOptions.Store(opt)
a.store = store.New() a.store = store.New()
pe, err := newPolicyEvaluator(opt, a.store, nil) pe, err := newPolicyEvaluator(context.Background(), opt, a.store, nil)
require.NoError(t, err) require.NoError(t, err)
a.state.Load().evaluator = pe a.state.Load().evaluator = pe
@ -327,7 +327,7 @@ func TestRequireLogin(t *testing.T) {
t.Cleanup(authnSrv.Close) t.Cleanup(authnSrv.Close)
opt.AuthenticateURLString = authnSrv.URL opt.AuthenticateURLString = authnSrv.URL
a, err := New(&config.Config{Options: opt}) a, err := New(context.Background(), &config.Config{Options: opt})
require.NoError(t, err) require.NoError(t, err)
t.Run("accept empty", func(t *testing.T) { t.Run("accept empty", func(t *testing.T) {

View file

@ -65,7 +65,7 @@ func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) {
t.Cleanup(clearTimeout) t.Cleanup(clearTimeout)
opt := config.NewDefaultOptions() opt := config.NewDefaultOptions()
a, err := New(&config.Config{Options: opt}) a, err := New(context.Background(), &config.Config{Options: opt})
require.NoError(t, err) require.NoError(t, err)
s1 := &session.Session{Id: "s1", ExpiresAt: timestamppb.New(time.Now().Add(-time.Second))} s1 := &session.Session{Id: "s1", ExpiresAt: timestamppb.New(time.Now().Add(-time.Second))}

View file

@ -108,7 +108,7 @@ func New(
) (*Evaluator, error) { ) (*Evaluator, error) {
cfg := getConfig(options...) cfg := getConfig(options...)
err := updateStore(store, cfg) err := updateStore(ctx, store, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -325,8 +325,8 @@ func (e *Evaluator) getClientCA(policy *config.Policy) (string, error) {
return string(e.clientCA), nil return string(e.clientCA), nil
} }
func updateStore(store *store.Store, cfg *evaluatorConfig) error { func updateStore(ctx context.Context, store *store.Store, cfg *evaluatorConfig) error {
jwk, err := getJWK(cfg) jwk, err := getJWK(ctx, cfg)
if err != nil { if err != nil {
return fmt.Errorf("authorize: couldn't create signer: %w", err) return fmt.Errorf("authorize: couldn't create signer: %w", err)
} }
@ -341,7 +341,7 @@ func updateStore(store *store.Store, cfg *evaluatorConfig) error {
return nil return nil
} }
func getJWK(cfg *evaluatorConfig) (*jose.JSONWebKey, error) { func getJWK(ctx context.Context, cfg *evaluatorConfig) (*jose.JSONWebKey, error) {
var decodedCert []byte var decodedCert []byte
// if we don't have a signing key, generate one // if we don't have a signing key, generate one
if len(cfg.SigningKey) == 0 { if len(cfg.SigningKey) == 0 {
@ -361,7 +361,7 @@ func getJWK(cfg *evaluatorConfig) (*jose.JSONWebKey, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("couldn't generate signing key: %w", err) return nil, fmt.Errorf("couldn't generate signing key: %w", err)
} }
log.Info().Str("Algorithm", jwk.Algorithm). log.Ctx(ctx).Info().Str("Algorithm", jwk.Algorithm).
Str("KeyID", jwk.KeyID). Str("KeyID", jwk.KeyID).
Interface("Public Key", jwk.Public()). Interface("Public Key", jwk.Public()).
Msg("authorize: signing key") Msg("authorize: signing key")

View file

@ -147,7 +147,8 @@ func NewPolicyEvaluator(
// for each script, create a rego and prepare a query. // for each script, create a rego and prepare a query.
for i := range e.queries { for i := range e.queries {
log.Ctx(ctx).Debug(). log.Ctx(ctx).
Trace().
Str("script", e.queries[i].script). Str("script", e.queries[i].script).
Str("from", configPolicy.From). Str("from", configPolicy.From).
Interface("to", configPolicy.To). Interface("to", configPolicy.To).

View file

@ -33,6 +33,7 @@ type authorizeState struct {
} }
func newAuthorizeStateFromConfig( func newAuthorizeStateFromConfig(
ctx context.Context,
cfg *config.Config, store *store.Store, previousPolicyEvaluator *evaluator.Evaluator, cfg *config.Config, store *store.Store, previousPolicyEvaluator *evaluator.Evaluator,
) (*authorizeState, error) { ) (*authorizeState, error) {
if err := validateOptions(cfg.Options); err != nil { if err := validateOptions(cfg.Options); err != nil {
@ -43,7 +44,7 @@ func newAuthorizeStateFromConfig(
var err error var err error
state.evaluator, err = newPolicyEvaluator(cfg.Options, store, previousPolicyEvaluator) state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousPolicyEvaluator)
if err != nil { if err != nil {
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err) return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
} }
@ -58,7 +59,7 @@ func newAuthorizeStateFromConfig(
return nil, err return nil, err
} }
cc, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{ cc, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort, OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID, InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,
@ -84,9 +85,9 @@ func newAuthorizeStateFromConfig(
} }
if cfg.Options.UseStatelessAuthenticateFlow() { if cfg.Options.UseStatelessAuthenticateFlow() {
state.authenticateFlow, err = authenticateflow.NewStateless(cfg, nil, nil, nil, nil) state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, nil, nil, nil, nil)
} else { } else {
state.authenticateFlow, err = authenticateflow.NewStateful(cfg, nil) state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, nil)
} }
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -54,7 +54,7 @@ func run(ctx context.Context, configFile string) error {
var src config.Source var src config.Source
src, err := config.NewFileOrEnvironmentSource(configFile, files.FullVersion()) src, err := config.NewFileOrEnvironmentSource(ctx, configFile, files.FullVersion())
if err != nil { if err != nil {
return err return err
} }

View file

@ -103,9 +103,10 @@ type FileOrEnvironmentSource struct {
// NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource. // NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource.
func NewFileOrEnvironmentSource( func NewFileOrEnvironmentSource(
ctx context.Context,
configFile, envoyVersion string, configFile, envoyVersion string,
) (*FileOrEnvironmentSource, error) { ) (*FileOrEnvironmentSource, error) {
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context { ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("config_file_source", configFile) return c.Str("config_file_source", configFile)
}) })

View file

@ -137,7 +137,7 @@ runtime_flags:
require.NoError(t, err) require.NoError(t, err)
var src Source var src Source
src, err = NewFileOrEnvironmentSource(configFilePath, "") src, err = NewFileOrEnvironmentSource(context.Background(), configFilePath, "")
require.NoError(t, err) require.NoError(t, err)
src = NewFileWatcherSource(context.Background(), src) src = NewFileWatcherSource(context.Background(), src)

View file

@ -247,7 +247,7 @@ func (b *Builder) buildInternalTransportSocket(
b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName), b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName),
}, },
} }
bs, err := getCombinedCertificateAuthority(cfg) bs, err := getCombinedCertificateAuthority(ctx, cfg)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found") log.Ctx(ctx).Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else { } else {
@ -347,7 +347,7 @@ func (b *Builder) buildPolicyValidationContext(
} }
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs) validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
} else { } else {
bs, err := getCombinedCertificateAuthority(cfg) bs, err := getCombinedCertificateAuthority(ctx, cfg)
if err != nil { if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found") log.Ctx(ctx).Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else { } else {

View file

@ -43,14 +43,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-57394a4e5157303436544830.pem") customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-57394a4e5157303436544830.pem")
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}}) rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}})
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename() rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
o1 := config.NewDefaultOptions() o1 := config.NewDefaultOptions()
o2 := config.NewDefaultOptions() o2 := config.NewDefaultOptions()
o2.CA = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}) o2.CA = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0})
combinedCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{CA: o2.CA}}) combinedCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{CA: o2.CA}})
combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename() combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename()
t.Run("insecure", func(t *testing.T) { t.Run("insecure", func(t *testing.T) {
@ -522,7 +522,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
func Test_buildCluster(t *testing.T) { func Test_buildCluster(t *testing.T) {
ctx := context.Background() ctx := context.Background()
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}}) rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}})
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename() rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
o1 := config.NewDefaultOptions() o1 := config.NewDefaultOptions()
t.Run("insecure", func(t *testing.T) { t.Run("insecure", func(t *testing.T) {

View file

@ -189,7 +189,7 @@ var rootCABundle struct {
value string value string
} }
func getRootCertificateAuthority() (string, error) { func getRootCertificateAuthority(ctx context.Context) (string, error) {
rootCABundle.Do(func() { rootCABundle.Do(func() {
// from https://www.envoyproxy.io/docs/envoy/latest/intro/arch_overview/security/ssl#arch-overview-ssl-enabling-verification // from https://www.envoyproxy.io/docs/envoy/latest/intro/arch_overview/security/ssl#arch-overview-ssl-enabling-verification
knownRootLocations := []string{ knownRootLocations := []string{
@ -207,10 +207,10 @@ func getRootCertificateAuthority() (string, error) {
} }
} }
if rootCABundle.value == "" { if rootCABundle.value == "" {
log.Error().Strs("known-locations", knownRootLocations). log.Ctx(ctx).Error().Strs("known-locations", knownRootLocations).
Msgf("no root certificates were found in any of the known locations") Msgf("no root certificates were found in any of the known locations")
} else { } else {
log.Info().Msgf("using %s as the system root certificate authority bundle", rootCABundle.value) log.Ctx(ctx).Info().Msgf("using %s as the system root certificate authority bundle", rootCABundle.value)
} }
}) })
if rootCABundle.value == "" { if rootCABundle.value == "" {
@ -219,8 +219,8 @@ func getRootCertificateAuthority() (string, error) {
return rootCABundle.value, nil return rootCABundle.value, nil
} }
func getCombinedCertificateAuthority(cfg *config.Config) ([]byte, error) { func getCombinedCertificateAuthority(ctx context.Context, cfg *config.Config) ([]byte, error) {
rootFile, err := getRootCertificateAuthority() rootFile, err := getRootCertificateAuthority(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -45,7 +45,7 @@ type DataBroker struct {
} }
// New creates a new databroker service. // New creates a new databroker service.
func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) { func New(ctx context.Context, cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
localListener, err := net.Listen("tcp", "127.0.0.1:0") localListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
return nil, err return nil, err
@ -61,8 +61,8 @@ func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
// No metrics handler because we have one in the control plane. Add one // No metrics handler because we have one in the control plane. Add one
// if we no longer register with that grpc Server // if we no longer register with that grpc Server
localGRPCServer := grpc.NewServer( localGRPCServer := grpc.NewServer(
grpc.StreamInterceptor(si), grpc.ChainStreamInterceptor(log.StreamServerInterceptor(log.Ctx(ctx)), si),
grpc.UnaryInterceptor(ui), grpc.ChainUnaryInterceptor(log.UnaryServerInterceptor(log.Ctx(ctx)), ui),
) )
sharedKey, err := cfg.Options.GetSharedKey() sharedKey, err := cfg.Options.GetSharedKey()
@ -79,7 +79,7 @@ func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
grpc.WithStatsHandler(clientStatsHandler.Handler), grpc.WithStatsHandler(clientStatsHandler.Handler),
} }
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context { ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "databroker").Str("config_source", "bootstrap") return c.Str("service", "databroker").Str("config_source", "bootstrap")
}) })
localGRPCConnection, err := grpc.DialContext( localGRPCConnection, err := grpc.DialContext(
@ -91,7 +91,7 @@ func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
return nil, err return nil, err
} }
dataBrokerServer, err := newDataBrokerServer(cfg) dataBrokerServer, err := newDataBrokerServer(ctx, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,6 +1,7 @@
package databroker package databroker
import ( import (
"context"
"testing" "testing"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
@ -20,7 +21,7 @@ func TestNew(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tt.opts.Provider = "google" tt.opts.Provider = "google"
_, err := New(&config.Config{Options: &tt.opts}, events.New()) _, err := New(context.Background(), &config.Config{Options: &tt.opts}, events.New())
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return return

View file

@ -23,7 +23,7 @@ type dataBrokerServer struct {
} }
// newDataBrokerServer creates a new databroker service server. // newDataBrokerServer creates a new databroker service server.
func newDataBrokerServer(cfg *config.Config) (*dataBrokerServer, error) { func newDataBrokerServer(ctx context.Context, cfg *config.Config) (*dataBrokerServer, error) {
srv := &dataBrokerServer{ srv := &dataBrokerServer{
sharedKey: atomicutil.NewValue([]byte{}), sharedKey: atomicutil.NewValue([]byte{}),
} }
@ -33,7 +33,7 @@ func newDataBrokerServer(cfg *config.Config) (*dataBrokerServer, error) {
return nil, err return nil, err
} }
srv.server = databroker.New(opts...) srv.server = databroker.New(ctx, opts...)
srv.setKey(cfg) srv.setKey(cfg)
return srv, nil return srv, nil
} }
@ -46,7 +46,7 @@ func (srv *dataBrokerServer) OnConfigChange(ctx context.Context, cfg *config.Con
return return
} }
srv.server.UpdateConfig(opts...) srv.server.UpdateConfig(ctx, opts...)
srv.setKey(cfg) srv.setKey(cfg)
} }

View file

@ -29,7 +29,7 @@ var lis *bufconn.Listener
func init() { func init() {
lis = bufconn.Listen(bufSize) lis = bufconn.Listen(bufSize)
s := grpc.NewServer() s := grpc.NewServer()
internalSrv := internal_databroker.New() internalSrv := internal_databroker.New(context.Background())
srv := &dataBrokerServer{server: internalSrv, sharedKey: atomicutil.NewValue([]byte{})} srv := &dataBrokerServer{server: internalSrv, sharedKey: atomicutil.NewValue([]byte{})}
databroker.RegisterDataBrokerServiceServer(s, srv) databroker.RegisterDataBrokerServiceServer(s, srv)

View file

@ -163,7 +163,7 @@ func waitForHealthy(ctx context.Context) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-ticker.C: case <-ticker.C:
} }
} }

View file

@ -56,7 +56,7 @@ type Stateful struct {
// NewStateful initializes the authentication flow for the given configuration // NewStateful initializes the authentication flow for the given configuration
// and session store. // and session store.
func NewStateful(cfg *config.Config, sessionStore sessions.SessionStore) (*Stateful, error) { func NewStateful(ctx context.Context, cfg *config.Config, sessionStore sessions.SessionStore) (*Stateful, error) {
s := &Stateful{ s := &Stateful{
sessionDuration: cfg.Options.CookieExpire, sessionDuration: cfg.Options.CookieExpire,
sessionStore: sessionStore, sessionStore: sessionStore,
@ -88,7 +88,7 @@ func NewStateful(cfg *config.Config, sessionStore sessions.SessionStore) (*State
s.defaultIdentityProviderID = idp.GetId() s.defaultIdentityProviderID = idp.GetId()
} }
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), dataBrokerConn, err := outboundGRPCConnection.Get(ctx,
&grpc.OutboundOptions{ &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort, OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID, InstallationID: cfg.Options.InstallationID,

View file

@ -69,7 +69,7 @@ func TestStatefulSignIn(t *testing.T) {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
sessionStore := &mstore.Store{SaveError: tt.saveError} sessionStore := &mstore.Store{SaveError: tt.saveError}
flow, err := NewStateful(&config.Config{Options: opts}, sessionStore) flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, sessionStore)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -123,7 +123,7 @@ func TestStatefulAuthenticateSignInURL(t *testing.T) {
opts.AuthenticateURLString = "https://authenticate.example.com" opts.AuthenticateURLString = "https://authenticate.example.com"
key := cryptutil.NewKey() key := cryptutil.NewKey()
opts.SharedKey = base64.StdEncoding.EncodeToString(key) opts.SharedKey = base64.StdEncoding.EncodeToString(key)
flow, err := NewStateful(&config.Config{Options: opts}, nil) flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil)
require.NoError(t, err) require.NoError(t, err)
t.Run("NilQueryParams", func(t *testing.T) { t.Run("NilQueryParams", func(t *testing.T) {
@ -238,7 +238,7 @@ func TestStatefulCallback(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
flow, err := NewStateful(&config.Config{Options: opts}, tt.sessionStore) flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, tt.sessionStore)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -289,7 +289,7 @@ func TestStatefulCallback(t *testing.T) {
func TestStatefulRevokeSession(t *testing.T) { func TestStatefulRevokeSession(t *testing.T) {
opts := config.NewDefaultOptions() opts := config.NewDefaultOptions()
flow, err := NewStateful(&config.Config{Options: opts}, nil) flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil)
require.NoError(t, err) require.NoError(t, err)
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
@ -367,7 +367,7 @@ func TestPersistSession(t *testing.T) {
opts := config.NewDefaultOptions() opts := config.NewDefaultOptions()
opts.CookieExpire = 4 * time.Hour opts.CookieExpire = 4 * time.Hour
flow, err := NewStateful(&config.Config{Options: opts}, nil) flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil)
require.NoError(t, err) require.NoError(t, err)
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)

View file

@ -64,6 +64,7 @@ type Stateless struct {
// NewStateless initializes the authentication flow for the given // NewStateless initializes the authentication flow for the given
// configuration, session store, and additional options. // configuration, session store, and additional options.
func NewStateless( func NewStateless(
ctx context.Context,
cfg *config.Config, cfg *config.Config,
sessionStore sessions.SessionStore, sessionStore sessions.SessionStore,
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error), getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error),
@ -131,7 +132,7 @@ func NewStateless(
return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err) return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err)
} }
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{ dataBrokerConn, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort, OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID, InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,

View file

@ -63,11 +63,12 @@ type Manager struct {
} }
// New creates a new autocert manager. // New creates a new autocert manager.
func New(src config.Source) (*Manager, error) { func New(ctx context.Context, src config.Source) (*Manager, error) {
return newManager(context.Background(), src, certmagic.DefaultACME, renewalInterval) return newManager(ctx, src, certmagic.DefaultACME, renewalInterval)
} }
func newManager(ctx context.Context, func newManager(
ctx context.Context,
src config.Source, src config.Source,
acmeTemplate certmagic.ACMEIssuer, acmeTemplate certmagic.ACMEIssuer,
checkInterval time.Duration, checkInterval time.Duration,
@ -96,12 +97,13 @@ func newManager(ctx context.Context,
if err != nil { if err != nil {
return nil, err return nil, err
} }
mgr.certmagic = certmagic.New(certmagic.NewCache(certmagic.CacheOptions{ cache := certmagic.NewCache(certmagic.CacheOptions{
GetConfigForCert: func(_ certmagic.Certificate) (*certmagic.Config, error) { GetConfigForCert: func(_ certmagic.Certificate) (*certmagic.Config, error) {
return mgr.certmagic, nil return mgr.certmagic, nil
}, },
Logger: logger, Logger: logger,
}), certmagic.Config{ })
mgr.certmagic = certmagic.New(cache, certmagic.Config{
Logger: logger, Logger: logger,
Storage: certmagicStorage, Storage: certmagicStorage,
}) })

View file

@ -316,7 +316,7 @@ func TestRedirect(t *testing.T) {
}, },
}, },
}) })
_, err = New(src) _, err = New(context.Background(), src)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
} }

View file

@ -47,7 +47,7 @@ func (l *locker) Lock(ctx context.Context, name string) error {
// wait // wait
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-time.After(lockPollInterval): case <-time.After(lockPollInterval):
} }
continue continue

View file

@ -71,7 +71,7 @@ func (srv *Server) getDataBrokerClient(ctx context.Context) (databrokerpb.DataBr
return nil, err return nil, err
} }
cc, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{ cc, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort, OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID, InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,

View file

@ -8,6 +8,7 @@ import (
"github.com/CAFxX/httpcompression" "github.com/CAFxX/httpcompression"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/handlers"
@ -19,7 +20,7 @@ import (
"github.com/pomerium/pomerium/pkg/telemetry/requestid" "github.com/pomerium/pomerium/pkg/telemetry/requestid"
) )
func (srv *Server) addHTTPMiddleware(root *mux.Router, _ *config.Config) { func (srv *Server) addHTTPMiddleware(root *mux.Router, logger *zerolog.Logger, _ *config.Config) {
compressor, err := httpcompression.DefaultAdapter() compressor, err := httpcompression.DefaultAdapter()
if err != nil { if err != nil {
panic(err) panic(err)
@ -28,7 +29,7 @@ func (srv *Server) addHTTPMiddleware(root *mux.Router, _ *config.Config) {
root.Use(compressor) root.Use(compressor)
root.Use(srv.reproxy.Middleware) root.Use(srv.reproxy.Middleware)
root.Use(requestid.HTTPMiddleware()) root.Use(requestid.HTTPMiddleware())
root.Use(log.NewHandler(log.Logger)) root.Use(log.NewHandler(func() *zerolog.Logger { return logger }))
root.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) { root.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
log.FromRequest(r).Debug(). log.FromRequest(r).Debug().
Dur("duration", duration). Dur("duration", duration).

View file

@ -69,10 +69,16 @@ type Server struct {
} }
// NewServer creates a new Server. Listener ports are chosen by the OS. // NewServer creates a new Server. Listener ports are chosen by the OS.
func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr *events.Manager) (*Server, error) { func NewServer(
ctx context.Context,
cfg *config.Config,
metricsMgr *config.MetricsManager,
eventsMgr *events.Manager,
) (*Server, error) {
srv := &Server{ srv := &Server{
metricsMgr: metricsMgr, metricsMgr: metricsMgr,
EventsMgr: eventsMgr, EventsMgr: eventsMgr,
filemgr: filemgr.NewManager(),
reproxy: reproxy.New(), reproxy: reproxy.New(),
haveSetCapacity: map[string]bool{}, haveSetCapacity: map[string]bool{},
updateConfig: make(chan *config.Config, 1), updateConfig: make(chan *config.Config, 1),
@ -80,6 +86,10 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
httpRouter: atomicutil.NewValue(mux.NewRouter()), httpRouter: atomicutil.NewValue(mux.NewRouter()),
} }
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("server_name", cfg.Options.Services)
})
var err error var err error
// setup gRPC // setup gRPC
@ -95,8 +105,16 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
) )
srv.GRPCServer = grpc.NewServer( srv.GRPCServer = grpc.NewServer(
grpc.StatsHandler(telemetry.NewGRPCServerStatsHandler(cfg.Options.Services)), grpc.StatsHandler(telemetry.NewGRPCServerStatsHandler(cfg.Options.Services)),
grpc.ChainUnaryInterceptor(requestid.UnaryServerInterceptor(), ui), grpc.ChainUnaryInterceptor(
grpc.ChainStreamInterceptor(requestid.StreamServerInterceptor(), si), log.UnaryServerInterceptor(log.Ctx(ctx)),
requestid.UnaryServerInterceptor(),
ui,
),
grpc.ChainStreamInterceptor(
log.StreamServerInterceptor(log.Ctx(ctx)),
requestid.StreamServerInterceptor(),
si,
),
) )
reflection.Register(srv.GRPCServer) reflection.Register(srv.GRPCServer)
srv.registerAccessLogHandlers() srv.registerAccessLogHandlers()
@ -125,7 +143,7 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
return nil, err return nil, err
} }
if err := srv.updateRouter(cfg); err != nil { if err := srv.updateRouter(ctx, cfg); err != nil {
return nil, err return nil, err
} }
srv.DebugRouter = mux.NewRouter() srv.DebugRouter = mux.NewRouter()
@ -141,7 +159,6 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
// metrics // metrics
srv.MetricsRouter.Handle("/metrics", srv.metricsMgr) srv.MetricsRouter.Handle("/metrics", srv.metricsMgr)
srv.filemgr = filemgr.NewManager()
srv.filemgr.ClearCache() srv.filemgr.ClearCache()
srv.Builder = envoyconfig.New( srv.Builder = envoyconfig.New(
@ -152,10 +169,6 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
srv.reproxy, srv.reproxy,
) )
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
return c.Str("server_name", cfg.Options.Services)
})
res, err := srv.buildDiscoveryResources(ctx) res, err := srv.buildDiscoveryResources(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@ -211,7 +224,7 @@ func (srv *Server) Run(ctx context.Context) error {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case cfg := <-srv.updateConfig: case cfg := <-srv.updateConfig:
err := srv.update(ctx, cfg) err := srv.update(ctx, cfg)
if err != nil { if err != nil {
@ -232,29 +245,29 @@ func (srv *Server) OnConfigChange(ctx context.Context, cfg *config.Config) error
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case srv.updateConfig <- cfg: case srv.updateConfig <- cfg:
} }
return nil return nil
} }
// EnableAuthenticate enables the authenticate service. // EnableAuthenticate enables the authenticate service.
func (srv *Server) EnableAuthenticate(svc Service) error { func (srv *Server) EnableAuthenticate(ctx context.Context, svc Service) error {
srv.authenticateSvc = svc srv.authenticateSvc = svc
return srv.updateRouter(srv.currentConfig.Load()) return srv.updateRouter(ctx, srv.currentConfig.Load())
} }
// EnableProxy enables the proxy service. // EnableProxy enables the proxy service.
func (srv *Server) EnableProxy(svc Service) error { func (srv *Server) EnableProxy(ctx context.Context, svc Service) error {
srv.proxySvc = svc srv.proxySvc = svc
return srv.updateRouter(srv.currentConfig.Load()) return srv.updateRouter(ctx, srv.currentConfig.Load())
} }
func (srv *Server) update(ctx context.Context, cfg *config.Config) error { func (srv *Server) update(ctx context.Context, cfg *config.Config) error {
ctx, span := trace.StartSpan(ctx, "controlplane.Server.update") ctx, span := trace.StartSpan(ctx, "controlplane.Server.update")
defer span.End() defer span.End()
if err := srv.updateRouter(cfg); err != nil { if err := srv.updateRouter(ctx, cfg); err != nil {
return err return err
} }
srv.reproxy.Update(ctx, cfg) srv.reproxy.Update(ctx, cfg)
@ -267,9 +280,9 @@ func (srv *Server) update(ctx context.Context, cfg *config.Config) error {
return nil return nil
} }
func (srv *Server) updateRouter(cfg *config.Config) error { func (srv *Server) updateRouter(ctx context.Context, cfg *config.Config) error {
httpRouter := mux.NewRouter() httpRouter := mux.NewRouter()
srv.addHTTPMiddleware(httpRouter, cfg) srv.addHTTPMiddleware(httpRouter, log.Ctx(ctx), cfg)
if err := srv.mountCommonEndpoints(httpRouter, cfg); err != nil { if err := srv.mountCommonEndpoints(httpRouter, cfg); err != nil {
return err return err
} }

View file

@ -38,7 +38,7 @@ func TestServerHTTP(t *testing.T) {
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs=" cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
src := config.NewStaticSource(cfg) src := config.NewStaticSource(cfg)
srv, err := NewServer(cfg, config.NewMetricsManager(ctx, src), events.New()) srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New())
require.NoError(t, err) require.NoError(t, err)
go srv.Run(ctx) go srv.Run(ctx)

View file

@ -1,6 +1,7 @@
package xdsmgr package xdsmgr
import ( import (
"context"
"errors" "errors"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
@ -19,8 +20,8 @@ var (
routeConfigurationTypeURL = protoutil.GetTypeURL((*envoy_config_route_v3.RouteConfiguration)(nil)) routeConfigurationTypeURL = protoutil.GetTypeURL((*envoy_config_route_v3.RouteConfiguration)(nil))
) )
func logNACK(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) { func logNACK(ctx context.Context, req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
log.Debug(). log.Ctx(ctx).Debug().
Str("type-url", req.GetTypeUrl()). Str("type-url", req.GetTypeUrl()).
Any("error-detail", req.GetErrorDetail()). Any("error-detail", req.GetErrorDetail()).
Msg("xdsmgr: nack") Msg("xdsmgr: nack")
@ -28,8 +29,8 @@ func logNACK(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
health.ReportError(getHealthCheck(req.GetTypeUrl()), errors.New(req.GetErrorDetail().GetMessage())) health.ReportError(getHealthCheck(req.GetTypeUrl()), errors.New(req.GetErrorDetail().GetMessage()))
} }
func logACK(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) { func logACK(ctx context.Context, req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
log.Debug(). log.Ctx(ctx).Debug().
Str("type-url", req.GetTypeUrl()). Str("type-url", req.GetTypeUrl()).
Msg("xdsmgr: ack") Msg("xdsmgr: ack")

View file

@ -113,7 +113,7 @@ func (mgr *Manager) DeltaAggregatedResources(
for _, resource := range mgr.resources[req.GetTypeUrl()] { for _, resource := range mgr.resources[req.GetTypeUrl()] {
state.clientResourceVersions[resource.Name] = resource.Version state.clientResourceVersions[resource.Name] = resource.Version
} }
logNACK(req) logNACK(ctx, req)
case req.GetResponseNonce() == mgr.nonce: case req.GetResponseNonce() == mgr.nonce:
// an ACK for the last response // an ACK for the last response
// - set the client resource versions to the current resource versions // - set the client resource versions to the current resource versions
@ -121,10 +121,11 @@ func (mgr *Manager) DeltaAggregatedResources(
for _, resource := range mgr.resources[req.GetTypeUrl()] { for _, resource := range mgr.resources[req.GetTypeUrl()] {
state.clientResourceVersions[resource.Name] = resource.Version state.clientResourceVersions[resource.Name] = resource.Version
} }
logACK(req) logACK(ctx, req)
default: default:
// an ACK for a response that's not the last response // an ACK for a response that's not the last response
log.Ctx(ctx).Debug(). log.Ctx(ctx).
Debug().
Str("type-url", req.GetTypeUrl()). Str("type-url", req.GetTypeUrl()).
Msg("xdsmgr: ack") Msg("xdsmgr: ack")
} }
@ -161,7 +162,7 @@ func (mgr *Manager) DeltaAggregatedResources(
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case incoming <- req: case incoming <- req:
} }
} }
@ -173,7 +174,7 @@ func (mgr *Manager) DeltaAggregatedResources(
var typeURLs []string var typeURLs []string
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case req := <-incoming: case req := <-incoming:
handleDeltaRequest(changeCtx, req) handleDeltaRequest(changeCtx, req)
typeURLs = []string{req.GetTypeUrl()} typeURLs = []string{req.GetTypeUrl()}
@ -193,7 +194,7 @@ func (mgr *Manager) DeltaAggregatedResources(
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case outgoing <- res: case outgoing <- res:
} }
} }
@ -204,9 +205,10 @@ func (mgr *Manager) DeltaAggregatedResources(
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case res := <-outgoing: case res := <-outgoing:
log.Ctx(ctx).Debug(). log.Ctx(ctx).
Debug().
Str("type-url", res.GetTypeUrl()). Str("type-url", res.GetTypeUrl()).
Int("resource-count", len(res.GetResources())). Int("resource-count", len(res.GetResources())).
Int("removed-resource-count", len(res.GetRemovedResources())). Int("removed-resource-count", len(res.GetRemovedResources())).

View file

@ -96,7 +96,7 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
cfg := src.underlyingConfig.Clone() cfg := src.underlyingConfig.Clone()
// start the updater // start the updater
src.runUpdater(cfg) src.runUpdater(ctx, cfg)
now = time.Now() now = time.Now()
err := src.buildNewConfigLocked(ctx, cfg) err := src.buildNewConfigLocked(ctx, cfg)
@ -234,7 +234,7 @@ func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, po
cfg.Options.AdditionalPolicies = append(cfg.Options.AdditionalPolicies, additionalPolicies...) cfg.Options.AdditionalPolicies = append(cfg.Options.AdditionalPolicies, additionalPolicies...)
} }
func (src *ConfigSource) runUpdater(cfg *config.Config) { func (src *ConfigSource) runUpdater(ctx context.Context, cfg *config.Config) {
sharedKey, _ := cfg.Options.GetSharedKey() sharedKey, _ := cfg.Options.GetSharedKey()
connectionOptions := &grpc.OutboundOptions{ connectionOptions := &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort, OutboundPort: cfg.OutboundPort,
@ -257,7 +257,6 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
src.cancel = nil src.cancel = nil
} }
ctx := context.Background()
ctx, src.cancel = context.WithCancel(ctx) ctx, src.cancel = context.WithCancel(ctx)
cc, err := src.outboundGRPCConnection.Get(ctx, connectionOptions) cc, err := src.outboundGRPCConnection.Get(ctx, connectionOptions)
@ -268,7 +267,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
client := databroker.NewDataBrokerServiceClient(cc) client := databroker.NewDataBrokerServiceClient(cc)
syncer := databroker.NewSyncer("databroker", &syncerHandler{ syncer := databroker.NewSyncer(ctx, "databroker", &syncerHandler{
client: client, client: client,
src: src, src: src,
}, databroker.WithTypeURL(grpcutil.GetTypeURL(new(configpb.Config))), }, databroker.WithTypeURL(grpcutil.GetTypeURL(new(configpb.Config))),

View file

@ -41,7 +41,7 @@ func TestConfigSource(t *testing.T) {
defer func() { _ = li.Close() }() defer func() { _ = li.Close() }()
_, outboundPort, _ := net.SplitHostPort(li.Addr().String()) _, outboundPort, _ := net.SplitHostPort(li.Addr().String())
dataBrokerServer := New() dataBrokerServer := New(ctx)
srv := grpc.NewServer() srv := grpc.NewServer()
databroker.RegisterDataBrokerServiceServer(srv, dataBrokerServer) databroker.RegisterDataBrokerServiceServer(srv, dataBrokerServer)
go func() { _ = srv.Serve(li) }() go func() { _ = srv.Serve(li) }()

View file

@ -28,7 +28,7 @@ func (srv *Server) Report(ctx context.Context, req *registrypb.RegisterRequest)
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Report") ctx, span := trace.StartSpan(ctx, "databroker.grpc.Report")
defer span.End() defer span.End()
r, err := srv.getRegistry() r, err := srv.getRegistry(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -41,7 +41,7 @@ func (srv *Server) List(ctx context.Context, req *registrypb.ListRequest) (*regi
ctx, span := trace.StartSpan(ctx, "databroker.grpc.List") ctx, span := trace.StartSpan(ctx, "databroker.grpc.List")
defer span.End() defer span.End()
r, err := srv.getRegistry() r, err := srv.getRegistry(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -55,7 +55,7 @@ func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Watch") ctx, span := trace.StartSpan(ctx, "databroker.grpc.Watch")
defer span.End() defer span.End()
r, err := srv.getRegistry() r, err := srv.getRegistry(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -66,8 +66,8 @@ func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry
}) })
} }
func (srv *Server) getRegistry() (registry.Interface, error) { func (srv *Server) getRegistry(ctx context.Context) (registry.Interface, error) {
backend, err := srv.getBackend() backend, err := srv.getBackend(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -81,7 +81,7 @@ func (srv *Server) getRegistry() (registry.Interface, error) {
r = srv.registry r = srv.registry
var err error var err error
if r == nil { if r == nil {
r, err = srv.newRegistryLocked(backend) r, err = srv.newRegistryLocked(ctx, backend)
srv.registry = r srv.registry = r
} }
srv.mu.Unlock() srv.mu.Unlock()
@ -92,9 +92,7 @@ func (srv *Server) getRegistry() (registry.Interface, error) {
return r, nil return r, nil
} }
func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interface, error) { func (srv *Server) newRegistryLocked(ctx context.Context, backend storage.Backend) (registry.Interface, error) {
ctx := context.Background()
if hasRegistryServer, ok := backend.(interface { if hasRegistryServer, ok := backend.(interface {
RegistryServer() registrypb.RegistryServer RegistryServer() registrypb.RegistryServer
}); ok { }); ok {

View file

@ -28,25 +28,26 @@ import (
type Server struct { type Server struct {
cfg *serverConfig cfg *serverConfig
mu sync.RWMutex mu sync.RWMutex
backend storage.Backend backend storage.Backend
registry registry.Interface backendCtx context.Context
registry registry.Interface
} }
// New creates a new server. // New creates a new server.
func New(options ...ServerOption) *Server { func New(ctx context.Context, options ...ServerOption) *Server {
srv := &Server{} srv := &Server{
srv.UpdateConfig(options...) backendCtx: ctx,
}
srv.UpdateConfig(ctx, options...)
return srv return srv
} }
// UpdateConfig updates the server with the new options. // UpdateConfig updates the server with the new options.
func (srv *Server) UpdateConfig(options ...ServerOption) { func (srv *Server) UpdateConfig(ctx context.Context, options ...ServerOption) {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
ctx := context.TODO()
cfg := newServerConfig(options...) cfg := newServerConfig(options...)
if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) { if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) {
log.Ctx(ctx).Debug().Msg("databroker: no changes detected, re-using existing DBs") log.Ctx(ctx).Debug().Msg("databroker: no changes detected, re-using existing DBs")
@ -80,7 +81,7 @@ func (srv *Server) AcquireLease(ctx context.Context, req *databroker.AcquireLeas
Dur("duration", req.GetDuration().AsDuration()). Dur("duration", req.GetDuration().AsDuration()).
Msg("acquire lease") Msg("acquire lease")
db, err := srv.getBackend() db, err := srv.getBackend(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -107,7 +108,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
Str("id", req.GetId()). Str("id", req.GetId()).
Msg("get") Msg("get")
db, err := srv.getBackend() db, err := srv.getBackend(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -131,7 +132,7 @@ func (srv *Server) ListTypes(ctx context.Context, _ *emptypb.Empty) (*databroker
defer span.End() defer span.End()
log.Ctx(ctx).Debug().Msg("list types") log.Ctx(ctx).Debug().Msg("list types")
db, err := srv.getBackend() db, err := srv.getBackend(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -156,7 +157,7 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
query := strings.ToLower(req.GetQuery()) query := strings.ToLower(req.GetQuery())
db, err := srv.getBackend() db, err := srv.getBackend(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -217,7 +218,7 @@ func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databr
Msg("put") Msg("put")
} }
db, err := srv.getBackend() db, err := srv.getBackend(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -256,7 +257,7 @@ func (srv *Server) Patch(ctx context.Context, req *databroker.PatchRequest) (*da
Msg("patch") Msg("patch")
} }
db, err := srv.getBackend() db, err := srv.getBackend(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -282,7 +283,7 @@ func (srv *Server) ReleaseLease(ctx context.Context, req *databroker.ReleaseLeas
Str("id", req.GetId()). Str("id", req.GetId()).
Msg("release lease") Msg("release lease")
db, err := srv.getBackend() db, err := srv.getBackend(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -305,7 +306,7 @@ func (srv *Server) RenewLease(ctx context.Context, req *databroker.RenewLeaseReq
Dur("duration", req.GetDuration().AsDuration()). Dur("duration", req.GetDuration().AsDuration()).
Msg("renew lease") Msg("renew lease")
db, err := srv.getBackend() db, err := srv.getBackend(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -325,7 +326,7 @@ func (srv *Server) SetOptions(ctx context.Context, req *databroker.SetOptionsReq
ctx, span := trace.StartSpan(ctx, "databroker.grpc.SetOptions") ctx, span := trace.StartSpan(ctx, "databroker.grpc.SetOptions")
defer span.End() defer span.End()
backend, err := srv.getBackend() backend, err := srv.getBackend(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -351,12 +352,13 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
log.Ctx(ctx).Debug(). log.Ctx(ctx).
Debug().
Uint64("server_version", req.GetServerVersion()). Uint64("server_version", req.GetServerVersion()).
Uint64("record_version", req.GetRecordVersion()). Uint64("record_version", req.GetRecordVersion()).
Msg("sync") Msg("sync")
backend, err := srv.getBackend() backend, err := srv.getBackend(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -392,7 +394,7 @@ func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databrok
Str("type", req.GetType()). Str("type", req.GetType()).
Msg("sync latest") Msg("sync latest")
backend, err := srv.getBackend() backend, err := srv.getBackend(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -430,7 +432,7 @@ func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databrok
}) })
} }
func (srv *Server) getBackend() (backend storage.Backend, err error) { func (srv *Server) getBackend(ctx context.Context) (backend storage.Backend, err error) {
// double-checked locking: // double-checked locking:
// first try the read lock, then re-try with the write lock, and finally create a new backend if nil // first try the read lock, then re-try with the write lock, and finally create a new backend if nil
srv.mu.RLock() srv.mu.RLock()
@ -441,7 +443,7 @@ func (srv *Server) getBackend() (backend storage.Backend, err error) {
backend = srv.backend backend = srv.backend
var err error var err error
if backend == nil { if backend == nil {
backend, err = srv.newBackendLocked() backend, err = srv.newBackendLocked(ctx)
srv.backend = backend srv.backend = backend
} }
srv.mu.Unlock() srv.mu.Unlock()
@ -452,18 +454,18 @@ func (srv *Server) getBackend() (backend storage.Backend, err error) {
return backend, nil return backend, nil
} }
func (srv *Server) newBackendLocked() (backend storage.Backend, err error) { func (srv *Server) newBackendLocked(ctx context.Context) (storage.Backend, error) {
ctx := context.Background()
switch srv.cfg.storageType { switch srv.cfg.storageType {
case config.StorageInMemoryName: case config.StorageInMemoryName:
log.Ctx(ctx).Info().Msg("using in-memory store") log.Ctx(ctx).Info().Msg("initializing new in-memory store")
return inmemory.New(), nil return inmemory.New(), nil
case config.StoragePostgresName: case config.StoragePostgresName:
log.Ctx(ctx).Info().Msg("using postgres store") log.Ctx(ctx).Info().Msg("initializing new postgres store")
backend = postgres.New(srv.cfg.storageConnectionString) // NB: the context passed to postgres.New here is a separate context scoped
// to the lifetime of the server itself. 'ctx' may be a short-lived request
// context, since the backend is lazy-initialized.
return postgres.New(srv.backendCtx, srv.cfg.storageConnectionString), nil
default: default:
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType) return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
} }
return backend, nil
} }

View file

@ -48,7 +48,8 @@ func (h testSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint
func newServer(cfg *serverConfig) *Server { func newServer(cfg *serverConfig) *Server {
return &Server{ return &Server{
cfg: cfg, cfg: cfg,
backendCtx: context.Background(),
} }
} }
@ -277,7 +278,7 @@ func TestServer_Sync(t *testing.T) {
updateRecords := make(chan uint64, 10) updateRecords := make(chan uint64, 10)
client := databroker.NewDataBrokerServiceClient(cc) client := databroker.NewDataBrokerServiceClient(cc)
syncer := databroker.NewSyncer("TEST", testSyncerHandler{ syncer := databroker.NewSyncer(ctx, "TEST", testSyncerHandler{
getDataBrokerServiceClient: func() databroker.DataBrokerServiceClient { getDataBrokerServiceClient: func() databroker.DataBrokerServiceClient {
return client return client
}, },
@ -292,12 +293,12 @@ func TestServer_Sync(t *testing.T) {
select { select {
case <-clearRecords: case <-clearRecords:
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
} }
select { select {
case <-updateRecords: case <-updateRecords:
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
} }
@ -313,7 +314,7 @@ func TestServer_Sync(t *testing.T) {
select { select {
case <-updateRecords: case <-updateRecords:
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
} }
return nil return nil

View file

@ -45,7 +45,7 @@ func TestEnabler(t *testing.T) {
started.Add(1) started.Add(1)
<-ctx.Done() <-ctx.Done()
stopped.Add(1) stopped.Add(1)
return ctx.Err() return context.Cause(ctx)
}), true) }), true)
time.AfterFunc(time.Millisecond*10, e.Disable) time.AfterFunc(time.Millisecond*10, e.Disable)
go e.Run(ctx) go e.Run(ctx)

View file

@ -80,6 +80,9 @@ func (watcher *Watcher) initLocked(ctx context.Context) {
if watcher.pollingWatcher == nil { if watcher.pollingWatcher == nil {
watcher.pollingWatcher = filenotify.NewPollingWatcher(nil) watcher.pollingWatcher = filenotify.NewPollingWatcher(nil)
context.AfterFunc(ctx, func() {
watcher.pollingWatcher.Close()
})
} }
errors := watcher.pollingWatcher.Errors() errors := watcher.pollingWatcher.Errors()

10
internal/log/debug.go Normal file
View file

@ -0,0 +1,10 @@
package log
import "sync/atomic"
var (
// Debug option to disable the Zap log shim
DebugDisableZapLogger atomic.Bool
// Debug option to suppress global warnings
DebugDisableGlobalWarnings atomic.Bool
)

View file

@ -52,6 +52,9 @@ func Logger() *zerolog.Logger {
// ZapLogger returns the global zap logger. // ZapLogger returns the global zap logger.
func ZapLogger() *zap.Logger { func ZapLogger() *zap.Logger {
if DebugDisableZapLogger.Load() {
return zap.NewNop()
}
return zapLogger.Load() return zapLogger.Load()
} }

View file

@ -1,11 +1,14 @@
package log package log
import ( import (
"context"
"net" "net"
"net/http" "net/http"
"time" "time"
"github.com/pomerium/protoutil/streams"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/middleware/responsewriter" "github.com/pomerium/pomerium/internal/middleware/responsewriter"
"github.com/pomerium/pomerium/pkg/telemetry/requestid" "github.com/pomerium/pomerium/pkg/telemetry/requestid"
@ -121,3 +124,17 @@ func HeadersHandler(headers []string) func(next http.Handler) http.Handler {
}) })
} }
} }
func StreamServerInterceptor(lg *zerolog.Logger) grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
s := streams.NewServerStreamWithContext(ss)
s.SetContext(lg.WithContext(s.Ctx))
return handler(srv, s)
}
}
func UnaryServerInterceptor(lg *zerolog.Logger) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
return handler(lg.WithContext(ctx), req)
}
}

View file

@ -11,6 +11,9 @@ var warnCookieSecretOnce sync.Once
// WarnCookieSecret warns about the cookie secret. // WarnCookieSecret warns about the cookie secret.
func WarnCookieSecret() { func WarnCookieSecret() {
warnCookieSecretOnce.Do(func() { warnCookieSecretOnce.Do(func() {
if DebugDisableGlobalWarnings.Load() {
return
}
Info(). Info().
Msg("using a generated COOKIE_SECRET. " + Msg("using a generated COOKIE_SECRET. " +
"Set the COOKIE_SECRET to avoid users being logged out on restart. " + "Set the COOKIE_SECRET to avoid users being logged out on restart. " +
@ -23,6 +26,9 @@ var warnNoTLSCertificateOnce syncutil.OnceMap[string]
// WarnNoTLSCertificate warns about no TLS certificate. // WarnNoTLSCertificate warns about no TLS certificate.
func WarnNoTLSCertificate(domain string) { func WarnNoTLSCertificate(domain string) {
warnNoTLSCertificateOnce.Do(domain, func() { warnNoTLSCertificateOnce.Do(domain, func() {
if DebugDisableGlobalWarnings.Load() {
return
}
Info(). Info().
Str("domain", domain). Str("domain", domain).
Msg("no TLS certificate found for domain, using a self-signed certificate") Msg("no TLS certificate found for domain, using a self-signed certificate")
@ -34,6 +40,9 @@ var warnWebSocketHTTP1_1Once syncutil.OnceMap[string]
// WarnWebSocketHTTP1_1 warns about falling back to http 1.1 due to web socket support. // WarnWebSocketHTTP1_1 warns about falling back to http 1.1 due to web socket support.
func WarnWebSocketHTTP1_1(clusterID string) { func WarnWebSocketHTTP1_1(clusterID string) {
warnWebSocketHTTP1_1Once.Do(clusterID, func() { warnWebSocketHTTP1_1Once.Do(clusterID, func() {
if DebugDisableGlobalWarnings.Load() {
return
}
Info(). Info().
Str("cluster-id", clusterID). Str("cluster-id", clusterID).
Msg("forcing http/1.1 due to web socket support") Msg("forcing http/1.1 due to web socket support")

View file

@ -103,7 +103,7 @@ func makeSelect(
fn: func(ctx context.Context) error { fn: func(ctx context.Context) error {
// unreachable, the context handler will never be called // unreachable, the context handler will never be called
// as its channel can only be closed // as its channel can only be closed
return ctx.Err() return context.Cause(ctx)
}, },
ch: reflect.ValueOf(ctx.Done()), ch: reflect.ValueOf(ctx.Done()),
}, },

View file

@ -22,7 +22,7 @@ func WaitForHealthy(ctx context.Context, client *http.Client, routes []*config.R
healthy++ healthy++
} }
} }
return ctx.Err() return context.Cause(ctx)
} }
func checkHealth(ctx context.Context, client *http.Client, addr string) error { func checkHealth(ctx context.Context, client *http.Client, addr string) error {

View file

@ -74,7 +74,7 @@ func (svc *Source) updateLoop(ctx context.Context) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-svc.checkForUpdate: case <-svc.checkForUpdate:
case <-ticker.C: case <-ticker.C:
} }

View file

@ -34,7 +34,7 @@ type source struct {
func (src *source) WaitReady(ctx context.Context) error { func (src *source) WaitReady(ctx context.Context) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-src.ready: case <-src.ready:
return nil return nil
} }

View file

@ -36,7 +36,7 @@ func BuildImportCmd() *cobra.Command {
return fmt.Errorf("no config file provided") return fmt.Errorf("no config file provided")
} }
log.SetLevel(zerolog.ErrorLevel) log.SetLevel(zerolog.ErrorLevel)
src, err := config.NewFileOrEnvironmentSource(configFile, files.FullVersion()) src, err := config.NewFileOrEnvironmentSource(cmd.Context(), configFile, files.FullVersion())
if err != nil { if err != nil {
return err return err
} }

View file

@ -17,7 +17,7 @@ import (
func (svc *Mux) Watch(ctx context.Context, opts ...WatchOption) error { func (svc *Mux) Watch(ctx context.Context, opts ...WatchOption) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-svc.ready: case <-svc.ready:
} }

View file

@ -102,7 +102,7 @@ func TestDatabrokerRestart(t *testing.T) {
cl(context.Background(), newConfig()) cl(context.Background(), newConfig())
<-ctx.Done() <-ctx.Done()
require.ErrorIs(t, context.Cause(ctx), controller.ErrBootstrapConfigurationChanged) require.ErrorIs(t, context.Cause(ctx), controller.ErrBootstrapConfigurationChanged)
return ctx.Err() return context.Cause(ctx)
} }
return nil return nil
}) })

View file

@ -56,5 +56,5 @@ func (w *LeaseStatus) MonitorLease(ctx context.Context, _ databroker.DataBrokerS
w.v.Store(true) w.v.Store(true)
<-ctx.Done() <-ctx.Done()
w.v.Store(false) w.v.Store(false)
return ctx.Err() return context.Cause(ctx)
} }

View file

@ -37,7 +37,7 @@ func TestUsageReporter(t *testing.T) {
t.Cleanup(cancel) t.Cleanup(cancel)
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New()) databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx))
}) })
t.Cleanup(func() { cc.Close() }) t.Cleanup(func() { cc.Close() })

View file

@ -10,7 +10,7 @@ import (
) )
func (c *Checker) ConfigSyncer(ctx context.Context) error { func (c *Checker) ConfigSyncer(ctx context.Context) error {
syncer := databroker.NewSyncer("zero-health-check", c, databroker.WithTypeURL(protoutil.GetTypeURL(new(configpb.Config)))) syncer := databroker.NewSyncer(ctx, "zero-health-check", c, databroker.WithTypeURL(protoutil.GetTypeURL(new(configpb.Config))))
return syncer.Run(ctx) return syncer.Run(ctx)
} }

View file

@ -35,7 +35,7 @@ func (c *service) SyncLoop(ctx context.Context) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-c.bundleSyncRequest: case <-c.bundleSyncRequest:
log.Ctx(ctx).Debug().Msg("bundle sync triggered") log.Ctx(ctx).Debug().Msg("bundle sync triggered")
err := c.syncBundles(ctx) err := c.syncBundles(ctx)

View file

@ -105,7 +105,7 @@ func (srv *Telemetry) handleRequests(ctx context.Context) error {
case req := <-requests: case req := <-requests:
srv.handleRequest(ctx, req) srv.handleRequest(ctx, req)
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
} }
} }
}) })

View file

@ -32,7 +32,7 @@ import (
// Run runs the main pomerium application. // Run runs the main pomerium application.
func Run(ctx context.Context, src config.Source) error { func Run(ctx context.Context, src config.Source) error {
_, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Debug().Msgf(s, i...) })) _, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Ctx(ctx).Debug().Msgf(s, i...) }))
evt := log.Ctx(ctx).Info(). evt := log.Ctx(ctx).Info().
Str("envoy_version", files.FullVersion()). Str("envoy_version", files.FullVersion()).
@ -53,7 +53,7 @@ func Run(ctx context.Context, src config.Source) error {
// trigger changes when underlying files are changed // trigger changes when underlying files are changed
src = config.NewFileWatcherSource(ctx, src) src = config.NewFileWatcherSource(ctx, src)
src, err = autocert.New(src) src, err = autocert.New(ctx, src)
if err != nil { if err != nil {
return err return err
} }
@ -71,7 +71,7 @@ func Run(ctx context.Context, src config.Source) error {
cfg := src.GetConfig() cfg := src.GetConfig()
// setup the control plane // setup the control plane
controlPlane, err := controlplane.NewServer(cfg, metricsMgr, eventsMgr) controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr)
if err != nil { if err != nil {
return fmt.Errorf("error creating control plane: %w", err) return fmt.Errorf("error creating control plane: %w", err)
} }
@ -166,11 +166,11 @@ func setupAuthenticate(ctx context.Context, src config.Source, controlPlane *con
return nil return nil
} }
svc, err := authenticate.New(src.GetConfig()) svc, err := authenticate.New(ctx, src.GetConfig())
if err != nil { if err != nil {
return fmt.Errorf("error creating authenticate service: %w", err) return fmt.Errorf("error creating authenticate service: %w", err)
} }
err = controlPlane.EnableAuthenticate(svc) err = controlPlane.EnableAuthenticate(ctx, svc)
if err != nil { if err != nil {
return fmt.Errorf("error adding authenticate service to control plane: %w", err) return fmt.Errorf("error adding authenticate service to control plane: %w", err)
} }
@ -183,7 +183,7 @@ func setupAuthenticate(ctx context.Context, src config.Source, controlPlane *con
} }
func setupAuthorize(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) { func setupAuthorize(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) {
svc, err := authorize.New(src.GetConfig()) svc, err := authorize.New(ctx, src.GetConfig())
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating authorize service: %w", err) return nil, fmt.Errorf("error creating authorize service: %w", err)
} }
@ -200,7 +200,7 @@ func setupDataBroker(ctx context.Context,
controlPlane *controlplane.Server, controlPlane *controlplane.Server,
eventsMgr *events.Manager, eventsMgr *events.Manager,
) (*databroker_service.DataBroker, error) { ) (*databroker_service.DataBroker, error) {
svc, err := databroker_service.New(src.GetConfig(), eventsMgr) svc, err := databroker_service.New(ctx, src.GetConfig(), eventsMgr)
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating databroker service: %w", err) return nil, fmt.Errorf("error creating databroker service: %w", err)
} }
@ -223,11 +223,11 @@ func setupProxy(ctx context.Context, src config.Source, controlPlane *controlpla
return nil return nil
} }
svc, err := proxy.New(src.GetConfig()) svc, err := proxy.New(ctx, src.GetConfig())
if err != nil { if err != nil {
return fmt.Errorf("error creating proxy service: %w", err) return fmt.Errorf("error creating proxy service: %w", err)
} }
err = controlPlane.EnableProxy(svc) err = controlPlane.EnableProxy(ctx, svc)
if err != nil { if err != nil {
return fmt.Errorf("error adding proxy service to control plane: %w", err) return fmt.Errorf("error adding proxy service to control plane: %w", err)
} }

View file

@ -185,7 +185,7 @@ func (srv *Server) run(ctx context.Context, cfg *config.Config) error {
// monitor the process so we exit if it prematurely exits // monitor the process so we exit if it prematurely exits
var monitorProcessCtx context.Context var monitorProcessCtx context.Context
monitorProcessCtx, srv.monitorProcessCancel = context.WithCancel(context.Background()) monitorProcessCtx, srv.monitorProcessCancel = context.WithCancel(context.WithoutCancel(ctx))
go srv.monitorProcess(monitorProcessCtx, int32(cmd.Process.Pid)) go srv.monitorProcess(monitorProcessCtx, int32(cmd.Process.Pid))
if srv.resourceMonitor != nil { if srv.resourceMonitor != nil {
@ -251,7 +251,7 @@ func (srv *Server) parseLog(line string) (name string, logLevel string, msg stri
func (srv *Server) handleLogs(ctx context.Context, rc io.ReadCloser) { func (srv *Server) handleLogs(ctx context.Context, rc io.ReadCloser) {
defer rc.Close() defer rc.Close()
l := log.With().Str("service", "envoy").Logger() l := log.Ctx(ctx).With().Str("service", "envoy").Logger()
bo := backoff.NewExponentialBackOff() bo := backoff.NewExponentialBackOff()
s := bufio.NewReader(rc) s := bufio.NewReader(rc)

View file

@ -10,7 +10,7 @@ func (f *FanOut[T]) Publish(ctx context.Context, msg T) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-f.done: case <-f.done:
return ErrStopped return ErrStopped
case f.messages <- msg: case f.messages <- msg:

View file

@ -141,7 +141,7 @@ func WaitForReady(ctx context.Context, cc *grpc.ClientConn, timeout time.Duratio
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-ticker.C: case <-ticker.C:
} }

View file

@ -71,13 +71,13 @@ func (locker *Leaser) Run(ctx context.Context) error {
case err == nil: case err == nil:
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-retryTicker.C: case <-retryTicker.C:
} }
case errors.Is(err, retryableError{}): case errors.Is(err, retryableError{}):
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-time.After(bo.NextBackOff()): case <-time.After(bo.NextBackOff()):
} }
default: default:

View file

@ -169,7 +169,7 @@ func TestLeasers(t *testing.T) {
fn2 := func(ctx context.Context) error { fn2 := func(ctx context.Context) error {
atomic.AddInt64(&counter, 10) atomic.AddInt64(&counter, 10)
<-ctx.Done() <-ctx.Done()
return ctx.Err() return context.Cause(ctx)
} }
leaser := databroker.NewLeasers("TEST", time.Second*30, client, fn1, fn2) leaser := databroker.NewLeasers("TEST", time.Second*30, client, fn1, fn2)
err := leaser.Run(context.Background()) err := leaser.Run(context.Background())

View file

@ -110,7 +110,7 @@ func (r *Reconciler) reconcileLoop(ctx context.Context) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-r.trigger: case <-r.trigger:
} }
} }

View file

@ -24,7 +24,7 @@ func Test_SyncLatestRecords(t *testing.T) {
defer clearTimeout() defer clearTimeout()
cc := testutil.NewGRPCServer(t, func(s *grpc.Server) { cc := testutil.NewGRPCServer(t, func(s *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(s, databroker.New()) databrokerpb.RegisterDataBrokerServiceServer(s, databroker.New(ctx))
}) })
c := databrokerpb.NewDataBrokerServiceClient(cc) c := databrokerpb.NewDataBrokerServiceClient(cc)

View file

@ -72,8 +72,8 @@ type Syncer struct {
} }
// NewSyncer creates a new Syncer. // NewSyncer creates a new Syncer.
func NewSyncer(id string, handler SyncerHandler, options ...SyncerOption) *Syncer { func NewSyncer(ctx context.Context, id string, handler SyncerHandler, options ...SyncerOption) *Syncer {
closeCtx, closeCtxCancel := context.WithCancel(context.Background()) closeCtx, closeCtxCancel := context.WithCancel(context.WithoutCancel(ctx))
bo := backoff.NewExponentialBackOff() bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0 bo.MaxElapsedTime = 0
@ -120,7 +120,7 @@ func (syncer *Syncer) Run(ctx context.Context) error {
log.Ctx(ctx).Error().Err(err).Msg("sync") log.Ctx(ctx).Error().Err(err).Msg("sync")
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-time.After(syncer.backoff.NextBackOff()): case <-time.After(syncer.backoff.NextBackOff()):
} }
} }
@ -133,6 +133,9 @@ func (syncer *Syncer) init(ctx context.Context) error {
Type: syncer.cfg.typeURL, Type: syncer.cfg.typeURL,
}) })
if err != nil { if err != nil {
if status.Code(err) == codes.Canceled && ctx.Err() != nil {
err = fmt.Errorf("%w: %w", err, context.Cause(ctx))
}
return fmt.Errorf("error during initial sync: %w", err) return fmt.Errorf("error during initial sync: %w", err)
} }
syncer.backoff.Reset() syncer.backoff.Reset()
@ -167,6 +170,9 @@ func (syncer *Syncer) sync(ctx context.Context) error {
syncer.serverVersion = 0 syncer.serverVersion = 0
return nil return nil
} else if err != nil { } else if err != nil {
if status.Code(err) == codes.Canceled && ctx.Err() != nil {
err = fmt.Errorf("%w: %w", err, context.Cause(ctx))
}
return fmt.Errorf("error receiving sync record: %w", err) return fmt.Errorf("error receiving sync record: %w", err)
} }

View file

@ -157,7 +157,7 @@ func TestSyncer(t *testing.T) {
clearCh := make(chan struct{}) clearCh := make(chan struct{})
updateCh := make(chan []*Record) updateCh := make(chan []*Record)
syncer := NewSyncer("test", testSyncerHandler{ syncer := NewSyncer(ctx, "test", testSyncerHandler{
getDataBrokerServiceClient: func() DataBrokerServiceClient { getDataBrokerServiceClient: func() DataBrokerServiceClient {
return NewDataBrokerServiceClient(gc) return NewDataBrokerServiceClient(gc)
}, },

View file

@ -124,13 +124,13 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
// wait for initial sync // wait for initial sync
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-clear: case <-clear:
mgr.reset() mgr.reset()
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case msg := <-update: case msg := <-update:
mgr.onUpdateRecords(ctx, msg) mgr.onUpdateRecords(ctx, msg)
} }
@ -150,7 +150,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case <-clear: case <-clear:
mgr.reset() mgr.reset()
case msg := <-update: case msg := <-update:

View file

@ -17,7 +17,7 @@ type dataBrokerSyncer struct {
} }
func newDataBrokerSyncer( func newDataBrokerSyncer(
_ context.Context, ctx context.Context,
cfg *atomicutil.Value[*config], cfg *atomicutil.Value[*config],
update chan<- updateRecordsMessage, update chan<- updateRecordsMessage,
clear chan<- struct{}, clear chan<- struct{},
@ -28,7 +28,7 @@ func newDataBrokerSyncer(
update: update, update: update,
clear: clear, clear: clear,
} }
syncer.syncer = databroker.NewSyncer("identity_manager", syncer) syncer.syncer = databroker.NewSyncer(ctx, "identity_manager", syncer)
return syncer return syncer
} }

View file

@ -16,7 +16,7 @@ type sessionSyncerHandler struct {
} }
func newSessionSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer { func newSessionSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer {
return databroker.NewSyncer("identity_manager/sessions", sessionSyncerHandler{baseCtx: ctx, mgr: mgr}, return databroker.NewSyncer(ctx, "identity_manager/sessions", sessionSyncerHandler{baseCtx: ctx, mgr: mgr},
databroker.WithTypeURL(grpcutil.GetTypeURL(new(session.Session)))) databroker.WithTypeURL(grpcutil.GetTypeURL(new(session.Session))))
} }
@ -50,7 +50,7 @@ type userSyncerHandler struct {
} }
func newUserSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer { func newUserSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer {
return databroker.NewSyncer("identity_manager/users", userSyncerHandler{baseCtx: ctx, mgr: mgr}, return databroker.NewSyncer(ctx, "identity_manager/users", userSyncerHandler{baseCtx: ctx, mgr: mgr},
databroker.WithTypeURL(grpcutil.GetTypeURL(new(user.User)))) databroker.WithTypeURL(grpcutil.GetTypeURL(new(user.User))))
} }

View file

@ -37,14 +37,14 @@ type Backend struct {
} }
// New creates a new Backend. // New creates a new Backend.
func New(dsn string, options ...Option) *Backend { func New(ctx context.Context, dsn string, options ...Option) *Backend {
backend := &Backend{ backend := &Backend{
cfg: getConfig(options...), cfg: getConfig(options...),
dsn: dsn, dsn: dsn,
onRecordChange: signal.New(), onRecordChange: signal.New(),
onServiceChange: signal.New(), onServiceChange: signal.New(),
} }
backend.closeCtx, backend.close = context.WithCancel(context.Background()) backend.closeCtx, backend.close = context.WithCancel(ctx)
go backend.doPeriodically(func(ctx context.Context) error { go backend.doPeriodically(func(ctx context.Context) error {
_, pool, err := backend.init(ctx) _, pool, err := backend.init(ctx)

View file

@ -34,7 +34,7 @@ func TestBackend(t *testing.T) {
defer clearTimeout() defer clearTimeout()
require.NoError(t, testutil.WithTestPostgres(func(dsn string) error { require.NoError(t, testutil.WithTestPostgres(func(dsn string) error {
backend := New(dsn) backend := New(ctx, dsn)
defer backend.Close() defer backend.Close()
t.Run("put", func(t *testing.T) { t.Run("put", func(t *testing.T) {

View file

@ -42,7 +42,7 @@ func TestRegistry(t *testing.T) {
defer clearTimeout() defer clearTimeout()
require.NoError(t, testutil.WithTestPostgres(func(dsn string) error { require.NoError(t, testutil.WithTestPostgres(func(dsn string) error {
backend := New(dsn) backend := New(ctx, dsn)
defer backend.Close() defer backend.Close()
eg, ctx := errgroup.WithContext(ctx) eg, ctx := errgroup.WithContext(ctx)
@ -53,7 +53,7 @@ func TestRegistry(t *testing.T) {
send: func(res *registry.ServiceList) error { send: func(res *registry.ServiceList) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case listResults <- res: case listResults <- res:
} }
return nil return nil
@ -73,7 +73,7 @@ func TestRegistry(t *testing.T) {
eg.Go(func() error { eg.Go(func() error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case res := <-listResults: case res := <-listResults:
testutil.AssertProtoEqual(t, &registry.ServiceList{}, res) testutil.AssertProtoEqual(t, &registry.ServiceList{}, res)
} }
@ -92,7 +92,7 @@ func TestRegistry(t *testing.T) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return context.Cause(ctx)
case res := <-listResults: case res := <-listResults:
testutil.AssertProtoEqual(t, &registry.ServiceList{ testutil.AssertProtoEqual(t, &registry.ServiceList{
Services: []*registry.Service{ Services: []*registry.Service{

View file

@ -32,14 +32,14 @@ func Test_getUserInfoData(t *testing.T) {
defer clearTimeout() defer clearTimeout()
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New()) databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx))
}) })
t.Cleanup(func() { cc.Close() }) t.Cleanup(func() { cc.Close() })
client := databrokerpb.NewDataBrokerServiceClient(cc) client := databrokerpb.NewDataBrokerServiceClient(cc)
opts := testOptions(t) opts := testOptions(t)
proxy, err := New(&config.Config{Options: opts}) proxy, err := New(ctx, &config.Config{Options: opts})
require.NoError(t, err) require.NoError(t, err)
proxy.state.Load().dataBrokerClient = client proxy.state.Load().dataBrokerClient = client

View file

@ -2,6 +2,7 @@ package proxy
import ( import (
"bytes" "bytes"
"context"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -36,7 +37,7 @@ func TestProxy_SignOut(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
opts := testOptions(t) opts := testOptions(t)
p, err := New(&config.Config{Options: opts}) p, err := New(context.Background(), &config.Config{Options: opts})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -129,7 +130,7 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options}) p, err := New(context.Background(), &config.Config{Options: tt.options})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -270,7 +271,7 @@ func TestLoadSessionState(t *testing.T) {
t.Parallel() t.Parallel()
opts := testOptions(t) opts := testOptions(t)
proxy, err := New(&config.Config{Options: opts}) proxy, err := New(context.Background(), &config.Config{Options: opts})
require.NoError(t, err) require.NoError(t, err)
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
@ -285,7 +286,7 @@ func TestLoadSessionState(t *testing.T) {
t.Parallel() t.Parallel()
opts := testOptions(t) opts := testOptions(t)
proxy, err := New(&config.Config{Options: opts}) proxy, err := New(context.Background(), &config.Config{Options: opts})
require.NoError(t, err) require.NoError(t, err)
session := encodeSession(t, opts, &sessions.State{ session := encodeSession(t, opts, &sessions.State{
@ -308,7 +309,7 @@ func TestLoadSessionState(t *testing.T) {
t.Parallel() t.Parallel()
opts := testOptions(t) opts := testOptions(t)
proxy, err := New(&config.Config{Options: opts}) proxy, err := New(context.Background(), &config.Config{Options: opts})
require.NoError(t, err) require.NoError(t, err)
session := encodeSession(t, opts, &sessions.State{ session := encodeSession(t, opts, &sessions.State{

View file

@ -60,8 +60,8 @@ type Proxy struct {
// New takes a Proxy service from options and a validation function. // New takes a Proxy service from options and a validation function.
// Function returns an error if options fail to validate. // Function returns an error if options fail to validate.
func New(cfg *config.Config) (*Proxy, error) { func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
state, err := newProxyStateFromConfig(cfg) state, err := newProxyStateFromConfig(ctx, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -71,7 +71,7 @@ func New(cfg *config.Config) (*Proxy, error) {
currentOptions: config.NewAtomicOptions(), currentOptions: config.NewAtomicOptions(),
currentRouter: atomicutil.NewValue(httputil.NewRouter()), currentRouter: atomicutil.NewValue(httputil.NewRouter()),
} }
p.OnConfigChange(context.Background(), cfg) p.OnConfigChange(ctx, cfg)
p.webauthn = webauthn.New(p.getWebauthnState) p.webauthn = webauthn.New(p.getWebauthnState)
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
@ -87,25 +87,25 @@ func (p *Proxy) Mount(r *mux.Router) {
} }
// OnConfigChange updates internal structures based on config.Options // OnConfigChange updates internal structures based on config.Options
func (p *Proxy) OnConfigChange(_ context.Context, cfg *config.Config) { func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
if p == nil { if p == nil {
return return
} }
p.currentOptions.Store(cfg.Options) p.currentOptions.Store(cfg.Options)
if err := p.setHandlers(cfg.Options); err != nil { if err := p.setHandlers(ctx, cfg.Options); err != nil {
log.Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings") log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
} }
if state, err := newProxyStateFromConfig(cfg); err != nil { if state, err := newProxyStateFromConfig(ctx, cfg); err != nil {
log.Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings") log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings")
} else { } else {
p.state.Store(state) p.state.Store(state)
} }
} }
func (p *Proxy) setHandlers(opts *config.Options) error { func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
if opts.NumPolicies() == 0 { if opts.NumPolicies() == 0 {
log.Info().Msg("proxy: configuration has no policies") log.Ctx(ctx).Info().Msg("proxy: configuration has no policies")
} }
r := httputil.NewRouter() r := httputil.NewRouter()
r.NotFoundHandler = httputil.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) error { r.NotFoundHandler = httputil.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) error {

View file

@ -104,7 +104,7 @@ func TestNew(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := New(&config.Config{Options: tt.opts}) got, err := New(context.Background(), &config.Config{Options: tt.opts})
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -197,7 +197,7 @@ func Test_UpdateOptions(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.originalOptions}) p, err := New(context.Background(), &config.Config{Options: tt.originalOptions})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -31,7 +31,7 @@ type proxyState struct {
authenticateFlow authenticateFlow authenticateFlow authenticateFlow
} }
func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { func newProxyStateFromConfig(ctx context.Context, cfg *config.Config) (*proxyState, error) {
err := ValidateOptions(cfg.Options) err := ValidateOptions(cfg.Options)
if err != nil { if err != nil {
return nil, err return nil, err
@ -57,7 +57,7 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
return nil, err return nil, err
} }
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{ dataBrokerConn, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort, OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID, InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,
@ -71,10 +71,10 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist
if cfg.Options.UseStatelessAuthenticateFlow() { if cfg.Options.UseStatelessAuthenticateFlow() {
state.authenticateFlow, err = authenticateflow.NewStateless( state.authenticateFlow, err = authenticateflow.NewStateless(ctx,
cfg, state.sessionStore, nil, nil, nil) cfg, state.sessionStore, nil, nil, nil)
} else { } else {
state.authenticateFlow, err = authenticateflow.NewStateful(cfg, state.sessionStore) state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, state.sessionStore)
} }
if err != nil { if err != nil {
return nil, err return nil, err