Fix many instances of contexts and loggers not being propagated (#5340)

This also replaces instances where we manually write "return ctx.Err()"
with "return context.Cause(ctx)" which is functionally identical, but
will also correctly propagate cause errors if present.
This commit is contained in:
Joe Kralicky 2024-10-25 14:50:56 -04:00 committed by GitHub
parent e1880ba20f
commit fe31799eb5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
77 changed files with 297 additions and 221 deletions

View file

@ -44,7 +44,7 @@ type Authenticate struct {
}
// New validates and creates a new authenticate service from a set of Options.
func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
func New(ctx context.Context, cfg *config.Config, options ...Option) (*Authenticate, error) {
authenticateConfig := getAuthenticateConfig(options...)
a := &Authenticate{
cfg: authenticateConfig,
@ -54,7 +54,7 @@ func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
a.options.Store(cfg.Options)
state, err := newAuthenticateStateFromConfig(cfg, authenticateConfig)
state, err := newAuthenticateStateFromConfig(ctx, cfg, authenticateConfig)
if err != nil {
return nil, err
}
@ -70,7 +70,7 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
}
a.options.Store(cfg.Options)
if state, err := newAuthenticateStateFromConfig(cfg, a.cfg); err != nil {
if state, err := newAuthenticateStateFromConfig(ctx, cfg, a.cfg); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authenticate: failed to update state")
} else {
a.state.Store(state)

View file

@ -1,6 +1,7 @@
package authenticate
import (
"context"
"testing"
"github.com/pomerium/pomerium/config"
@ -106,7 +107,7 @@ func TestNew(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
_, err := New(&config.Config{Options: tt.opts})
_, err := New(context.Background(), &config.Config{Options: tt.opts})
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return

View file

@ -64,6 +64,7 @@ func newAuthenticateState() *authenticateState {
}
func newAuthenticateStateFromConfig(
ctx context.Context,
cfg *config.Config, authenticateConfig *authenticateConfig,
) (*authenticateState, error) {
err := ValidateOptions(cfg.Options)
@ -145,7 +146,7 @@ func newAuthenticateStateFromConfig(
}
if cfg.Options.UseStatelessAuthenticateFlow() {
state.flow, err = authenticateflow.NewStateless(
state.flow, err = authenticateflow.NewStateless(ctx,
cfg,
cookieStore,
authenticateConfig.getIdentityProvider,
@ -153,7 +154,7 @@ func newAuthenticateStateFromConfig(
authenticateConfig.authEventFn,
)
} else {
state.flow, err = authenticateflow.NewStateful(cfg, cookieStore)
state.flow, err = authenticateflow.NewStateful(ctx, cfg, cookieStore)
}
if err != nil {
return nil, err

View file

@ -40,7 +40,7 @@ type Authorize struct {
}
// New validates and creates a new Authorize service from a set of config options.
func New(cfg *config.Config) (*Authorize, error) {
func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
a := &Authorize{
currentOptions: config.NewAtomicOptions(),
store: store.New(),
@ -48,7 +48,7 @@ func New(cfg *config.Config) (*Authorize, error) {
}
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
state, err := newAuthorizeStateFromConfig(cfg, a.store, nil)
state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, nil)
if err != nil {
return nil, err
}
@ -89,12 +89,13 @@ func validateOptions(o *config.Options) error {
// newPolicyEvaluator returns an policy evaluator.
func newPolicyEvaluator(
ctx context.Context,
opts *config.Options, store *store.Store, previous *evaluator.Evaluator,
) (*evaluator.Evaluator, error) {
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
return int64(opts.NumPolicies())
})
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "authorize")
})
ctx, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator")
@ -150,7 +151,7 @@ func newPolicyEvaluator(
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
currentState := a.state.Load()
a.currentOptions.Store(cfg.Options)
if state, err := newAuthorizeStateFromConfig(cfg, a.store, currentState.evaluator); err != nil {
if state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, currentState.evaluator); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
} else {
a.state.Store(state)

View file

@ -82,7 +82,7 @@ func TestNew(t *testing.T) {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := New(&config.Config{Options: &tt.config})
_, err := New(context.Background(), &config.Config{Options: &tt.config})
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
@ -114,7 +114,7 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
SharedKey: tc.SharedKey,
Policies: tc.Policies,
}
a, err := New(&config.Config{Options: o})
a, err := New(context.Background(), &config.Config{Options: o})
require.NoError(t, err)
require.NotNil(t, a)
@ -185,7 +185,7 @@ func TestNewPolicyEvaluator_addDefaultClientCertificateRule(t *testing.T) {
c.opts.Policies = []config.Policy{{
To: mustParseWeightedURLs(t, "http://example.com"),
}}
e, err := newPolicyEvaluator(c.opts, store, nil)
e, err := newPolicyEvaluator(context.Background(), c.opts, store, nil)
require.NoError(t, err)
r, err := e.Evaluate(context.Background(), &evaluator.Request{

View file

@ -34,7 +34,7 @@ func TestAuthorize_handleResult(t *testing.T) {
t.Cleanup(authnSrv.Close)
opt.AuthenticateURLString = authnSrv.URL
a, err := New(&config.Config{Options: opt})
a, err := New(context.Background(), &config.Config{Options: opt})
require.NoError(t, err)
t.Run("user-unauthenticated", func(t *testing.T) {
@ -129,7 +129,7 @@ func TestAuthorize_okResponse(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(opt)
a.store = store.New()
pe, err := newPolicyEvaluator(opt, a.store, nil)
pe, err := newPolicyEvaluator(context.Background(), opt, a.store, nil)
require.NoError(t, err)
a.state.Load().evaluator = pe
@ -327,7 +327,7 @@ func TestRequireLogin(t *testing.T) {
t.Cleanup(authnSrv.Close)
opt.AuthenticateURLString = authnSrv.URL
a, err := New(&config.Config{Options: opt})
a, err := New(context.Background(), &config.Config{Options: opt})
require.NoError(t, err)
t.Run("accept empty", func(t *testing.T) {

View file

@ -65,7 +65,7 @@ func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) {
t.Cleanup(clearTimeout)
opt := config.NewDefaultOptions()
a, err := New(&config.Config{Options: opt})
a, err := New(context.Background(), &config.Config{Options: opt})
require.NoError(t, err)
s1 := &session.Session{Id: "s1", ExpiresAt: timestamppb.New(time.Now().Add(-time.Second))}

View file

@ -108,7 +108,7 @@ func New(
) (*Evaluator, error) {
cfg := getConfig(options...)
err := updateStore(store, cfg)
err := updateStore(ctx, store, cfg)
if err != nil {
return nil, err
}
@ -325,8 +325,8 @@ func (e *Evaluator) getClientCA(policy *config.Policy) (string, error) {
return string(e.clientCA), nil
}
func updateStore(store *store.Store, cfg *evaluatorConfig) error {
jwk, err := getJWK(cfg)
func updateStore(ctx context.Context, store *store.Store, cfg *evaluatorConfig) error {
jwk, err := getJWK(ctx, cfg)
if err != nil {
return fmt.Errorf("authorize: couldn't create signer: %w", err)
}
@ -341,7 +341,7 @@ func updateStore(store *store.Store, cfg *evaluatorConfig) error {
return nil
}
func getJWK(cfg *evaluatorConfig) (*jose.JSONWebKey, error) {
func getJWK(ctx context.Context, cfg *evaluatorConfig) (*jose.JSONWebKey, error) {
var decodedCert []byte
// if we don't have a signing key, generate one
if len(cfg.SigningKey) == 0 {
@ -361,7 +361,7 @@ func getJWK(cfg *evaluatorConfig) (*jose.JSONWebKey, error) {
if err != nil {
return nil, fmt.Errorf("couldn't generate signing key: %w", err)
}
log.Info().Str("Algorithm", jwk.Algorithm).
log.Ctx(ctx).Info().Str("Algorithm", jwk.Algorithm).
Str("KeyID", jwk.KeyID).
Interface("Public Key", jwk.Public()).
Msg("authorize: signing key")

View file

@ -147,7 +147,8 @@ func NewPolicyEvaluator(
// for each script, create a rego and prepare a query.
for i := range e.queries {
log.Ctx(ctx).Debug().
log.Ctx(ctx).
Trace().
Str("script", e.queries[i].script).
Str("from", configPolicy.From).
Interface("to", configPolicy.To).

View file

@ -33,6 +33,7 @@ type authorizeState struct {
}
func newAuthorizeStateFromConfig(
ctx context.Context,
cfg *config.Config, store *store.Store, previousPolicyEvaluator *evaluator.Evaluator,
) (*authorizeState, error) {
if err := validateOptions(cfg.Options); err != nil {
@ -43,7 +44,7 @@ func newAuthorizeStateFromConfig(
var err error
state.evaluator, err = newPolicyEvaluator(cfg.Options, store, previousPolicyEvaluator)
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousPolicyEvaluator)
if err != nil {
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
}
@ -58,7 +59,7 @@ func newAuthorizeStateFromConfig(
return nil, err
}
cc, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{
cc, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services,
@ -84,9 +85,9 @@ func newAuthorizeStateFromConfig(
}
if cfg.Options.UseStatelessAuthenticateFlow() {
state.authenticateFlow, err = authenticateflow.NewStateless(cfg, nil, nil, nil, nil)
state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, nil, nil, nil, nil)
} else {
state.authenticateFlow, err = authenticateflow.NewStateful(cfg, nil)
state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, nil)
}
if err != nil {
return nil, err

View file

@ -54,7 +54,7 @@ func run(ctx context.Context, configFile string) error {
var src config.Source
src, err := config.NewFileOrEnvironmentSource(configFile, files.FullVersion())
src, err := config.NewFileOrEnvironmentSource(ctx, configFile, files.FullVersion())
if err != nil {
return err
}

View file

@ -103,9 +103,10 @@ type FileOrEnvironmentSource struct {
// NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource.
func NewFileOrEnvironmentSource(
ctx context.Context,
configFile, envoyVersion string,
) (*FileOrEnvironmentSource, error) {
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context {
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("config_file_source", configFile)
})

View file

@ -137,7 +137,7 @@ runtime_flags:
require.NoError(t, err)
var src Source
src, err = NewFileOrEnvironmentSource(configFilePath, "")
src, err = NewFileOrEnvironmentSource(context.Background(), configFilePath, "")
require.NoError(t, err)
src = NewFileWatcherSource(context.Background(), src)

View file

@ -247,7 +247,7 @@ func (b *Builder) buildInternalTransportSocket(
b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName),
},
}
bs, err := getCombinedCertificateAuthority(cfg)
bs, err := getCombinedCertificateAuthority(ctx, cfg)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else {
@ -347,7 +347,7 @@ func (b *Builder) buildPolicyValidationContext(
}
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
} else {
bs, err := getCombinedCertificateAuthority(cfg)
bs, err := getCombinedCertificateAuthority(ctx, cfg)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else {

View file

@ -43,14 +43,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-57394a4e5157303436544830.pem")
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}})
rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}})
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
o1 := config.NewDefaultOptions()
o2 := config.NewDefaultOptions()
o2.CA = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0})
combinedCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{CA: o2.CA}})
combinedCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{CA: o2.CA}})
combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename()
t.Run("insecure", func(t *testing.T) {
@ -522,7 +522,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
func Test_buildCluster(t *testing.T) {
ctx := context.Background()
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}})
rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}})
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
o1 := config.NewDefaultOptions()
t.Run("insecure", func(t *testing.T) {

View file

@ -189,7 +189,7 @@ var rootCABundle struct {
value string
}
func getRootCertificateAuthority() (string, error) {
func getRootCertificateAuthority(ctx context.Context) (string, error) {
rootCABundle.Do(func() {
// from https://www.envoyproxy.io/docs/envoy/latest/intro/arch_overview/security/ssl#arch-overview-ssl-enabling-verification
knownRootLocations := []string{
@ -207,10 +207,10 @@ func getRootCertificateAuthority() (string, error) {
}
}
if rootCABundle.value == "" {
log.Error().Strs("known-locations", knownRootLocations).
log.Ctx(ctx).Error().Strs("known-locations", knownRootLocations).
Msgf("no root certificates were found in any of the known locations")
} else {
log.Info().Msgf("using %s as the system root certificate authority bundle", rootCABundle.value)
log.Ctx(ctx).Info().Msgf("using %s as the system root certificate authority bundle", rootCABundle.value)
}
})
if rootCABundle.value == "" {
@ -219,8 +219,8 @@ func getRootCertificateAuthority() (string, error) {
return rootCABundle.value, nil
}
func getCombinedCertificateAuthority(cfg *config.Config) ([]byte, error) {
rootFile, err := getRootCertificateAuthority()
func getCombinedCertificateAuthority(ctx context.Context, cfg *config.Config) ([]byte, error) {
rootFile, err := getRootCertificateAuthority(ctx)
if err != nil {
return nil, err
}

View file

@ -45,7 +45,7 @@ type DataBroker struct {
}
// New creates a new databroker service.
func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
func New(ctx context.Context, cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
localListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, err
@ -61,8 +61,8 @@ func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
// No metrics handler because we have one in the control plane. Add one
// if we no longer register with that grpc Server
localGRPCServer := grpc.NewServer(
grpc.StreamInterceptor(si),
grpc.UnaryInterceptor(ui),
grpc.ChainStreamInterceptor(log.StreamServerInterceptor(log.Ctx(ctx)), si),
grpc.ChainUnaryInterceptor(log.UnaryServerInterceptor(log.Ctx(ctx)), ui),
)
sharedKey, err := cfg.Options.GetSharedKey()
@ -79,7 +79,7 @@ func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
grpc.WithStatsHandler(clientStatsHandler.Handler),
}
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "databroker").Str("config_source", "bootstrap")
})
localGRPCConnection, err := grpc.DialContext(
@ -91,7 +91,7 @@ func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
return nil, err
}
dataBrokerServer, err := newDataBrokerServer(cfg)
dataBrokerServer, err := newDataBrokerServer(ctx, cfg)
if err != nil {
return nil, err
}

View file

@ -1,6 +1,7 @@
package databroker
import (
"context"
"testing"
"github.com/pomerium/pomerium/config"
@ -20,7 +21,7 @@ func TestNew(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.opts.Provider = "google"
_, err := New(&config.Config{Options: &tt.opts}, events.New())
_, err := New(context.Background(), &config.Config{Options: &tt.opts}, events.New())
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return

View file

@ -23,7 +23,7 @@ type dataBrokerServer struct {
}
// newDataBrokerServer creates a new databroker service server.
func newDataBrokerServer(cfg *config.Config) (*dataBrokerServer, error) {
func newDataBrokerServer(ctx context.Context, cfg *config.Config) (*dataBrokerServer, error) {
srv := &dataBrokerServer{
sharedKey: atomicutil.NewValue([]byte{}),
}
@ -33,7 +33,7 @@ func newDataBrokerServer(cfg *config.Config) (*dataBrokerServer, error) {
return nil, err
}
srv.server = databroker.New(opts...)
srv.server = databroker.New(ctx, opts...)
srv.setKey(cfg)
return srv, nil
}
@ -46,7 +46,7 @@ func (srv *dataBrokerServer) OnConfigChange(ctx context.Context, cfg *config.Con
return
}
srv.server.UpdateConfig(opts...)
srv.server.UpdateConfig(ctx, opts...)
srv.setKey(cfg)
}

View file

@ -29,7 +29,7 @@ var lis *bufconn.Listener
func init() {
lis = bufconn.Listen(bufSize)
s := grpc.NewServer()
internalSrv := internal_databroker.New()
internalSrv := internal_databroker.New(context.Background())
srv := &dataBrokerServer{server: internalSrv, sharedKey: atomicutil.NewValue([]byte{})}
databroker.RegisterDataBrokerServiceServer(s, srv)

View file

@ -163,7 +163,7 @@ func waitForHealthy(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-ticker.C:
}
}

View file

@ -56,7 +56,7 @@ type Stateful struct {
// NewStateful initializes the authentication flow for the given configuration
// and session store.
func NewStateful(cfg *config.Config, sessionStore sessions.SessionStore) (*Stateful, error) {
func NewStateful(ctx context.Context, cfg *config.Config, sessionStore sessions.SessionStore) (*Stateful, error) {
s := &Stateful{
sessionDuration: cfg.Options.CookieExpire,
sessionStore: sessionStore,
@ -88,7 +88,7 @@ func NewStateful(cfg *config.Config, sessionStore sessions.SessionStore) (*State
s.defaultIdentityProviderID = idp.GetId()
}
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(),
dataBrokerConn, err := outboundGRPCConnection.Get(ctx,
&grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID,

View file

@ -69,7 +69,7 @@ func TestStatefulSignIn(t *testing.T) {
tt := tt
t.Run(tt.name, func(t *testing.T) {
sessionStore := &mstore.Store{SaveError: tt.saveError}
flow, err := NewStateful(&config.Config{Options: opts}, sessionStore)
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, sessionStore)
if err != nil {
t.Fatal(err)
}
@ -123,7 +123,7 @@ func TestStatefulAuthenticateSignInURL(t *testing.T) {
opts.AuthenticateURLString = "https://authenticate.example.com"
key := cryptutil.NewKey()
opts.SharedKey = base64.StdEncoding.EncodeToString(key)
flow, err := NewStateful(&config.Config{Options: opts}, nil)
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil)
require.NoError(t, err)
t.Run("NilQueryParams", func(t *testing.T) {
@ -238,7 +238,7 @@ func TestStatefulCallback(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
flow, err := NewStateful(&config.Config{Options: opts}, tt.sessionStore)
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, tt.sessionStore)
if err != nil {
t.Fatal(err)
}
@ -289,7 +289,7 @@ func TestStatefulCallback(t *testing.T) {
func TestStatefulRevokeSession(t *testing.T) {
opts := config.NewDefaultOptions()
flow, err := NewStateful(&config.Config{Options: opts}, nil)
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil)
require.NoError(t, err)
ctrl := gomock.NewController(t)
@ -367,7 +367,7 @@ func TestPersistSession(t *testing.T) {
opts := config.NewDefaultOptions()
opts.CookieExpire = 4 * time.Hour
flow, err := NewStateful(&config.Config{Options: opts}, nil)
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil)
require.NoError(t, err)
ctrl := gomock.NewController(t)

View file

@ -64,6 +64,7 @@ type Stateless struct {
// NewStateless initializes the authentication flow for the given
// configuration, session store, and additional options.
func NewStateless(
ctx context.Context,
cfg *config.Config,
sessionStore sessions.SessionStore,
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error),
@ -131,7 +132,7 @@ func NewStateless(
return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err)
}
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{
dataBrokerConn, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services,

View file

@ -63,11 +63,12 @@ type Manager struct {
}
// New creates a new autocert manager.
func New(src config.Source) (*Manager, error) {
return newManager(context.Background(), src, certmagic.DefaultACME, renewalInterval)
func New(ctx context.Context, src config.Source) (*Manager, error) {
return newManager(ctx, src, certmagic.DefaultACME, renewalInterval)
}
func newManager(ctx context.Context,
func newManager(
ctx context.Context,
src config.Source,
acmeTemplate certmagic.ACMEIssuer,
checkInterval time.Duration,
@ -96,12 +97,13 @@ func newManager(ctx context.Context,
if err != nil {
return nil, err
}
mgr.certmagic = certmagic.New(certmagic.NewCache(certmagic.CacheOptions{
cache := certmagic.NewCache(certmagic.CacheOptions{
GetConfigForCert: func(_ certmagic.Certificate) (*certmagic.Config, error) {
return mgr.certmagic, nil
},
Logger: logger,
}), certmagic.Config{
})
mgr.certmagic = certmagic.New(cache, certmagic.Config{
Logger: logger,
Storage: certmagicStorage,
})

View file

@ -316,7 +316,7 @@ func TestRedirect(t *testing.T) {
},
},
})
_, err = New(src)
_, err = New(context.Background(), src)
if !assert.NoError(t, err) {
return
}

View file

@ -47,7 +47,7 @@ func (l *locker) Lock(ctx context.Context, name string) error {
// wait
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-time.After(lockPollInterval):
}
continue

View file

@ -71,7 +71,7 @@ func (srv *Server) getDataBrokerClient(ctx context.Context) (databrokerpb.DataBr
return nil, err
}
cc, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{
cc, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services,

View file

@ -8,6 +8,7 @@ import (
"github.com/CAFxX/httpcompression"
"github.com/gorilla/mux"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/handlers"
@ -19,7 +20,7 @@ import (
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
)
func (srv *Server) addHTTPMiddleware(root *mux.Router, _ *config.Config) {
func (srv *Server) addHTTPMiddleware(root *mux.Router, logger *zerolog.Logger, _ *config.Config) {
compressor, err := httpcompression.DefaultAdapter()
if err != nil {
panic(err)
@ -28,7 +29,7 @@ func (srv *Server) addHTTPMiddleware(root *mux.Router, _ *config.Config) {
root.Use(compressor)
root.Use(srv.reproxy.Middleware)
root.Use(requestid.HTTPMiddleware())
root.Use(log.NewHandler(log.Logger))
root.Use(log.NewHandler(func() *zerolog.Logger { return logger }))
root.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
log.FromRequest(r).Debug().
Dur("duration", duration).

View file

@ -69,10 +69,16 @@ type Server struct {
}
// NewServer creates a new Server. Listener ports are chosen by the OS.
func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr *events.Manager) (*Server, error) {
func NewServer(
ctx context.Context,
cfg *config.Config,
metricsMgr *config.MetricsManager,
eventsMgr *events.Manager,
) (*Server, error) {
srv := &Server{
metricsMgr: metricsMgr,
EventsMgr: eventsMgr,
filemgr: filemgr.NewManager(),
reproxy: reproxy.New(),
haveSetCapacity: map[string]bool{},
updateConfig: make(chan *config.Config, 1),
@ -80,6 +86,10 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
httpRouter: atomicutil.NewValue(mux.NewRouter()),
}
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("server_name", cfg.Options.Services)
})
var err error
// setup gRPC
@ -95,8 +105,16 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
)
srv.GRPCServer = grpc.NewServer(
grpc.StatsHandler(telemetry.NewGRPCServerStatsHandler(cfg.Options.Services)),
grpc.ChainUnaryInterceptor(requestid.UnaryServerInterceptor(), ui),
grpc.ChainStreamInterceptor(requestid.StreamServerInterceptor(), si),
grpc.ChainUnaryInterceptor(
log.UnaryServerInterceptor(log.Ctx(ctx)),
requestid.UnaryServerInterceptor(),
ui,
),
grpc.ChainStreamInterceptor(
log.StreamServerInterceptor(log.Ctx(ctx)),
requestid.StreamServerInterceptor(),
si,
),
)
reflection.Register(srv.GRPCServer)
srv.registerAccessLogHandlers()
@ -125,7 +143,7 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
return nil, err
}
if err := srv.updateRouter(cfg); err != nil {
if err := srv.updateRouter(ctx, cfg); err != nil {
return nil, err
}
srv.DebugRouter = mux.NewRouter()
@ -141,7 +159,6 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
// metrics
srv.MetricsRouter.Handle("/metrics", srv.metricsMgr)
srv.filemgr = filemgr.NewManager()
srv.filemgr.ClearCache()
srv.Builder = envoyconfig.New(
@ -152,10 +169,6 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
srv.reproxy,
)
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
return c.Str("server_name", cfg.Options.Services)
})
res, err := srv.buildDiscoveryResources(ctx)
if err != nil {
return nil, err
@ -211,7 +224,7 @@ func (srv *Server) Run(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case cfg := <-srv.updateConfig:
err := srv.update(ctx, cfg)
if err != nil {
@ -232,29 +245,29 @@ func (srv *Server) OnConfigChange(ctx context.Context, cfg *config.Config) error
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case srv.updateConfig <- cfg:
}
return nil
}
// EnableAuthenticate enables the authenticate service.
func (srv *Server) EnableAuthenticate(svc Service) error {
func (srv *Server) EnableAuthenticate(ctx context.Context, svc Service) error {
srv.authenticateSvc = svc
return srv.updateRouter(srv.currentConfig.Load())
return srv.updateRouter(ctx, srv.currentConfig.Load())
}
// EnableProxy enables the proxy service.
func (srv *Server) EnableProxy(svc Service) error {
func (srv *Server) EnableProxy(ctx context.Context, svc Service) error {
srv.proxySvc = svc
return srv.updateRouter(srv.currentConfig.Load())
return srv.updateRouter(ctx, srv.currentConfig.Load())
}
func (srv *Server) update(ctx context.Context, cfg *config.Config) error {
ctx, span := trace.StartSpan(ctx, "controlplane.Server.update")
defer span.End()
if err := srv.updateRouter(cfg); err != nil {
if err := srv.updateRouter(ctx, cfg); err != nil {
return err
}
srv.reproxy.Update(ctx, cfg)
@ -267,9 +280,9 @@ func (srv *Server) update(ctx context.Context, cfg *config.Config) error {
return nil
}
func (srv *Server) updateRouter(cfg *config.Config) error {
func (srv *Server) updateRouter(ctx context.Context, cfg *config.Config) error {
httpRouter := mux.NewRouter()
srv.addHTTPMiddleware(httpRouter, cfg)
srv.addHTTPMiddleware(httpRouter, log.Ctx(ctx), cfg)
if err := srv.mountCommonEndpoints(httpRouter, cfg); err != nil {
return err
}

View file

@ -38,7 +38,7 @@ func TestServerHTTP(t *testing.T) {
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
src := config.NewStaticSource(cfg)
srv, err := NewServer(cfg, config.NewMetricsManager(ctx, src), events.New())
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New())
require.NoError(t, err)
go srv.Run(ctx)

View file

@ -1,6 +1,7 @@
package xdsmgr
import (
"context"
"errors"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
@ -19,8 +20,8 @@ var (
routeConfigurationTypeURL = protoutil.GetTypeURL((*envoy_config_route_v3.RouteConfiguration)(nil))
)
func logNACK(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
log.Debug().
func logNACK(ctx context.Context, req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
log.Ctx(ctx).Debug().
Str("type-url", req.GetTypeUrl()).
Any("error-detail", req.GetErrorDetail()).
Msg("xdsmgr: nack")
@ -28,8 +29,8 @@ func logNACK(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
health.ReportError(getHealthCheck(req.GetTypeUrl()), errors.New(req.GetErrorDetail().GetMessage()))
}
func logACK(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
log.Debug().
func logACK(ctx context.Context, req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
log.Ctx(ctx).Debug().
Str("type-url", req.GetTypeUrl()).
Msg("xdsmgr: ack")

View file

@ -113,7 +113,7 @@ func (mgr *Manager) DeltaAggregatedResources(
for _, resource := range mgr.resources[req.GetTypeUrl()] {
state.clientResourceVersions[resource.Name] = resource.Version
}
logNACK(req)
logNACK(ctx, req)
case req.GetResponseNonce() == mgr.nonce:
// an ACK for the last response
// - set the client resource versions to the current resource versions
@ -121,10 +121,11 @@ func (mgr *Manager) DeltaAggregatedResources(
for _, resource := range mgr.resources[req.GetTypeUrl()] {
state.clientResourceVersions[resource.Name] = resource.Version
}
logACK(req)
logACK(ctx, req)
default:
// an ACK for a response that's not the last response
log.Ctx(ctx).Debug().
log.Ctx(ctx).
Debug().
Str("type-url", req.GetTypeUrl()).
Msg("xdsmgr: ack")
}
@ -161,7 +162,7 @@ func (mgr *Manager) DeltaAggregatedResources(
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case incoming <- req:
}
}
@ -173,7 +174,7 @@ func (mgr *Manager) DeltaAggregatedResources(
var typeURLs []string
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case req := <-incoming:
handleDeltaRequest(changeCtx, req)
typeURLs = []string{req.GetTypeUrl()}
@ -193,7 +194,7 @@ func (mgr *Manager) DeltaAggregatedResources(
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case outgoing <- res:
}
}
@ -204,9 +205,10 @@ func (mgr *Manager) DeltaAggregatedResources(
for {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case res := <-outgoing:
log.Ctx(ctx).Debug().
log.Ctx(ctx).
Debug().
Str("type-url", res.GetTypeUrl()).
Int("resource-count", len(res.GetResources())).
Int("removed-resource-count", len(res.GetRemovedResources())).

View file

@ -96,7 +96,7 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
cfg := src.underlyingConfig.Clone()
// start the updater
src.runUpdater(cfg)
src.runUpdater(ctx, cfg)
now = time.Now()
err := src.buildNewConfigLocked(ctx, cfg)
@ -234,7 +234,7 @@ func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, po
cfg.Options.AdditionalPolicies = append(cfg.Options.AdditionalPolicies, additionalPolicies...)
}
func (src *ConfigSource) runUpdater(cfg *config.Config) {
func (src *ConfigSource) runUpdater(ctx context.Context, cfg *config.Config) {
sharedKey, _ := cfg.Options.GetSharedKey()
connectionOptions := &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort,
@ -257,7 +257,6 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
src.cancel = nil
}
ctx := context.Background()
ctx, src.cancel = context.WithCancel(ctx)
cc, err := src.outboundGRPCConnection.Get(ctx, connectionOptions)
@ -268,7 +267,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
client := databroker.NewDataBrokerServiceClient(cc)
syncer := databroker.NewSyncer("databroker", &syncerHandler{
syncer := databroker.NewSyncer(ctx, "databroker", &syncerHandler{
client: client,
src: src,
}, databroker.WithTypeURL(grpcutil.GetTypeURL(new(configpb.Config))),

View file

@ -41,7 +41,7 @@ func TestConfigSource(t *testing.T) {
defer func() { _ = li.Close() }()
_, outboundPort, _ := net.SplitHostPort(li.Addr().String())
dataBrokerServer := New()
dataBrokerServer := New(ctx)
srv := grpc.NewServer()
databroker.RegisterDataBrokerServiceServer(srv, dataBrokerServer)
go func() { _ = srv.Serve(li) }()

View file

@ -28,7 +28,7 @@ func (srv *Server) Report(ctx context.Context, req *registrypb.RegisterRequest)
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Report")
defer span.End()
r, err := srv.getRegistry()
r, err := srv.getRegistry(ctx)
if err != nil {
return nil, err
}
@ -41,7 +41,7 @@ func (srv *Server) List(ctx context.Context, req *registrypb.ListRequest) (*regi
ctx, span := trace.StartSpan(ctx, "databroker.grpc.List")
defer span.End()
r, err := srv.getRegistry()
r, err := srv.getRegistry(ctx)
if err != nil {
return nil, err
}
@ -55,7 +55,7 @@ func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Watch")
defer span.End()
r, err := srv.getRegistry()
r, err := srv.getRegistry(ctx)
if err != nil {
return err
}
@ -66,8 +66,8 @@ func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry
})
}
func (srv *Server) getRegistry() (registry.Interface, error) {
backend, err := srv.getBackend()
func (srv *Server) getRegistry(ctx context.Context) (registry.Interface, error) {
backend, err := srv.getBackend(ctx)
if err != nil {
return nil, err
}
@ -81,7 +81,7 @@ func (srv *Server) getRegistry() (registry.Interface, error) {
r = srv.registry
var err error
if r == nil {
r, err = srv.newRegistryLocked(backend)
r, err = srv.newRegistryLocked(ctx, backend)
srv.registry = r
}
srv.mu.Unlock()
@ -92,9 +92,7 @@ func (srv *Server) getRegistry() (registry.Interface, error) {
return r, nil
}
func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interface, error) {
ctx := context.Background()
func (srv *Server) newRegistryLocked(ctx context.Context, backend storage.Backend) (registry.Interface, error) {
if hasRegistryServer, ok := backend.(interface {
RegistryServer() registrypb.RegistryServer
}); ok {

View file

@ -28,25 +28,26 @@ import (
type Server struct {
cfg *serverConfig
mu sync.RWMutex
backend storage.Backend
registry registry.Interface
mu sync.RWMutex
backend storage.Backend
backendCtx context.Context
registry registry.Interface
}
// New creates a new server.
func New(options ...ServerOption) *Server {
srv := &Server{}
srv.UpdateConfig(options...)
func New(ctx context.Context, options ...ServerOption) *Server {
srv := &Server{
backendCtx: ctx,
}
srv.UpdateConfig(ctx, options...)
return srv
}
// UpdateConfig updates the server with the new options.
func (srv *Server) UpdateConfig(options ...ServerOption) {
func (srv *Server) UpdateConfig(ctx context.Context, options ...ServerOption) {
srv.mu.Lock()
defer srv.mu.Unlock()
ctx := context.TODO()
cfg := newServerConfig(options...)
if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) {
log.Ctx(ctx).Debug().Msg("databroker: no changes detected, re-using existing DBs")
@ -80,7 +81,7 @@ func (srv *Server) AcquireLease(ctx context.Context, req *databroker.AcquireLeas
Dur("duration", req.GetDuration().AsDuration()).
Msg("acquire lease")
db, err := srv.getBackend()
db, err := srv.getBackend(ctx)
if err != nil {
return nil, err
}
@ -107,7 +108,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
Str("id", req.GetId()).
Msg("get")
db, err := srv.getBackend()
db, err := srv.getBackend(ctx)
if err != nil {
return nil, err
}
@ -131,7 +132,7 @@ func (srv *Server) ListTypes(ctx context.Context, _ *emptypb.Empty) (*databroker
defer span.End()
log.Ctx(ctx).Debug().Msg("list types")
db, err := srv.getBackend()
db, err := srv.getBackend(ctx)
if err != nil {
return nil, err
}
@ -156,7 +157,7 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
query := strings.ToLower(req.GetQuery())
db, err := srv.getBackend()
db, err := srv.getBackend(ctx)
if err != nil {
return nil, err
}
@ -217,7 +218,7 @@ func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databr
Msg("put")
}
db, err := srv.getBackend()
db, err := srv.getBackend(ctx)
if err != nil {
return nil, err
}
@ -256,7 +257,7 @@ func (srv *Server) Patch(ctx context.Context, req *databroker.PatchRequest) (*da
Msg("patch")
}
db, err := srv.getBackend()
db, err := srv.getBackend(ctx)
if err != nil {
return nil, err
}
@ -282,7 +283,7 @@ func (srv *Server) ReleaseLease(ctx context.Context, req *databroker.ReleaseLeas
Str("id", req.GetId()).
Msg("release lease")
db, err := srv.getBackend()
db, err := srv.getBackend(ctx)
if err != nil {
return nil, err
}
@ -305,7 +306,7 @@ func (srv *Server) RenewLease(ctx context.Context, req *databroker.RenewLeaseReq
Dur("duration", req.GetDuration().AsDuration()).
Msg("renew lease")
db, err := srv.getBackend()
db, err := srv.getBackend(ctx)
if err != nil {
return nil, err
}
@ -325,7 +326,7 @@ func (srv *Server) SetOptions(ctx context.Context, req *databroker.SetOptionsReq
ctx, span := trace.StartSpan(ctx, "databroker.grpc.SetOptions")
defer span.End()
backend, err := srv.getBackend()
backend, err := srv.getBackend(ctx)
if err != nil {
return nil, err
}
@ -351,12 +352,13 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
ctx, cancel := context.WithCancel(ctx)
defer cancel()
log.Ctx(ctx).Debug().
log.Ctx(ctx).
Debug().
Uint64("server_version", req.GetServerVersion()).
Uint64("record_version", req.GetRecordVersion()).
Msg("sync")
backend, err := srv.getBackend()
backend, err := srv.getBackend(ctx)
if err != nil {
return err
}
@ -392,7 +394,7 @@ func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databrok
Str("type", req.GetType()).
Msg("sync latest")
backend, err := srv.getBackend()
backend, err := srv.getBackend(ctx)
if err != nil {
return err
}
@ -430,7 +432,7 @@ func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databrok
})
}
func (srv *Server) getBackend() (backend storage.Backend, err error) {
func (srv *Server) getBackend(ctx context.Context) (backend storage.Backend, err error) {
// double-checked locking:
// first try the read lock, then re-try with the write lock, and finally create a new backend if nil
srv.mu.RLock()
@ -441,7 +443,7 @@ func (srv *Server) getBackend() (backend storage.Backend, err error) {
backend = srv.backend
var err error
if backend == nil {
backend, err = srv.newBackendLocked()
backend, err = srv.newBackendLocked(ctx)
srv.backend = backend
}
srv.mu.Unlock()
@ -452,18 +454,18 @@ func (srv *Server) getBackend() (backend storage.Backend, err error) {
return backend, nil
}
func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
ctx := context.Background()
func (srv *Server) newBackendLocked(ctx context.Context) (storage.Backend, error) {
switch srv.cfg.storageType {
case config.StorageInMemoryName:
log.Ctx(ctx).Info().Msg("using in-memory store")
log.Ctx(ctx).Info().Msg("initializing new in-memory store")
return inmemory.New(), nil
case config.StoragePostgresName:
log.Ctx(ctx).Info().Msg("using postgres store")
backend = postgres.New(srv.cfg.storageConnectionString)
log.Ctx(ctx).Info().Msg("initializing new postgres store")
// NB: the context passed to postgres.New here is a separate context scoped
// to the lifetime of the server itself. 'ctx' may be a short-lived request
// context, since the backend is lazy-initialized.
return postgres.New(srv.backendCtx, srv.cfg.storageConnectionString), nil
default:
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
}
return backend, nil
}

View file

@ -48,7 +48,8 @@ func (h testSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint
func newServer(cfg *serverConfig) *Server {
return &Server{
cfg: cfg,
cfg: cfg,
backendCtx: context.Background(),
}
}
@ -277,7 +278,7 @@ func TestServer_Sync(t *testing.T) {
updateRecords := make(chan uint64, 10)
client := databroker.NewDataBrokerServiceClient(cc)
syncer := databroker.NewSyncer("TEST", testSyncerHandler{
syncer := databroker.NewSyncer(ctx, "TEST", testSyncerHandler{
getDataBrokerServiceClient: func() databroker.DataBrokerServiceClient {
return client
},
@ -292,12 +293,12 @@ func TestServer_Sync(t *testing.T) {
select {
case <-clearRecords:
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
}
select {
case <-updateRecords:
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
}
@ -313,7 +314,7 @@ func TestServer_Sync(t *testing.T) {
select {
case <-updateRecords:
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
}
return nil

View file

@ -45,7 +45,7 @@ func TestEnabler(t *testing.T) {
started.Add(1)
<-ctx.Done()
stopped.Add(1)
return ctx.Err()
return context.Cause(ctx)
}), true)
time.AfterFunc(time.Millisecond*10, e.Disable)
go e.Run(ctx)

View file

@ -80,6 +80,9 @@ func (watcher *Watcher) initLocked(ctx context.Context) {
if watcher.pollingWatcher == nil {
watcher.pollingWatcher = filenotify.NewPollingWatcher(nil)
context.AfterFunc(ctx, func() {
watcher.pollingWatcher.Close()
})
}
errors := watcher.pollingWatcher.Errors()

10
internal/log/debug.go Normal file
View file

@ -0,0 +1,10 @@
package log
import "sync/atomic"
var (
// Debug option to disable the Zap log shim
DebugDisableZapLogger atomic.Bool
// Debug option to suppress global warnings
DebugDisableGlobalWarnings atomic.Bool
)

View file

@ -52,6 +52,9 @@ func Logger() *zerolog.Logger {
// ZapLogger returns the global zap logger.
func ZapLogger() *zap.Logger {
if DebugDisableZapLogger.Load() {
return zap.NewNop()
}
return zapLogger.Load()
}

View file

@ -1,11 +1,14 @@
package log
import (
"context"
"net"
"net/http"
"time"
"github.com/pomerium/protoutil/streams"
"github.com/rs/zerolog"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
@ -121,3 +124,17 @@ func HeadersHandler(headers []string) func(next http.Handler) http.Handler {
})
}
}
func StreamServerInterceptor(lg *zerolog.Logger) grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
s := streams.NewServerStreamWithContext(ss)
s.SetContext(lg.WithContext(s.Ctx))
return handler(srv, s)
}
}
func UnaryServerInterceptor(lg *zerolog.Logger) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
return handler(lg.WithContext(ctx), req)
}
}

View file

@ -11,6 +11,9 @@ var warnCookieSecretOnce sync.Once
// WarnCookieSecret warns about the cookie secret.
func WarnCookieSecret() {
warnCookieSecretOnce.Do(func() {
if DebugDisableGlobalWarnings.Load() {
return
}
Info().
Msg("using a generated COOKIE_SECRET. " +
"Set the COOKIE_SECRET to avoid users being logged out on restart. " +
@ -23,6 +26,9 @@ var warnNoTLSCertificateOnce syncutil.OnceMap[string]
// WarnNoTLSCertificate warns about no TLS certificate.
func WarnNoTLSCertificate(domain string) {
warnNoTLSCertificateOnce.Do(domain, func() {
if DebugDisableGlobalWarnings.Load() {
return
}
Info().
Str("domain", domain).
Msg("no TLS certificate found for domain, using a self-signed certificate")
@ -34,6 +40,9 @@ var warnWebSocketHTTP1_1Once syncutil.OnceMap[string]
// WarnWebSocketHTTP1_1 warns about falling back to http 1.1 due to web socket support.
func WarnWebSocketHTTP1_1(clusterID string) {
warnWebSocketHTTP1_1Once.Do(clusterID, func() {
if DebugDisableGlobalWarnings.Load() {
return
}
Info().
Str("cluster-id", clusterID).
Msg("forcing http/1.1 due to web socket support")

View file

@ -103,7 +103,7 @@ func makeSelect(
fn: func(ctx context.Context) error {
// unreachable, the context handler will never be called
// as its channel can only be closed
return ctx.Err()
return context.Cause(ctx)
},
ch: reflect.ValueOf(ctx.Done()),
},

View file

@ -22,7 +22,7 @@ func WaitForHealthy(ctx context.Context, client *http.Client, routes []*config.R
healthy++
}
}
return ctx.Err()
return context.Cause(ctx)
}
func checkHealth(ctx context.Context, client *http.Client, addr string) error {

View file

@ -74,7 +74,7 @@ func (svc *Source) updateLoop(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-svc.checkForUpdate:
case <-ticker.C:
}

View file

@ -34,7 +34,7 @@ type source struct {
func (src *source) WaitReady(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-src.ready:
return nil
}

View file

@ -36,7 +36,7 @@ func BuildImportCmd() *cobra.Command {
return fmt.Errorf("no config file provided")
}
log.SetLevel(zerolog.ErrorLevel)
src, err := config.NewFileOrEnvironmentSource(configFile, files.FullVersion())
src, err := config.NewFileOrEnvironmentSource(cmd.Context(), configFile, files.FullVersion())
if err != nil {
return err
}

View file

@ -17,7 +17,7 @@ import (
func (svc *Mux) Watch(ctx context.Context, opts ...WatchOption) error {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-svc.ready:
}

View file

@ -102,7 +102,7 @@ func TestDatabrokerRestart(t *testing.T) {
cl(context.Background(), newConfig())
<-ctx.Done()
require.ErrorIs(t, context.Cause(ctx), controller.ErrBootstrapConfigurationChanged)
return ctx.Err()
return context.Cause(ctx)
}
return nil
})

View file

@ -56,5 +56,5 @@ func (w *LeaseStatus) MonitorLease(ctx context.Context, _ databroker.DataBrokerS
w.v.Store(true)
<-ctx.Done()
w.v.Store(false)
return ctx.Err()
return context.Cause(ctx)
}

View file

@ -37,7 +37,7 @@ func TestUsageReporter(t *testing.T) {
t.Cleanup(cancel)
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New())
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx))
})
t.Cleanup(func() { cc.Close() })

View file

@ -10,7 +10,7 @@ import (
)
func (c *Checker) ConfigSyncer(ctx context.Context) error {
syncer := databroker.NewSyncer("zero-health-check", c, databroker.WithTypeURL(protoutil.GetTypeURL(new(configpb.Config))))
syncer := databroker.NewSyncer(ctx, "zero-health-check", c, databroker.WithTypeURL(protoutil.GetTypeURL(new(configpb.Config))))
return syncer.Run(ctx)
}

View file

@ -35,7 +35,7 @@ func (c *service) SyncLoop(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-c.bundleSyncRequest:
log.Ctx(ctx).Debug().Msg("bundle sync triggered")
err := c.syncBundles(ctx)

View file

@ -105,7 +105,7 @@ func (srv *Telemetry) handleRequests(ctx context.Context) error {
case req := <-requests:
srv.handleRequest(ctx, req)
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
}
}
})

View file

@ -32,7 +32,7 @@ import (
// Run runs the main pomerium application.
func Run(ctx context.Context, src config.Source) error {
_, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Debug().Msgf(s, i...) }))
_, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Ctx(ctx).Debug().Msgf(s, i...) }))
evt := log.Ctx(ctx).Info().
Str("envoy_version", files.FullVersion()).
@ -53,7 +53,7 @@ func Run(ctx context.Context, src config.Source) error {
// trigger changes when underlying files are changed
src = config.NewFileWatcherSource(ctx, src)
src, err = autocert.New(src)
src, err = autocert.New(ctx, src)
if err != nil {
return err
}
@ -71,7 +71,7 @@ func Run(ctx context.Context, src config.Source) error {
cfg := src.GetConfig()
// setup the control plane
controlPlane, err := controlplane.NewServer(cfg, metricsMgr, eventsMgr)
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr)
if err != nil {
return fmt.Errorf("error creating control plane: %w", err)
}
@ -166,11 +166,11 @@ func setupAuthenticate(ctx context.Context, src config.Source, controlPlane *con
return nil
}
svc, err := authenticate.New(src.GetConfig())
svc, err := authenticate.New(ctx, src.GetConfig())
if err != nil {
return fmt.Errorf("error creating authenticate service: %w", err)
}
err = controlPlane.EnableAuthenticate(svc)
err = controlPlane.EnableAuthenticate(ctx, svc)
if err != nil {
return fmt.Errorf("error adding authenticate service to control plane: %w", err)
}
@ -183,7 +183,7 @@ func setupAuthenticate(ctx context.Context, src config.Source, controlPlane *con
}
func setupAuthorize(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) {
svc, err := authorize.New(src.GetConfig())
svc, err := authorize.New(ctx, src.GetConfig())
if err != nil {
return nil, fmt.Errorf("error creating authorize service: %w", err)
}
@ -200,7 +200,7 @@ func setupDataBroker(ctx context.Context,
controlPlane *controlplane.Server,
eventsMgr *events.Manager,
) (*databroker_service.DataBroker, error) {
svc, err := databroker_service.New(src.GetConfig(), eventsMgr)
svc, err := databroker_service.New(ctx, src.GetConfig(), eventsMgr)
if err != nil {
return nil, fmt.Errorf("error creating databroker service: %w", err)
}
@ -223,11 +223,11 @@ func setupProxy(ctx context.Context, src config.Source, controlPlane *controlpla
return nil
}
svc, err := proxy.New(src.GetConfig())
svc, err := proxy.New(ctx, src.GetConfig())
if err != nil {
return fmt.Errorf("error creating proxy service: %w", err)
}
err = controlPlane.EnableProxy(svc)
err = controlPlane.EnableProxy(ctx, svc)
if err != nil {
return fmt.Errorf("error adding proxy service to control plane: %w", err)
}

View file

@ -185,7 +185,7 @@ func (srv *Server) run(ctx context.Context, cfg *config.Config) error {
// monitor the process so we exit if it prematurely exits
var monitorProcessCtx context.Context
monitorProcessCtx, srv.monitorProcessCancel = context.WithCancel(context.Background())
monitorProcessCtx, srv.monitorProcessCancel = context.WithCancel(context.WithoutCancel(ctx))
go srv.monitorProcess(monitorProcessCtx, int32(cmd.Process.Pid))
if srv.resourceMonitor != nil {
@ -251,7 +251,7 @@ func (srv *Server) parseLog(line string) (name string, logLevel string, msg stri
func (srv *Server) handleLogs(ctx context.Context, rc io.ReadCloser) {
defer rc.Close()
l := log.With().Str("service", "envoy").Logger()
l := log.Ctx(ctx).With().Str("service", "envoy").Logger()
bo := backoff.NewExponentialBackOff()
s := bufio.NewReader(rc)

View file

@ -10,7 +10,7 @@ func (f *FanOut[T]) Publish(ctx context.Context, msg T) error {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-f.done:
return ErrStopped
case f.messages <- msg:

View file

@ -141,7 +141,7 @@ func WaitForReady(ctx context.Context, cc *grpc.ClientConn, timeout time.Duratio
for {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-ticker.C:
}

View file

@ -71,13 +71,13 @@ func (locker *Leaser) Run(ctx context.Context) error {
case err == nil:
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-retryTicker.C:
}
case errors.Is(err, retryableError{}):
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-time.After(bo.NextBackOff()):
}
default:

View file

@ -169,7 +169,7 @@ func TestLeasers(t *testing.T) {
fn2 := func(ctx context.Context) error {
atomic.AddInt64(&counter, 10)
<-ctx.Done()
return ctx.Err()
return context.Cause(ctx)
}
leaser := databroker.NewLeasers("TEST", time.Second*30, client, fn1, fn2)
err := leaser.Run(context.Background())

View file

@ -110,7 +110,7 @@ func (r *Reconciler) reconcileLoop(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-r.trigger:
}
}

View file

@ -24,7 +24,7 @@ func Test_SyncLatestRecords(t *testing.T) {
defer clearTimeout()
cc := testutil.NewGRPCServer(t, func(s *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(s, databroker.New())
databrokerpb.RegisterDataBrokerServiceServer(s, databroker.New(ctx))
})
c := databrokerpb.NewDataBrokerServiceClient(cc)

View file

@ -72,8 +72,8 @@ type Syncer struct {
}
// NewSyncer creates a new Syncer.
func NewSyncer(id string, handler SyncerHandler, options ...SyncerOption) *Syncer {
closeCtx, closeCtxCancel := context.WithCancel(context.Background())
func NewSyncer(ctx context.Context, id string, handler SyncerHandler, options ...SyncerOption) *Syncer {
closeCtx, closeCtxCancel := context.WithCancel(context.WithoutCancel(ctx))
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
@ -120,7 +120,7 @@ func (syncer *Syncer) Run(ctx context.Context) error {
log.Ctx(ctx).Error().Err(err).Msg("sync")
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-time.After(syncer.backoff.NextBackOff()):
}
}
@ -133,6 +133,9 @@ func (syncer *Syncer) init(ctx context.Context) error {
Type: syncer.cfg.typeURL,
})
if err != nil {
if status.Code(err) == codes.Canceled && ctx.Err() != nil {
err = fmt.Errorf("%w: %w", err, context.Cause(ctx))
}
return fmt.Errorf("error during initial sync: %w", err)
}
syncer.backoff.Reset()
@ -167,6 +170,9 @@ func (syncer *Syncer) sync(ctx context.Context) error {
syncer.serverVersion = 0
return nil
} else if err != nil {
if status.Code(err) == codes.Canceled && ctx.Err() != nil {
err = fmt.Errorf("%w: %w", err, context.Cause(ctx))
}
return fmt.Errorf("error receiving sync record: %w", err)
}

View file

@ -157,7 +157,7 @@ func TestSyncer(t *testing.T) {
clearCh := make(chan struct{})
updateCh := make(chan []*Record)
syncer := NewSyncer("test", testSyncerHandler{
syncer := NewSyncer(ctx, "test", testSyncerHandler{
getDataBrokerServiceClient: func() DataBrokerServiceClient {
return NewDataBrokerServiceClient(gc)
},

View file

@ -124,13 +124,13 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
// wait for initial sync
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-clear:
mgr.reset()
}
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case msg := <-update:
mgr.onUpdateRecords(ctx, msg)
}
@ -150,7 +150,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
for {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-clear:
mgr.reset()
case msg := <-update:

View file

@ -17,7 +17,7 @@ type dataBrokerSyncer struct {
}
func newDataBrokerSyncer(
_ context.Context,
ctx context.Context,
cfg *atomicutil.Value[*config],
update chan<- updateRecordsMessage,
clear chan<- struct{},
@ -28,7 +28,7 @@ func newDataBrokerSyncer(
update: update,
clear: clear,
}
syncer.syncer = databroker.NewSyncer("identity_manager", syncer)
syncer.syncer = databroker.NewSyncer(ctx, "identity_manager", syncer)
return syncer
}

View file

@ -16,7 +16,7 @@ type sessionSyncerHandler struct {
}
func newSessionSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer {
return databroker.NewSyncer("identity_manager/sessions", sessionSyncerHandler{baseCtx: ctx, mgr: mgr},
return databroker.NewSyncer(ctx, "identity_manager/sessions", sessionSyncerHandler{baseCtx: ctx, mgr: mgr},
databroker.WithTypeURL(grpcutil.GetTypeURL(new(session.Session))))
}
@ -50,7 +50,7 @@ type userSyncerHandler struct {
}
func newUserSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer {
return databroker.NewSyncer("identity_manager/users", userSyncerHandler{baseCtx: ctx, mgr: mgr},
return databroker.NewSyncer(ctx, "identity_manager/users", userSyncerHandler{baseCtx: ctx, mgr: mgr},
databroker.WithTypeURL(grpcutil.GetTypeURL(new(user.User))))
}

View file

@ -37,14 +37,14 @@ type Backend struct {
}
// New creates a new Backend.
func New(dsn string, options ...Option) *Backend {
func New(ctx context.Context, dsn string, options ...Option) *Backend {
backend := &Backend{
cfg: getConfig(options...),
dsn: dsn,
onRecordChange: signal.New(),
onServiceChange: signal.New(),
}
backend.closeCtx, backend.close = context.WithCancel(context.Background())
backend.closeCtx, backend.close = context.WithCancel(ctx)
go backend.doPeriodically(func(ctx context.Context) error {
_, pool, err := backend.init(ctx)

View file

@ -34,7 +34,7 @@ func TestBackend(t *testing.T) {
defer clearTimeout()
require.NoError(t, testutil.WithTestPostgres(func(dsn string) error {
backend := New(dsn)
backend := New(ctx, dsn)
defer backend.Close()
t.Run("put", func(t *testing.T) {

View file

@ -42,7 +42,7 @@ func TestRegistry(t *testing.T) {
defer clearTimeout()
require.NoError(t, testutil.WithTestPostgres(func(dsn string) error {
backend := New(dsn)
backend := New(ctx, dsn)
defer backend.Close()
eg, ctx := errgroup.WithContext(ctx)
@ -53,7 +53,7 @@ func TestRegistry(t *testing.T) {
send: func(res *registry.ServiceList) error {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case listResults <- res:
}
return nil
@ -73,7 +73,7 @@ func TestRegistry(t *testing.T) {
eg.Go(func() error {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case res := <-listResults:
testutil.AssertProtoEqual(t, &registry.ServiceList{}, res)
}
@ -92,7 +92,7 @@ func TestRegistry(t *testing.T) {
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case res := <-listResults:
testutil.AssertProtoEqual(t, &registry.ServiceList{
Services: []*registry.Service{

View file

@ -32,14 +32,14 @@ func Test_getUserInfoData(t *testing.T) {
defer clearTimeout()
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New())
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx))
})
t.Cleanup(func() { cc.Close() })
client := databrokerpb.NewDataBrokerServiceClient(cc)
opts := testOptions(t)
proxy, err := New(&config.Config{Options: opts})
proxy, err := New(ctx, &config.Config{Options: opts})
require.NoError(t, err)
proxy.state.Load().dataBrokerClient = client

View file

@ -2,6 +2,7 @@ package proxy
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
@ -36,7 +37,7 @@ func TestProxy_SignOut(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := testOptions(t)
p, err := New(&config.Config{Options: opts})
p, err := New(context.Background(), &config.Config{Options: opts})
if err != nil {
t.Fatal(err)
}
@ -129,7 +130,7 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options})
p, err := New(context.Background(), &config.Config{Options: tt.options})
if err != nil {
t.Fatal(err)
}
@ -270,7 +271,7 @@ func TestLoadSessionState(t *testing.T) {
t.Parallel()
opts := testOptions(t)
proxy, err := New(&config.Config{Options: opts})
proxy, err := New(context.Background(), &config.Config{Options: opts})
require.NoError(t, err)
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
@ -285,7 +286,7 @@ func TestLoadSessionState(t *testing.T) {
t.Parallel()
opts := testOptions(t)
proxy, err := New(&config.Config{Options: opts})
proxy, err := New(context.Background(), &config.Config{Options: opts})
require.NoError(t, err)
session := encodeSession(t, opts, &sessions.State{
@ -308,7 +309,7 @@ func TestLoadSessionState(t *testing.T) {
t.Parallel()
opts := testOptions(t)
proxy, err := New(&config.Config{Options: opts})
proxy, err := New(context.Background(), &config.Config{Options: opts})
require.NoError(t, err)
session := encodeSession(t, opts, &sessions.State{

View file

@ -60,8 +60,8 @@ type Proxy struct {
// New takes a Proxy service from options and a validation function.
// Function returns an error if options fail to validate.
func New(cfg *config.Config) (*Proxy, error) {
state, err := newProxyStateFromConfig(cfg)
func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
state, err := newProxyStateFromConfig(ctx, cfg)
if err != nil {
return nil, err
}
@ -71,7 +71,7 @@ func New(cfg *config.Config) (*Proxy, error) {
currentOptions: config.NewAtomicOptions(),
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
}
p.OnConfigChange(context.Background(), cfg)
p.OnConfigChange(ctx, cfg)
p.webauthn = webauthn.New(p.getWebauthnState)
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
@ -87,25 +87,25 @@ func (p *Proxy) Mount(r *mux.Router) {
}
// OnConfigChange updates internal structures based on config.Options
func (p *Proxy) OnConfigChange(_ context.Context, cfg *config.Config) {
func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
if p == nil {
return
}
p.currentOptions.Store(cfg.Options)
if err := p.setHandlers(cfg.Options); err != nil {
log.Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
if err := p.setHandlers(ctx, cfg.Options); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
}
if state, err := newProxyStateFromConfig(cfg); err != nil {
log.Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings")
if state, err := newProxyStateFromConfig(ctx, cfg); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings")
} else {
p.state.Store(state)
}
}
func (p *Proxy) setHandlers(opts *config.Options) error {
func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
if opts.NumPolicies() == 0 {
log.Info().Msg("proxy: configuration has no policies")
log.Ctx(ctx).Info().Msg("proxy: configuration has no policies")
}
r := httputil.NewRouter()
r.NotFoundHandler = httputil.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) error {

View file

@ -104,7 +104,7 @@ func TestNew(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := New(&config.Config{Options: tt.opts})
got, err := New(context.Background(), &config.Config{Options: tt.opts})
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
@ -197,7 +197,7 @@ func Test_UpdateOptions(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.originalOptions})
p, err := New(context.Background(), &config.Config{Options: tt.originalOptions})
if err != nil {
t.Fatal(err)
}

View file

@ -31,7 +31,7 @@ type proxyState struct {
authenticateFlow authenticateFlow
}
func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
func newProxyStateFromConfig(ctx context.Context, cfg *config.Config) (*proxyState, error) {
err := ValidateOptions(cfg.Options)
if err != nil {
return nil, err
@ -57,7 +57,7 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
return nil, err
}
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{
dataBrokerConn, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services,
@ -71,10 +71,10 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist
if cfg.Options.UseStatelessAuthenticateFlow() {
state.authenticateFlow, err = authenticateflow.NewStateless(
state.authenticateFlow, err = authenticateflow.NewStateless(ctx,
cfg, state.sessionStore, nil, nil, nil)
} else {
state.authenticateFlow, err = authenticateflow.NewStateful(cfg, state.sessionStore)
state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, state.sessionStore)
}
if err != nil {
return nil, err