mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-05 05:16:04 +02:00
Fix many instances of contexts and loggers not being propagated (#5340)
This also replaces instances where we manually write "return ctx.Err()" with "return context.Cause(ctx)" which is functionally identical, but will also correctly propagate cause errors if present.
This commit is contained in:
parent
e1880ba20f
commit
fe31799eb5
77 changed files with 297 additions and 221 deletions
|
@ -44,7 +44,7 @@ type Authenticate struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New validates and creates a new authenticate service from a set of Options.
|
// New validates and creates a new authenticate service from a set of Options.
|
||||||
func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
|
func New(ctx context.Context, cfg *config.Config, options ...Option) (*Authenticate, error) {
|
||||||
authenticateConfig := getAuthenticateConfig(options...)
|
authenticateConfig := getAuthenticateConfig(options...)
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
cfg: authenticateConfig,
|
cfg: authenticateConfig,
|
||||||
|
@ -54,7 +54,7 @@ func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
|
||||||
|
|
||||||
a.options.Store(cfg.Options)
|
a.options.Store(cfg.Options)
|
||||||
|
|
||||||
state, err := newAuthenticateStateFromConfig(cfg, authenticateConfig)
|
state, err := newAuthenticateStateFromConfig(ctx, cfg, authenticateConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -70,7 +70,7 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
}
|
}
|
||||||
|
|
||||||
a.options.Store(cfg.Options)
|
a.options.Store(cfg.Options)
|
||||||
if state, err := newAuthenticateStateFromConfig(cfg, a.cfg); err != nil {
|
if state, err := newAuthenticateStateFromConfig(ctx, cfg, a.cfg); err != nil {
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("authenticate: failed to update state")
|
log.Ctx(ctx).Error().Err(err).Msg("authenticate: failed to update state")
|
||||||
} else {
|
} else {
|
||||||
a.state.Store(state)
|
a.state.Store(state)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package authenticate
|
package authenticate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
@ -106,7 +107,7 @@ func TestNew(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
tt := tt
|
tt := tt
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
_, err := New(&config.Config{Options: tt.opts})
|
_, err := New(context.Background(), &config.Config{Options: tt.opts})
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
|
|
@ -64,6 +64,7 @@ func newAuthenticateState() *authenticateState {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAuthenticateStateFromConfig(
|
func newAuthenticateStateFromConfig(
|
||||||
|
ctx context.Context,
|
||||||
cfg *config.Config, authenticateConfig *authenticateConfig,
|
cfg *config.Config, authenticateConfig *authenticateConfig,
|
||||||
) (*authenticateState, error) {
|
) (*authenticateState, error) {
|
||||||
err := ValidateOptions(cfg.Options)
|
err := ValidateOptions(cfg.Options)
|
||||||
|
@ -145,7 +146,7 @@ func newAuthenticateStateFromConfig(
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Options.UseStatelessAuthenticateFlow() {
|
if cfg.Options.UseStatelessAuthenticateFlow() {
|
||||||
state.flow, err = authenticateflow.NewStateless(
|
state.flow, err = authenticateflow.NewStateless(ctx,
|
||||||
cfg,
|
cfg,
|
||||||
cookieStore,
|
cookieStore,
|
||||||
authenticateConfig.getIdentityProvider,
|
authenticateConfig.getIdentityProvider,
|
||||||
|
@ -153,7 +154,7 @@ func newAuthenticateStateFromConfig(
|
||||||
authenticateConfig.authEventFn,
|
authenticateConfig.authEventFn,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
state.flow, err = authenticateflow.NewStateful(cfg, cookieStore)
|
state.flow, err = authenticateflow.NewStateful(ctx, cfg, cookieStore)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -40,7 +40,7 @@ type Authorize struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New validates and creates a new Authorize service from a set of config options.
|
// New validates and creates a new Authorize service from a set of config options.
|
||||||
func New(cfg *config.Config) (*Authorize, error) {
|
func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
||||||
a := &Authorize{
|
a := &Authorize{
|
||||||
currentOptions: config.NewAtomicOptions(),
|
currentOptions: config.NewAtomicOptions(),
|
||||||
store: store.New(),
|
store: store.New(),
|
||||||
|
@ -48,7 +48,7 @@ func New(cfg *config.Config) (*Authorize, error) {
|
||||||
}
|
}
|
||||||
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
||||||
|
|
||||||
state, err := newAuthorizeStateFromConfig(cfg, a.store, nil)
|
state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -89,12 +89,13 @@ func validateOptions(o *config.Options) error {
|
||||||
|
|
||||||
// newPolicyEvaluator returns an policy evaluator.
|
// newPolicyEvaluator returns an policy evaluator.
|
||||||
func newPolicyEvaluator(
|
func newPolicyEvaluator(
|
||||||
|
ctx context.Context,
|
||||||
opts *config.Options, store *store.Store, previous *evaluator.Evaluator,
|
opts *config.Options, store *store.Store, previous *evaluator.Evaluator,
|
||||||
) (*evaluator.Evaluator, error) {
|
) (*evaluator.Evaluator, error) {
|
||||||
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
|
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
|
||||||
return int64(opts.NumPolicies())
|
return int64(opts.NumPolicies())
|
||||||
})
|
})
|
||||||
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
|
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||||
return c.Str("service", "authorize")
|
return c.Str("service", "authorize")
|
||||||
})
|
})
|
||||||
ctx, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator")
|
ctx, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator")
|
||||||
|
@ -150,7 +151,7 @@ func newPolicyEvaluator(
|
||||||
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
currentState := a.state.Load()
|
currentState := a.state.Load()
|
||||||
a.currentOptions.Store(cfg.Options)
|
a.currentOptions.Store(cfg.Options)
|
||||||
if state, err := newAuthorizeStateFromConfig(cfg, a.store, currentState.evaluator); err != nil {
|
if state, err := newAuthorizeStateFromConfig(ctx, cfg, a.store, currentState.evaluator); err != nil {
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
|
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
|
||||||
} else {
|
} else {
|
||||||
a.state.Store(state)
|
a.state.Store(state)
|
||||||
|
|
|
@ -82,7 +82,7 @@ func TestNew(t *testing.T) {
|
||||||
tt := tt
|
tt := tt
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
_, err := New(&config.Config{Options: &tt.config})
|
_, err := New(context.Background(), &config.Config{Options: &tt.config})
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -114,7 +114,7 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
|
||||||
SharedKey: tc.SharedKey,
|
SharedKey: tc.SharedKey,
|
||||||
Policies: tc.Policies,
|
Policies: tc.Policies,
|
||||||
}
|
}
|
||||||
a, err := New(&config.Config{Options: o})
|
a, err := New(context.Background(), &config.Config{Options: o})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, a)
|
require.NotNil(t, a)
|
||||||
|
|
||||||
|
@ -185,7 +185,7 @@ func TestNewPolicyEvaluator_addDefaultClientCertificateRule(t *testing.T) {
|
||||||
c.opts.Policies = []config.Policy{{
|
c.opts.Policies = []config.Policy{{
|
||||||
To: mustParseWeightedURLs(t, "http://example.com"),
|
To: mustParseWeightedURLs(t, "http://example.com"),
|
||||||
}}
|
}}
|
||||||
e, err := newPolicyEvaluator(c.opts, store, nil)
|
e, err := newPolicyEvaluator(context.Background(), c.opts, store, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
r, err := e.Evaluate(context.Background(), &evaluator.Request{
|
r, err := e.Evaluate(context.Background(), &evaluator.Request{
|
||||||
|
|
|
@ -34,7 +34,7 @@ func TestAuthorize_handleResult(t *testing.T) {
|
||||||
t.Cleanup(authnSrv.Close)
|
t.Cleanup(authnSrv.Close)
|
||||||
opt.AuthenticateURLString = authnSrv.URL
|
opt.AuthenticateURLString = authnSrv.URL
|
||||||
|
|
||||||
a, err := New(&config.Config{Options: opt})
|
a, err := New(context.Background(), &config.Config{Options: opt})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Run("user-unauthenticated", func(t *testing.T) {
|
t.Run("user-unauthenticated", func(t *testing.T) {
|
||||||
|
@ -129,7 +129,7 @@ func TestAuthorize_okResponse(t *testing.T) {
|
||||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||||
a.currentOptions.Store(opt)
|
a.currentOptions.Store(opt)
|
||||||
a.store = store.New()
|
a.store = store.New()
|
||||||
pe, err := newPolicyEvaluator(opt, a.store, nil)
|
pe, err := newPolicyEvaluator(context.Background(), opt, a.store, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
a.state.Load().evaluator = pe
|
a.state.Load().evaluator = pe
|
||||||
|
|
||||||
|
@ -327,7 +327,7 @@ func TestRequireLogin(t *testing.T) {
|
||||||
t.Cleanup(authnSrv.Close)
|
t.Cleanup(authnSrv.Close)
|
||||||
opt.AuthenticateURLString = authnSrv.URL
|
opt.AuthenticateURLString = authnSrv.URL
|
||||||
|
|
||||||
a, err := New(&config.Config{Options: opt})
|
a, err := New(context.Background(), &config.Config{Options: opt})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Run("accept empty", func(t *testing.T) {
|
t.Run("accept empty", func(t *testing.T) {
|
||||||
|
|
|
@ -65,7 +65,7 @@ func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) {
|
||||||
t.Cleanup(clearTimeout)
|
t.Cleanup(clearTimeout)
|
||||||
|
|
||||||
opt := config.NewDefaultOptions()
|
opt := config.NewDefaultOptions()
|
||||||
a, err := New(&config.Config{Options: opt})
|
a, err := New(context.Background(), &config.Config{Options: opt})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
s1 := &session.Session{Id: "s1", ExpiresAt: timestamppb.New(time.Now().Add(-time.Second))}
|
s1 := &session.Session{Id: "s1", ExpiresAt: timestamppb.New(time.Now().Add(-time.Second))}
|
||||||
|
|
|
@ -108,7 +108,7 @@ func New(
|
||||||
) (*Evaluator, error) {
|
) (*Evaluator, error) {
|
||||||
cfg := getConfig(options...)
|
cfg := getConfig(options...)
|
||||||
|
|
||||||
err := updateStore(store, cfg)
|
err := updateStore(ctx, store, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -325,8 +325,8 @@ func (e *Evaluator) getClientCA(policy *config.Policy) (string, error) {
|
||||||
return string(e.clientCA), nil
|
return string(e.clientCA), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateStore(store *store.Store, cfg *evaluatorConfig) error {
|
func updateStore(ctx context.Context, store *store.Store, cfg *evaluatorConfig) error {
|
||||||
jwk, err := getJWK(cfg)
|
jwk, err := getJWK(ctx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("authorize: couldn't create signer: %w", err)
|
return fmt.Errorf("authorize: couldn't create signer: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -341,7 +341,7 @@ func updateStore(store *store.Store, cfg *evaluatorConfig) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getJWK(cfg *evaluatorConfig) (*jose.JSONWebKey, error) {
|
func getJWK(ctx context.Context, cfg *evaluatorConfig) (*jose.JSONWebKey, error) {
|
||||||
var decodedCert []byte
|
var decodedCert []byte
|
||||||
// if we don't have a signing key, generate one
|
// if we don't have a signing key, generate one
|
||||||
if len(cfg.SigningKey) == 0 {
|
if len(cfg.SigningKey) == 0 {
|
||||||
|
@ -361,7 +361,7 @@ func getJWK(cfg *evaluatorConfig) (*jose.JSONWebKey, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("couldn't generate signing key: %w", err)
|
return nil, fmt.Errorf("couldn't generate signing key: %w", err)
|
||||||
}
|
}
|
||||||
log.Info().Str("Algorithm", jwk.Algorithm).
|
log.Ctx(ctx).Info().Str("Algorithm", jwk.Algorithm).
|
||||||
Str("KeyID", jwk.KeyID).
|
Str("KeyID", jwk.KeyID).
|
||||||
Interface("Public Key", jwk.Public()).
|
Interface("Public Key", jwk.Public()).
|
||||||
Msg("authorize: signing key")
|
Msg("authorize: signing key")
|
||||||
|
|
|
@ -147,7 +147,8 @@ func NewPolicyEvaluator(
|
||||||
|
|
||||||
// for each script, create a rego and prepare a query.
|
// for each script, create a rego and prepare a query.
|
||||||
for i := range e.queries {
|
for i := range e.queries {
|
||||||
log.Ctx(ctx).Debug().
|
log.Ctx(ctx).
|
||||||
|
Trace().
|
||||||
Str("script", e.queries[i].script).
|
Str("script", e.queries[i].script).
|
||||||
Str("from", configPolicy.From).
|
Str("from", configPolicy.From).
|
||||||
Interface("to", configPolicy.To).
|
Interface("to", configPolicy.To).
|
||||||
|
|
|
@ -33,6 +33,7 @@ type authorizeState struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAuthorizeStateFromConfig(
|
func newAuthorizeStateFromConfig(
|
||||||
|
ctx context.Context,
|
||||||
cfg *config.Config, store *store.Store, previousPolicyEvaluator *evaluator.Evaluator,
|
cfg *config.Config, store *store.Store, previousPolicyEvaluator *evaluator.Evaluator,
|
||||||
) (*authorizeState, error) {
|
) (*authorizeState, error) {
|
||||||
if err := validateOptions(cfg.Options); err != nil {
|
if err := validateOptions(cfg.Options); err != nil {
|
||||||
|
@ -43,7 +44,7 @@ func newAuthorizeStateFromConfig(
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
state.evaluator, err = newPolicyEvaluator(cfg.Options, store, previousPolicyEvaluator)
|
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousPolicyEvaluator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
|
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -58,7 +59,7 @@ func newAuthorizeStateFromConfig(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cc, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{
|
cc, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
|
||||||
OutboundPort: cfg.OutboundPort,
|
OutboundPort: cfg.OutboundPort,
|
||||||
InstallationID: cfg.Options.InstallationID,
|
InstallationID: cfg.Options.InstallationID,
|
||||||
ServiceName: cfg.Options.Services,
|
ServiceName: cfg.Options.Services,
|
||||||
|
@ -84,9 +85,9 @@ func newAuthorizeStateFromConfig(
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Options.UseStatelessAuthenticateFlow() {
|
if cfg.Options.UseStatelessAuthenticateFlow() {
|
||||||
state.authenticateFlow, err = authenticateflow.NewStateless(cfg, nil, nil, nil, nil)
|
state.authenticateFlow, err = authenticateflow.NewStateless(ctx, cfg, nil, nil, nil, nil)
|
||||||
} else {
|
} else {
|
||||||
state.authenticateFlow, err = authenticateflow.NewStateful(cfg, nil)
|
state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, nil)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -54,7 +54,7 @@ func run(ctx context.Context, configFile string) error {
|
||||||
|
|
||||||
var src config.Source
|
var src config.Source
|
||||||
|
|
||||||
src, err := config.NewFileOrEnvironmentSource(configFile, files.FullVersion())
|
src, err := config.NewFileOrEnvironmentSource(ctx, configFile, files.FullVersion())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,9 +103,10 @@ type FileOrEnvironmentSource struct {
|
||||||
|
|
||||||
// NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource.
|
// NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource.
|
||||||
func NewFileOrEnvironmentSource(
|
func NewFileOrEnvironmentSource(
|
||||||
|
ctx context.Context,
|
||||||
configFile, envoyVersion string,
|
configFile, envoyVersion string,
|
||||||
) (*FileOrEnvironmentSource, error) {
|
) (*FileOrEnvironmentSource, error) {
|
||||||
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context {
|
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||||
return c.Str("config_file_source", configFile)
|
return c.Str("config_file_source", configFile)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -137,7 +137,7 @@ runtime_flags:
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var src Source
|
var src Source
|
||||||
src, err = NewFileOrEnvironmentSource(configFilePath, "")
|
src, err = NewFileOrEnvironmentSource(context.Background(), configFilePath, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
src = NewFileWatcherSource(context.Background(), src)
|
src = NewFileWatcherSource(context.Background(), src)
|
||||||
|
|
||||||
|
|
|
@ -247,7 +247,7 @@ func (b *Builder) buildInternalTransportSocket(
|
||||||
b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName),
|
b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
bs, err := getCombinedCertificateAuthority(cfg)
|
bs, err := getCombinedCertificateAuthority(ctx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
log.Ctx(ctx).Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||||
} else {
|
} else {
|
||||||
|
@ -347,7 +347,7 @@ func (b *Builder) buildPolicyValidationContext(
|
||||||
}
|
}
|
||||||
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
|
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
|
||||||
} else {
|
} else {
|
||||||
bs, err := getCombinedCertificateAuthority(cfg)
|
bs, err := getCombinedCertificateAuthority(ctx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
log.Ctx(ctx).Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -43,14 +43,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-57394a4e5157303436544830.pem")
|
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-57394a4e5157303436544830.pem")
|
||||||
|
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||||
rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}})
|
rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}})
|
||||||
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
||||||
|
|
||||||
o1 := config.NewDefaultOptions()
|
o1 := config.NewDefaultOptions()
|
||||||
o2 := config.NewDefaultOptions()
|
o2 := config.NewDefaultOptions()
|
||||||
o2.CA = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0})
|
o2.CA = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0})
|
||||||
|
|
||||||
combinedCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{CA: o2.CA}})
|
combinedCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{CA: o2.CA}})
|
||||||
combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename()
|
combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename()
|
||||||
|
|
||||||
t.Run("insecure", func(t *testing.T) {
|
t.Run("insecure", func(t *testing.T) {
|
||||||
|
@ -522,7 +522,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
func Test_buildCluster(t *testing.T) {
|
func Test_buildCluster(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||||
rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}})
|
rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}})
|
||||||
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
||||||
o1 := config.NewDefaultOptions()
|
o1 := config.NewDefaultOptions()
|
||||||
t.Run("insecure", func(t *testing.T) {
|
t.Run("insecure", func(t *testing.T) {
|
||||||
|
|
|
@ -189,7 +189,7 @@ var rootCABundle struct {
|
||||||
value string
|
value string
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRootCertificateAuthority() (string, error) {
|
func getRootCertificateAuthority(ctx context.Context) (string, error) {
|
||||||
rootCABundle.Do(func() {
|
rootCABundle.Do(func() {
|
||||||
// from https://www.envoyproxy.io/docs/envoy/latest/intro/arch_overview/security/ssl#arch-overview-ssl-enabling-verification
|
// from https://www.envoyproxy.io/docs/envoy/latest/intro/arch_overview/security/ssl#arch-overview-ssl-enabling-verification
|
||||||
knownRootLocations := []string{
|
knownRootLocations := []string{
|
||||||
|
@ -207,10 +207,10 @@ func getRootCertificateAuthority() (string, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if rootCABundle.value == "" {
|
if rootCABundle.value == "" {
|
||||||
log.Error().Strs("known-locations", knownRootLocations).
|
log.Ctx(ctx).Error().Strs("known-locations", knownRootLocations).
|
||||||
Msgf("no root certificates were found in any of the known locations")
|
Msgf("no root certificates were found in any of the known locations")
|
||||||
} else {
|
} else {
|
||||||
log.Info().Msgf("using %s as the system root certificate authority bundle", rootCABundle.value)
|
log.Ctx(ctx).Info().Msgf("using %s as the system root certificate authority bundle", rootCABundle.value)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if rootCABundle.value == "" {
|
if rootCABundle.value == "" {
|
||||||
|
@ -219,8 +219,8 @@ func getRootCertificateAuthority() (string, error) {
|
||||||
return rootCABundle.value, nil
|
return rootCABundle.value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCombinedCertificateAuthority(cfg *config.Config) ([]byte, error) {
|
func getCombinedCertificateAuthority(ctx context.Context, cfg *config.Config) ([]byte, error) {
|
||||||
rootFile, err := getRootCertificateAuthority()
|
rootFile, err := getRootCertificateAuthority(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ type DataBroker struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new databroker service.
|
// New creates a new databroker service.
|
||||||
func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
|
func New(ctx context.Context, cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
|
||||||
localListener, err := net.Listen("tcp", "127.0.0.1:0")
|
localListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -61,8 +61,8 @@ func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
|
||||||
// No metrics handler because we have one in the control plane. Add one
|
// No metrics handler because we have one in the control plane. Add one
|
||||||
// if we no longer register with that grpc Server
|
// if we no longer register with that grpc Server
|
||||||
localGRPCServer := grpc.NewServer(
|
localGRPCServer := grpc.NewServer(
|
||||||
grpc.StreamInterceptor(si),
|
grpc.ChainStreamInterceptor(log.StreamServerInterceptor(log.Ctx(ctx)), si),
|
||||||
grpc.UnaryInterceptor(ui),
|
grpc.ChainUnaryInterceptor(log.UnaryServerInterceptor(log.Ctx(ctx)), ui),
|
||||||
)
|
)
|
||||||
|
|
||||||
sharedKey, err := cfg.Options.GetSharedKey()
|
sharedKey, err := cfg.Options.GetSharedKey()
|
||||||
|
@ -79,7 +79,7 @@ func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
|
||||||
grpc.WithStatsHandler(clientStatsHandler.Handler),
|
grpc.WithStatsHandler(clientStatsHandler.Handler),
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
|
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||||
return c.Str("service", "databroker").Str("config_source", "bootstrap")
|
return c.Str("service", "databroker").Str("config_source", "bootstrap")
|
||||||
})
|
})
|
||||||
localGRPCConnection, err := grpc.DialContext(
|
localGRPCConnection, err := grpc.DialContext(
|
||||||
|
@ -91,7 +91,7 @@ func New(cfg *config.Config, eventsMgr *events.Manager) (*DataBroker, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dataBrokerServer, err := newDataBrokerServer(cfg)
|
dataBrokerServer, err := newDataBrokerServer(ctx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package databroker
|
package databroker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
@ -20,7 +21,7 @@ func TestNew(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tt.opts.Provider = "google"
|
tt.opts.Provider = "google"
|
||||||
_, err := New(&config.Config{Options: &tt.opts}, events.New())
|
_, err := New(context.Background(), &config.Config{Options: &tt.opts}, events.New())
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
|
|
@ -23,7 +23,7 @@ type dataBrokerServer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// newDataBrokerServer creates a new databroker service server.
|
// newDataBrokerServer creates a new databroker service server.
|
||||||
func newDataBrokerServer(cfg *config.Config) (*dataBrokerServer, error) {
|
func newDataBrokerServer(ctx context.Context, cfg *config.Config) (*dataBrokerServer, error) {
|
||||||
srv := &dataBrokerServer{
|
srv := &dataBrokerServer{
|
||||||
sharedKey: atomicutil.NewValue([]byte{}),
|
sharedKey: atomicutil.NewValue([]byte{}),
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ func newDataBrokerServer(cfg *config.Config) (*dataBrokerServer, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.server = databroker.New(opts...)
|
srv.server = databroker.New(ctx, opts...)
|
||||||
srv.setKey(cfg)
|
srv.setKey(cfg)
|
||||||
return srv, nil
|
return srv, nil
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ func (srv *dataBrokerServer) OnConfigChange(ctx context.Context, cfg *config.Con
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.server.UpdateConfig(opts...)
|
srv.server.UpdateConfig(ctx, opts...)
|
||||||
srv.setKey(cfg)
|
srv.setKey(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ var lis *bufconn.Listener
|
||||||
func init() {
|
func init() {
|
||||||
lis = bufconn.Listen(bufSize)
|
lis = bufconn.Listen(bufSize)
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
internalSrv := internal_databroker.New()
|
internalSrv := internal_databroker.New(context.Background())
|
||||||
srv := &dataBrokerServer{server: internalSrv, sharedKey: atomicutil.NewValue([]byte{})}
|
srv := &dataBrokerServer{server: internalSrv, sharedKey: atomicutil.NewValue([]byte{})}
|
||||||
databroker.RegisterDataBrokerServiceServer(s, srv)
|
databroker.RegisterDataBrokerServiceServer(s, srv)
|
||||||
|
|
||||||
|
|
|
@ -163,7 +163,7 @@ func waitForHealthy(ctx context.Context) error {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,7 +56,7 @@ type Stateful struct {
|
||||||
|
|
||||||
// NewStateful initializes the authentication flow for the given configuration
|
// NewStateful initializes the authentication flow for the given configuration
|
||||||
// and session store.
|
// and session store.
|
||||||
func NewStateful(cfg *config.Config, sessionStore sessions.SessionStore) (*Stateful, error) {
|
func NewStateful(ctx context.Context, cfg *config.Config, sessionStore sessions.SessionStore) (*Stateful, error) {
|
||||||
s := &Stateful{
|
s := &Stateful{
|
||||||
sessionDuration: cfg.Options.CookieExpire,
|
sessionDuration: cfg.Options.CookieExpire,
|
||||||
sessionStore: sessionStore,
|
sessionStore: sessionStore,
|
||||||
|
@ -88,7 +88,7 @@ func NewStateful(cfg *config.Config, sessionStore sessions.SessionStore) (*State
|
||||||
s.defaultIdentityProviderID = idp.GetId()
|
s.defaultIdentityProviderID = idp.GetId()
|
||||||
}
|
}
|
||||||
|
|
||||||
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(),
|
dataBrokerConn, err := outboundGRPCConnection.Get(ctx,
|
||||||
&grpc.OutboundOptions{
|
&grpc.OutboundOptions{
|
||||||
OutboundPort: cfg.OutboundPort,
|
OutboundPort: cfg.OutboundPort,
|
||||||
InstallationID: cfg.Options.InstallationID,
|
InstallationID: cfg.Options.InstallationID,
|
||||||
|
|
|
@ -69,7 +69,7 @@ func TestStatefulSignIn(t *testing.T) {
|
||||||
tt := tt
|
tt := tt
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
sessionStore := &mstore.Store{SaveError: tt.saveError}
|
sessionStore := &mstore.Store{SaveError: tt.saveError}
|
||||||
flow, err := NewStateful(&config.Config{Options: opts}, sessionStore)
|
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, sessionStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -123,7 +123,7 @@ func TestStatefulAuthenticateSignInURL(t *testing.T) {
|
||||||
opts.AuthenticateURLString = "https://authenticate.example.com"
|
opts.AuthenticateURLString = "https://authenticate.example.com"
|
||||||
key := cryptutil.NewKey()
|
key := cryptutil.NewKey()
|
||||||
opts.SharedKey = base64.StdEncoding.EncodeToString(key)
|
opts.SharedKey = base64.StdEncoding.EncodeToString(key)
|
||||||
flow, err := NewStateful(&config.Config{Options: opts}, nil)
|
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Run("NilQueryParams", func(t *testing.T) {
|
t.Run("NilQueryParams", func(t *testing.T) {
|
||||||
|
@ -238,7 +238,7 @@ func TestStatefulCallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
flow, err := NewStateful(&config.Config{Options: opts}, tt.sessionStore)
|
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, tt.sessionStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -289,7 +289,7 @@ func TestStatefulCallback(t *testing.T) {
|
||||||
|
|
||||||
func TestStatefulRevokeSession(t *testing.T) {
|
func TestStatefulRevokeSession(t *testing.T) {
|
||||||
opts := config.NewDefaultOptions()
|
opts := config.NewDefaultOptions()
|
||||||
flow, err := NewStateful(&config.Config{Options: opts}, nil)
|
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
|
@ -367,7 +367,7 @@ func TestPersistSession(t *testing.T) {
|
||||||
|
|
||||||
opts := config.NewDefaultOptions()
|
opts := config.NewDefaultOptions()
|
||||||
opts.CookieExpire = 4 * time.Hour
|
opts.CookieExpire = 4 * time.Hour
|
||||||
flow, err := NewStateful(&config.Config{Options: opts}, nil)
|
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
|
|
|
@ -64,6 +64,7 @@ type Stateless struct {
|
||||||
// NewStateless initializes the authentication flow for the given
|
// NewStateless initializes the authentication flow for the given
|
||||||
// configuration, session store, and additional options.
|
// configuration, session store, and additional options.
|
||||||
func NewStateless(
|
func NewStateless(
|
||||||
|
ctx context.Context,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
sessionStore sessions.SessionStore,
|
sessionStore sessions.SessionStore,
|
||||||
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error),
|
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error),
|
||||||
|
@ -131,7 +132,7 @@ func NewStateless(
|
||||||
return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err)
|
return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{
|
dataBrokerConn, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
|
||||||
OutboundPort: cfg.OutboundPort,
|
OutboundPort: cfg.OutboundPort,
|
||||||
InstallationID: cfg.Options.InstallationID,
|
InstallationID: cfg.Options.InstallationID,
|
||||||
ServiceName: cfg.Options.Services,
|
ServiceName: cfg.Options.Services,
|
||||||
|
|
|
@ -63,11 +63,12 @@ type Manager struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new autocert manager.
|
// New creates a new autocert manager.
|
||||||
func New(src config.Source) (*Manager, error) {
|
func New(ctx context.Context, src config.Source) (*Manager, error) {
|
||||||
return newManager(context.Background(), src, certmagic.DefaultACME, renewalInterval)
|
return newManager(ctx, src, certmagic.DefaultACME, renewalInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newManager(ctx context.Context,
|
func newManager(
|
||||||
|
ctx context.Context,
|
||||||
src config.Source,
|
src config.Source,
|
||||||
acmeTemplate certmagic.ACMEIssuer,
|
acmeTemplate certmagic.ACMEIssuer,
|
||||||
checkInterval time.Duration,
|
checkInterval time.Duration,
|
||||||
|
@ -96,12 +97,13 @@ func newManager(ctx context.Context,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mgr.certmagic = certmagic.New(certmagic.NewCache(certmagic.CacheOptions{
|
cache := certmagic.NewCache(certmagic.CacheOptions{
|
||||||
GetConfigForCert: func(_ certmagic.Certificate) (*certmagic.Config, error) {
|
GetConfigForCert: func(_ certmagic.Certificate) (*certmagic.Config, error) {
|
||||||
return mgr.certmagic, nil
|
return mgr.certmagic, nil
|
||||||
},
|
},
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
}), certmagic.Config{
|
})
|
||||||
|
mgr.certmagic = certmagic.New(cache, certmagic.Config{
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
Storage: certmagicStorage,
|
Storage: certmagicStorage,
|
||||||
})
|
})
|
||||||
|
|
|
@ -316,7 +316,7 @@ func TestRedirect(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
_, err = New(src)
|
_, err = New(context.Background(), src)
|
||||||
if !assert.NoError(t, err) {
|
if !assert.NoError(t, err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,7 +47,7 @@ func (l *locker) Lock(ctx context.Context, name string) error {
|
||||||
// wait
|
// wait
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-time.After(lockPollInterval):
|
case <-time.After(lockPollInterval):
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -71,7 +71,7 @@ func (srv *Server) getDataBrokerClient(ctx context.Context) (databrokerpb.DataBr
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cc, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{
|
cc, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
|
||||||
OutboundPort: cfg.OutboundPort,
|
OutboundPort: cfg.OutboundPort,
|
||||||
InstallationID: cfg.Options.InstallationID,
|
InstallationID: cfg.Options.InstallationID,
|
||||||
ServiceName: cfg.Options.Services,
|
ServiceName: cfg.Options.Services,
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
|
|
||||||
"github.com/CAFxX/httpcompression"
|
"github.com/CAFxX/httpcompression"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/handlers"
|
"github.com/pomerium/pomerium/internal/handlers"
|
||||||
|
@ -19,7 +20,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (srv *Server) addHTTPMiddleware(root *mux.Router, _ *config.Config) {
|
func (srv *Server) addHTTPMiddleware(root *mux.Router, logger *zerolog.Logger, _ *config.Config) {
|
||||||
compressor, err := httpcompression.DefaultAdapter()
|
compressor, err := httpcompression.DefaultAdapter()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
@ -28,7 +29,7 @@ func (srv *Server) addHTTPMiddleware(root *mux.Router, _ *config.Config) {
|
||||||
root.Use(compressor)
|
root.Use(compressor)
|
||||||
root.Use(srv.reproxy.Middleware)
|
root.Use(srv.reproxy.Middleware)
|
||||||
root.Use(requestid.HTTPMiddleware())
|
root.Use(requestid.HTTPMiddleware())
|
||||||
root.Use(log.NewHandler(log.Logger))
|
root.Use(log.NewHandler(func() *zerolog.Logger { return logger }))
|
||||||
root.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
root.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
||||||
log.FromRequest(r).Debug().
|
log.FromRequest(r).Debug().
|
||||||
Dur("duration", duration).
|
Dur("duration", duration).
|
||||||
|
|
|
@ -69,10 +69,16 @@ type Server struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Server. Listener ports are chosen by the OS.
|
// NewServer creates a new Server. Listener ports are chosen by the OS.
|
||||||
func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr *events.Manager) (*Server, error) {
|
func NewServer(
|
||||||
|
ctx context.Context,
|
||||||
|
cfg *config.Config,
|
||||||
|
metricsMgr *config.MetricsManager,
|
||||||
|
eventsMgr *events.Manager,
|
||||||
|
) (*Server, error) {
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
metricsMgr: metricsMgr,
|
metricsMgr: metricsMgr,
|
||||||
EventsMgr: eventsMgr,
|
EventsMgr: eventsMgr,
|
||||||
|
filemgr: filemgr.NewManager(),
|
||||||
reproxy: reproxy.New(),
|
reproxy: reproxy.New(),
|
||||||
haveSetCapacity: map[string]bool{},
|
haveSetCapacity: map[string]bool{},
|
||||||
updateConfig: make(chan *config.Config, 1),
|
updateConfig: make(chan *config.Config, 1),
|
||||||
|
@ -80,6 +86,10 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
|
||||||
httpRouter: atomicutil.NewValue(mux.NewRouter()),
|
httpRouter: atomicutil.NewValue(mux.NewRouter()),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||||
|
return c.Str("server_name", cfg.Options.Services)
|
||||||
|
})
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// setup gRPC
|
// setup gRPC
|
||||||
|
@ -95,8 +105,16 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
|
||||||
)
|
)
|
||||||
srv.GRPCServer = grpc.NewServer(
|
srv.GRPCServer = grpc.NewServer(
|
||||||
grpc.StatsHandler(telemetry.NewGRPCServerStatsHandler(cfg.Options.Services)),
|
grpc.StatsHandler(telemetry.NewGRPCServerStatsHandler(cfg.Options.Services)),
|
||||||
grpc.ChainUnaryInterceptor(requestid.UnaryServerInterceptor(), ui),
|
grpc.ChainUnaryInterceptor(
|
||||||
grpc.ChainStreamInterceptor(requestid.StreamServerInterceptor(), si),
|
log.UnaryServerInterceptor(log.Ctx(ctx)),
|
||||||
|
requestid.UnaryServerInterceptor(),
|
||||||
|
ui,
|
||||||
|
),
|
||||||
|
grpc.ChainStreamInterceptor(
|
||||||
|
log.StreamServerInterceptor(log.Ctx(ctx)),
|
||||||
|
requestid.StreamServerInterceptor(),
|
||||||
|
si,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
reflection.Register(srv.GRPCServer)
|
reflection.Register(srv.GRPCServer)
|
||||||
srv.registerAccessLogHandlers()
|
srv.registerAccessLogHandlers()
|
||||||
|
@ -125,7 +143,7 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := srv.updateRouter(cfg); err != nil {
|
if err := srv.updateRouter(ctx, cfg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
srv.DebugRouter = mux.NewRouter()
|
srv.DebugRouter = mux.NewRouter()
|
||||||
|
@ -141,7 +159,6 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
|
||||||
// metrics
|
// metrics
|
||||||
srv.MetricsRouter.Handle("/metrics", srv.metricsMgr)
|
srv.MetricsRouter.Handle("/metrics", srv.metricsMgr)
|
||||||
|
|
||||||
srv.filemgr = filemgr.NewManager()
|
|
||||||
srv.filemgr.ClearCache()
|
srv.filemgr.ClearCache()
|
||||||
|
|
||||||
srv.Builder = envoyconfig.New(
|
srv.Builder = envoyconfig.New(
|
||||||
|
@ -152,10 +169,6 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
|
||||||
srv.reproxy,
|
srv.reproxy,
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
|
|
||||||
return c.Str("server_name", cfg.Options.Services)
|
|
||||||
})
|
|
||||||
|
|
||||||
res, err := srv.buildDiscoveryResources(ctx)
|
res, err := srv.buildDiscoveryResources(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -211,7 +224,7 @@ func (srv *Server) Run(ctx context.Context) error {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case cfg := <-srv.updateConfig:
|
case cfg := <-srv.updateConfig:
|
||||||
err := srv.update(ctx, cfg)
|
err := srv.update(ctx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -232,29 +245,29 @@ func (srv *Server) OnConfigChange(ctx context.Context, cfg *config.Config) error
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case srv.updateConfig <- cfg:
|
case srv.updateConfig <- cfg:
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableAuthenticate enables the authenticate service.
|
// EnableAuthenticate enables the authenticate service.
|
||||||
func (srv *Server) EnableAuthenticate(svc Service) error {
|
func (srv *Server) EnableAuthenticate(ctx context.Context, svc Service) error {
|
||||||
srv.authenticateSvc = svc
|
srv.authenticateSvc = svc
|
||||||
return srv.updateRouter(srv.currentConfig.Load())
|
return srv.updateRouter(ctx, srv.currentConfig.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableProxy enables the proxy service.
|
// EnableProxy enables the proxy service.
|
||||||
func (srv *Server) EnableProxy(svc Service) error {
|
func (srv *Server) EnableProxy(ctx context.Context, svc Service) error {
|
||||||
srv.proxySvc = svc
|
srv.proxySvc = svc
|
||||||
return srv.updateRouter(srv.currentConfig.Load())
|
return srv.updateRouter(ctx, srv.currentConfig.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) update(ctx context.Context, cfg *config.Config) error {
|
func (srv *Server) update(ctx context.Context, cfg *config.Config) error {
|
||||||
ctx, span := trace.StartSpan(ctx, "controlplane.Server.update")
|
ctx, span := trace.StartSpan(ctx, "controlplane.Server.update")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
if err := srv.updateRouter(cfg); err != nil {
|
if err := srv.updateRouter(ctx, cfg); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
srv.reproxy.Update(ctx, cfg)
|
srv.reproxy.Update(ctx, cfg)
|
||||||
|
@ -267,9 +280,9 @@ func (srv *Server) update(ctx context.Context, cfg *config.Config) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) updateRouter(cfg *config.Config) error {
|
func (srv *Server) updateRouter(ctx context.Context, cfg *config.Config) error {
|
||||||
httpRouter := mux.NewRouter()
|
httpRouter := mux.NewRouter()
|
||||||
srv.addHTTPMiddleware(httpRouter, cfg)
|
srv.addHTTPMiddleware(httpRouter, log.Ctx(ctx), cfg)
|
||||||
if err := srv.mountCommonEndpoints(httpRouter, cfg); err != nil {
|
if err := srv.mountCommonEndpoints(httpRouter, cfg); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ func TestServerHTTP(t *testing.T) {
|
||||||
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
|
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
|
||||||
|
|
||||||
src := config.NewStaticSource(cfg)
|
src := config.NewStaticSource(cfg)
|
||||||
srv, err := NewServer(cfg, config.NewMetricsManager(ctx, src), events.New())
|
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
go srv.Run(ctx)
|
go srv.Run(ctx)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package xdsmgr
|
package xdsmgr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
|
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
|
||||||
|
@ -19,8 +20,8 @@ var (
|
||||||
routeConfigurationTypeURL = protoutil.GetTypeURL((*envoy_config_route_v3.RouteConfiguration)(nil))
|
routeConfigurationTypeURL = protoutil.GetTypeURL((*envoy_config_route_v3.RouteConfiguration)(nil))
|
||||||
)
|
)
|
||||||
|
|
||||||
func logNACK(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
|
func logNACK(ctx context.Context, req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
|
||||||
log.Debug().
|
log.Ctx(ctx).Debug().
|
||||||
Str("type-url", req.GetTypeUrl()).
|
Str("type-url", req.GetTypeUrl()).
|
||||||
Any("error-detail", req.GetErrorDetail()).
|
Any("error-detail", req.GetErrorDetail()).
|
||||||
Msg("xdsmgr: nack")
|
Msg("xdsmgr: nack")
|
||||||
|
@ -28,8 +29,8 @@ func logNACK(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
|
||||||
health.ReportError(getHealthCheck(req.GetTypeUrl()), errors.New(req.GetErrorDetail().GetMessage()))
|
health.ReportError(getHealthCheck(req.GetTypeUrl()), errors.New(req.GetErrorDetail().GetMessage()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func logACK(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
|
func logACK(ctx context.Context, req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
|
||||||
log.Debug().
|
log.Ctx(ctx).Debug().
|
||||||
Str("type-url", req.GetTypeUrl()).
|
Str("type-url", req.GetTypeUrl()).
|
||||||
Msg("xdsmgr: ack")
|
Msg("xdsmgr: ack")
|
||||||
|
|
||||||
|
|
|
@ -113,7 +113,7 @@ func (mgr *Manager) DeltaAggregatedResources(
|
||||||
for _, resource := range mgr.resources[req.GetTypeUrl()] {
|
for _, resource := range mgr.resources[req.GetTypeUrl()] {
|
||||||
state.clientResourceVersions[resource.Name] = resource.Version
|
state.clientResourceVersions[resource.Name] = resource.Version
|
||||||
}
|
}
|
||||||
logNACK(req)
|
logNACK(ctx, req)
|
||||||
case req.GetResponseNonce() == mgr.nonce:
|
case req.GetResponseNonce() == mgr.nonce:
|
||||||
// an ACK for the last response
|
// an ACK for the last response
|
||||||
// - set the client resource versions to the current resource versions
|
// - set the client resource versions to the current resource versions
|
||||||
|
@ -121,10 +121,11 @@ func (mgr *Manager) DeltaAggregatedResources(
|
||||||
for _, resource := range mgr.resources[req.GetTypeUrl()] {
|
for _, resource := range mgr.resources[req.GetTypeUrl()] {
|
||||||
state.clientResourceVersions[resource.Name] = resource.Version
|
state.clientResourceVersions[resource.Name] = resource.Version
|
||||||
}
|
}
|
||||||
logACK(req)
|
logACK(ctx, req)
|
||||||
default:
|
default:
|
||||||
// an ACK for a response that's not the last response
|
// an ACK for a response that's not the last response
|
||||||
log.Ctx(ctx).Debug().
|
log.Ctx(ctx).
|
||||||
|
Debug().
|
||||||
Str("type-url", req.GetTypeUrl()).
|
Str("type-url", req.GetTypeUrl()).
|
||||||
Msg("xdsmgr: ack")
|
Msg("xdsmgr: ack")
|
||||||
}
|
}
|
||||||
|
@ -161,7 +162,7 @@ func (mgr *Manager) DeltaAggregatedResources(
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case incoming <- req:
|
case incoming <- req:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -173,7 +174,7 @@ func (mgr *Manager) DeltaAggregatedResources(
|
||||||
var typeURLs []string
|
var typeURLs []string
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case req := <-incoming:
|
case req := <-incoming:
|
||||||
handleDeltaRequest(changeCtx, req)
|
handleDeltaRequest(changeCtx, req)
|
||||||
typeURLs = []string{req.GetTypeUrl()}
|
typeURLs = []string{req.GetTypeUrl()}
|
||||||
|
@ -193,7 +194,7 @@ func (mgr *Manager) DeltaAggregatedResources(
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case outgoing <- res:
|
case outgoing <- res:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -204,9 +205,10 @@ func (mgr *Manager) DeltaAggregatedResources(
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case res := <-outgoing:
|
case res := <-outgoing:
|
||||||
log.Ctx(ctx).Debug().
|
log.Ctx(ctx).
|
||||||
|
Debug().
|
||||||
Str("type-url", res.GetTypeUrl()).
|
Str("type-url", res.GetTypeUrl()).
|
||||||
Int("resource-count", len(res.GetResources())).
|
Int("resource-count", len(res.GetResources())).
|
||||||
Int("removed-resource-count", len(res.GetRemovedResources())).
|
Int("removed-resource-count", len(res.GetRemovedResources())).
|
||||||
|
|
|
@ -96,7 +96,7 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
|
||||||
cfg := src.underlyingConfig.Clone()
|
cfg := src.underlyingConfig.Clone()
|
||||||
|
|
||||||
// start the updater
|
// start the updater
|
||||||
src.runUpdater(cfg)
|
src.runUpdater(ctx, cfg)
|
||||||
|
|
||||||
now = time.Now()
|
now = time.Now()
|
||||||
err := src.buildNewConfigLocked(ctx, cfg)
|
err := src.buildNewConfigLocked(ctx, cfg)
|
||||||
|
@ -234,7 +234,7 @@ func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, po
|
||||||
cfg.Options.AdditionalPolicies = append(cfg.Options.AdditionalPolicies, additionalPolicies...)
|
cfg.Options.AdditionalPolicies = append(cfg.Options.AdditionalPolicies, additionalPolicies...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (src *ConfigSource) runUpdater(cfg *config.Config) {
|
func (src *ConfigSource) runUpdater(ctx context.Context, cfg *config.Config) {
|
||||||
sharedKey, _ := cfg.Options.GetSharedKey()
|
sharedKey, _ := cfg.Options.GetSharedKey()
|
||||||
connectionOptions := &grpc.OutboundOptions{
|
connectionOptions := &grpc.OutboundOptions{
|
||||||
OutboundPort: cfg.OutboundPort,
|
OutboundPort: cfg.OutboundPort,
|
||||||
|
@ -257,7 +257,6 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
|
||||||
src.cancel = nil
|
src.cancel = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
ctx, src.cancel = context.WithCancel(ctx)
|
ctx, src.cancel = context.WithCancel(ctx)
|
||||||
|
|
||||||
cc, err := src.outboundGRPCConnection.Get(ctx, connectionOptions)
|
cc, err := src.outboundGRPCConnection.Get(ctx, connectionOptions)
|
||||||
|
@ -268,7 +267,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
|
||||||
|
|
||||||
client := databroker.NewDataBrokerServiceClient(cc)
|
client := databroker.NewDataBrokerServiceClient(cc)
|
||||||
|
|
||||||
syncer := databroker.NewSyncer("databroker", &syncerHandler{
|
syncer := databroker.NewSyncer(ctx, "databroker", &syncerHandler{
|
||||||
client: client,
|
client: client,
|
||||||
src: src,
|
src: src,
|
||||||
}, databroker.WithTypeURL(grpcutil.GetTypeURL(new(configpb.Config))),
|
}, databroker.WithTypeURL(grpcutil.GetTypeURL(new(configpb.Config))),
|
||||||
|
|
|
@ -41,7 +41,7 @@ func TestConfigSource(t *testing.T) {
|
||||||
defer func() { _ = li.Close() }()
|
defer func() { _ = li.Close() }()
|
||||||
_, outboundPort, _ := net.SplitHostPort(li.Addr().String())
|
_, outboundPort, _ := net.SplitHostPort(li.Addr().String())
|
||||||
|
|
||||||
dataBrokerServer := New()
|
dataBrokerServer := New(ctx)
|
||||||
srv := grpc.NewServer()
|
srv := grpc.NewServer()
|
||||||
databroker.RegisterDataBrokerServiceServer(srv, dataBrokerServer)
|
databroker.RegisterDataBrokerServiceServer(srv, dataBrokerServer)
|
||||||
go func() { _ = srv.Serve(li) }()
|
go func() { _ = srv.Serve(li) }()
|
||||||
|
|
|
@ -28,7 +28,7 @@ func (srv *Server) Report(ctx context.Context, req *registrypb.RegisterRequest)
|
||||||
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Report")
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Report")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
r, err := srv.getRegistry()
|
r, err := srv.getRegistry(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -41,7 +41,7 @@ func (srv *Server) List(ctx context.Context, req *registrypb.ListRequest) (*regi
|
||||||
ctx, span := trace.StartSpan(ctx, "databroker.grpc.List")
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.List")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
r, err := srv.getRegistry()
|
r, err := srv.getRegistry(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry
|
||||||
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Watch")
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Watch")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
r, err := srv.getRegistry()
|
r, err := srv.getRegistry(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -66,8 +66,8 @@ func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) getRegistry() (registry.Interface, error) {
|
func (srv *Server) getRegistry(ctx context.Context) (registry.Interface, error) {
|
||||||
backend, err := srv.getBackend()
|
backend, err := srv.getBackend(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -81,7 +81,7 @@ func (srv *Server) getRegistry() (registry.Interface, error) {
|
||||||
r = srv.registry
|
r = srv.registry
|
||||||
var err error
|
var err error
|
||||||
if r == nil {
|
if r == nil {
|
||||||
r, err = srv.newRegistryLocked(backend)
|
r, err = srv.newRegistryLocked(ctx, backend)
|
||||||
srv.registry = r
|
srv.registry = r
|
||||||
}
|
}
|
||||||
srv.mu.Unlock()
|
srv.mu.Unlock()
|
||||||
|
@ -92,9 +92,7 @@ func (srv *Server) getRegistry() (registry.Interface, error) {
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interface, error) {
|
func (srv *Server) newRegistryLocked(ctx context.Context, backend storage.Backend) (registry.Interface, error) {
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
if hasRegistryServer, ok := backend.(interface {
|
if hasRegistryServer, ok := backend.(interface {
|
||||||
RegistryServer() registrypb.RegistryServer
|
RegistryServer() registrypb.RegistryServer
|
||||||
}); ok {
|
}); ok {
|
||||||
|
|
|
@ -28,25 +28,26 @@ import (
|
||||||
type Server struct {
|
type Server struct {
|
||||||
cfg *serverConfig
|
cfg *serverConfig
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
backend storage.Backend
|
backend storage.Backend
|
||||||
registry registry.Interface
|
backendCtx context.Context
|
||||||
|
registry registry.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new server.
|
// New creates a new server.
|
||||||
func New(options ...ServerOption) *Server {
|
func New(ctx context.Context, options ...ServerOption) *Server {
|
||||||
srv := &Server{}
|
srv := &Server{
|
||||||
srv.UpdateConfig(options...)
|
backendCtx: ctx,
|
||||||
|
}
|
||||||
|
srv.UpdateConfig(ctx, options...)
|
||||||
return srv
|
return srv
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateConfig updates the server with the new options.
|
// UpdateConfig updates the server with the new options.
|
||||||
func (srv *Server) UpdateConfig(options ...ServerOption) {
|
func (srv *Server) UpdateConfig(ctx context.Context, options ...ServerOption) {
|
||||||
srv.mu.Lock()
|
srv.mu.Lock()
|
||||||
defer srv.mu.Unlock()
|
defer srv.mu.Unlock()
|
||||||
|
|
||||||
ctx := context.TODO()
|
|
||||||
|
|
||||||
cfg := newServerConfig(options...)
|
cfg := newServerConfig(options...)
|
||||||
if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) {
|
if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) {
|
||||||
log.Ctx(ctx).Debug().Msg("databroker: no changes detected, re-using existing DBs")
|
log.Ctx(ctx).Debug().Msg("databroker: no changes detected, re-using existing DBs")
|
||||||
|
@ -80,7 +81,7 @@ func (srv *Server) AcquireLease(ctx context.Context, req *databroker.AcquireLeas
|
||||||
Dur("duration", req.GetDuration().AsDuration()).
|
Dur("duration", req.GetDuration().AsDuration()).
|
||||||
Msg("acquire lease")
|
Msg("acquire lease")
|
||||||
|
|
||||||
db, err := srv.getBackend()
|
db, err := srv.getBackend(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -107,7 +108,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
||||||
Str("id", req.GetId()).
|
Str("id", req.GetId()).
|
||||||
Msg("get")
|
Msg("get")
|
||||||
|
|
||||||
db, err := srv.getBackend()
|
db, err := srv.getBackend(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -131,7 +132,7 @@ func (srv *Server) ListTypes(ctx context.Context, _ *emptypb.Empty) (*databroker
|
||||||
defer span.End()
|
defer span.End()
|
||||||
log.Ctx(ctx).Debug().Msg("list types")
|
log.Ctx(ctx).Debug().Msg("list types")
|
||||||
|
|
||||||
db, err := srv.getBackend()
|
db, err := srv.getBackend(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -156,7 +157,7 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
|
||||||
|
|
||||||
query := strings.ToLower(req.GetQuery())
|
query := strings.ToLower(req.GetQuery())
|
||||||
|
|
||||||
db, err := srv.getBackend()
|
db, err := srv.getBackend(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -217,7 +218,7 @@ func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databr
|
||||||
Msg("put")
|
Msg("put")
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := srv.getBackend()
|
db, err := srv.getBackend(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -256,7 +257,7 @@ func (srv *Server) Patch(ctx context.Context, req *databroker.PatchRequest) (*da
|
||||||
Msg("patch")
|
Msg("patch")
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := srv.getBackend()
|
db, err := srv.getBackend(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -282,7 +283,7 @@ func (srv *Server) ReleaseLease(ctx context.Context, req *databroker.ReleaseLeas
|
||||||
Str("id", req.GetId()).
|
Str("id", req.GetId()).
|
||||||
Msg("release lease")
|
Msg("release lease")
|
||||||
|
|
||||||
db, err := srv.getBackend()
|
db, err := srv.getBackend(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -305,7 +306,7 @@ func (srv *Server) RenewLease(ctx context.Context, req *databroker.RenewLeaseReq
|
||||||
Dur("duration", req.GetDuration().AsDuration()).
|
Dur("duration", req.GetDuration().AsDuration()).
|
||||||
Msg("renew lease")
|
Msg("renew lease")
|
||||||
|
|
||||||
db, err := srv.getBackend()
|
db, err := srv.getBackend(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -325,7 +326,7 @@ func (srv *Server) SetOptions(ctx context.Context, req *databroker.SetOptionsReq
|
||||||
ctx, span := trace.StartSpan(ctx, "databroker.grpc.SetOptions")
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.SetOptions")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
backend, err := srv.getBackend()
|
backend, err := srv.getBackend(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -351,12 +352,13 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
log.Ctx(ctx).Debug().
|
log.Ctx(ctx).
|
||||||
|
Debug().
|
||||||
Uint64("server_version", req.GetServerVersion()).
|
Uint64("server_version", req.GetServerVersion()).
|
||||||
Uint64("record_version", req.GetRecordVersion()).
|
Uint64("record_version", req.GetRecordVersion()).
|
||||||
Msg("sync")
|
Msg("sync")
|
||||||
|
|
||||||
backend, err := srv.getBackend()
|
backend, err := srv.getBackend(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -392,7 +394,7 @@ func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databrok
|
||||||
Str("type", req.GetType()).
|
Str("type", req.GetType()).
|
||||||
Msg("sync latest")
|
Msg("sync latest")
|
||||||
|
|
||||||
backend, err := srv.getBackend()
|
backend, err := srv.getBackend(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -430,7 +432,7 @@ func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databrok
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) getBackend() (backend storage.Backend, err error) {
|
func (srv *Server) getBackend(ctx context.Context) (backend storage.Backend, err error) {
|
||||||
// double-checked locking:
|
// double-checked locking:
|
||||||
// first try the read lock, then re-try with the write lock, and finally create a new backend if nil
|
// first try the read lock, then re-try with the write lock, and finally create a new backend if nil
|
||||||
srv.mu.RLock()
|
srv.mu.RLock()
|
||||||
|
@ -441,7 +443,7 @@ func (srv *Server) getBackend() (backend storage.Backend, err error) {
|
||||||
backend = srv.backend
|
backend = srv.backend
|
||||||
var err error
|
var err error
|
||||||
if backend == nil {
|
if backend == nil {
|
||||||
backend, err = srv.newBackendLocked()
|
backend, err = srv.newBackendLocked(ctx)
|
||||||
srv.backend = backend
|
srv.backend = backend
|
||||||
}
|
}
|
||||||
srv.mu.Unlock()
|
srv.mu.Unlock()
|
||||||
|
@ -452,18 +454,18 @@ func (srv *Server) getBackend() (backend storage.Backend, err error) {
|
||||||
return backend, nil
|
return backend, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
|
func (srv *Server) newBackendLocked(ctx context.Context) (storage.Backend, error) {
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
switch srv.cfg.storageType {
|
switch srv.cfg.storageType {
|
||||||
case config.StorageInMemoryName:
|
case config.StorageInMemoryName:
|
||||||
log.Ctx(ctx).Info().Msg("using in-memory store")
|
log.Ctx(ctx).Info().Msg("initializing new in-memory store")
|
||||||
return inmemory.New(), nil
|
return inmemory.New(), nil
|
||||||
case config.StoragePostgresName:
|
case config.StoragePostgresName:
|
||||||
log.Ctx(ctx).Info().Msg("using postgres store")
|
log.Ctx(ctx).Info().Msg("initializing new postgres store")
|
||||||
backend = postgres.New(srv.cfg.storageConnectionString)
|
// NB: the context passed to postgres.New here is a separate context scoped
|
||||||
|
// to the lifetime of the server itself. 'ctx' may be a short-lived request
|
||||||
|
// context, since the backend is lazy-initialized.
|
||||||
|
return postgres.New(srv.backendCtx, srv.cfg.storageConnectionString), nil
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
|
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
|
||||||
}
|
}
|
||||||
return backend, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,8 @@ func (h testSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint
|
||||||
|
|
||||||
func newServer(cfg *serverConfig) *Server {
|
func newServer(cfg *serverConfig) *Server {
|
||||||
return &Server{
|
return &Server{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
backendCtx: context.Background(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -277,7 +278,7 @@ func TestServer_Sync(t *testing.T) {
|
||||||
updateRecords := make(chan uint64, 10)
|
updateRecords := make(chan uint64, 10)
|
||||||
|
|
||||||
client := databroker.NewDataBrokerServiceClient(cc)
|
client := databroker.NewDataBrokerServiceClient(cc)
|
||||||
syncer := databroker.NewSyncer("TEST", testSyncerHandler{
|
syncer := databroker.NewSyncer(ctx, "TEST", testSyncerHandler{
|
||||||
getDataBrokerServiceClient: func() databroker.DataBrokerServiceClient {
|
getDataBrokerServiceClient: func() databroker.DataBrokerServiceClient {
|
||||||
return client
|
return client
|
||||||
},
|
},
|
||||||
|
@ -292,12 +293,12 @@ func TestServer_Sync(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case <-clearRecords:
|
case <-clearRecords:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-updateRecords:
|
case <-updateRecords:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -313,7 +314,7 @@ func TestServer_Sync(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case <-updateRecords:
|
case <-updateRecords:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -45,7 +45,7 @@ func TestEnabler(t *testing.T) {
|
||||||
started.Add(1)
|
started.Add(1)
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
stopped.Add(1)
|
stopped.Add(1)
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
}), true)
|
}), true)
|
||||||
time.AfterFunc(time.Millisecond*10, e.Disable)
|
time.AfterFunc(time.Millisecond*10, e.Disable)
|
||||||
go e.Run(ctx)
|
go e.Run(ctx)
|
||||||
|
|
|
@ -80,6 +80,9 @@ func (watcher *Watcher) initLocked(ctx context.Context) {
|
||||||
|
|
||||||
if watcher.pollingWatcher == nil {
|
if watcher.pollingWatcher == nil {
|
||||||
watcher.pollingWatcher = filenotify.NewPollingWatcher(nil)
|
watcher.pollingWatcher = filenotify.NewPollingWatcher(nil)
|
||||||
|
context.AfterFunc(ctx, func() {
|
||||||
|
watcher.pollingWatcher.Close()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
errors := watcher.pollingWatcher.Errors()
|
errors := watcher.pollingWatcher.Errors()
|
||||||
|
|
10
internal/log/debug.go
Normal file
10
internal/log/debug.go
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
package log
|
||||||
|
|
||||||
|
import "sync/atomic"
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Debug option to disable the Zap log shim
|
||||||
|
DebugDisableZapLogger atomic.Bool
|
||||||
|
// Debug option to suppress global warnings
|
||||||
|
DebugDisableGlobalWarnings atomic.Bool
|
||||||
|
)
|
|
@ -52,6 +52,9 @@ func Logger() *zerolog.Logger {
|
||||||
|
|
||||||
// ZapLogger returns the global zap logger.
|
// ZapLogger returns the global zap logger.
|
||||||
func ZapLogger() *zap.Logger {
|
func ZapLogger() *zap.Logger {
|
||||||
|
if DebugDisableZapLogger.Load() {
|
||||||
|
return zap.NewNop()
|
||||||
|
}
|
||||||
return zapLogger.Load()
|
return zapLogger.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
package log
|
package log
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/protoutil/streams"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
|
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
|
||||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||||
|
@ -121,3 +124,17 @@ func HeadersHandler(headers []string) func(next http.Handler) http.Handler {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func StreamServerInterceptor(lg *zerolog.Logger) grpc.StreamServerInterceptor {
|
||||||
|
return func(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||||
|
s := streams.NewServerStreamWithContext(ss)
|
||||||
|
s.SetContext(lg.WithContext(s.Ctx))
|
||||||
|
return handler(srv, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnaryServerInterceptor(lg *zerolog.Logger) grpc.UnaryServerInterceptor {
|
||||||
|
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||||
|
return handler(lg.WithContext(ctx), req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -11,6 +11,9 @@ var warnCookieSecretOnce sync.Once
|
||||||
// WarnCookieSecret warns about the cookie secret.
|
// WarnCookieSecret warns about the cookie secret.
|
||||||
func WarnCookieSecret() {
|
func WarnCookieSecret() {
|
||||||
warnCookieSecretOnce.Do(func() {
|
warnCookieSecretOnce.Do(func() {
|
||||||
|
if DebugDisableGlobalWarnings.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
Info().
|
Info().
|
||||||
Msg("using a generated COOKIE_SECRET. " +
|
Msg("using a generated COOKIE_SECRET. " +
|
||||||
"Set the COOKIE_SECRET to avoid users being logged out on restart. " +
|
"Set the COOKIE_SECRET to avoid users being logged out on restart. " +
|
||||||
|
@ -23,6 +26,9 @@ var warnNoTLSCertificateOnce syncutil.OnceMap[string]
|
||||||
// WarnNoTLSCertificate warns about no TLS certificate.
|
// WarnNoTLSCertificate warns about no TLS certificate.
|
||||||
func WarnNoTLSCertificate(domain string) {
|
func WarnNoTLSCertificate(domain string) {
|
||||||
warnNoTLSCertificateOnce.Do(domain, func() {
|
warnNoTLSCertificateOnce.Do(domain, func() {
|
||||||
|
if DebugDisableGlobalWarnings.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
Info().
|
Info().
|
||||||
Str("domain", domain).
|
Str("domain", domain).
|
||||||
Msg("no TLS certificate found for domain, using a self-signed certificate")
|
Msg("no TLS certificate found for domain, using a self-signed certificate")
|
||||||
|
@ -34,6 +40,9 @@ var warnWebSocketHTTP1_1Once syncutil.OnceMap[string]
|
||||||
// WarnWebSocketHTTP1_1 warns about falling back to http 1.1 due to web socket support.
|
// WarnWebSocketHTTP1_1 warns about falling back to http 1.1 due to web socket support.
|
||||||
func WarnWebSocketHTTP1_1(clusterID string) {
|
func WarnWebSocketHTTP1_1(clusterID string) {
|
||||||
warnWebSocketHTTP1_1Once.Do(clusterID, func() {
|
warnWebSocketHTTP1_1Once.Do(clusterID, func() {
|
||||||
|
if DebugDisableGlobalWarnings.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
Info().
|
Info().
|
||||||
Str("cluster-id", clusterID).
|
Str("cluster-id", clusterID).
|
||||||
Msg("forcing http/1.1 due to web socket support")
|
Msg("forcing http/1.1 due to web socket support")
|
||||||
|
|
|
@ -103,7 +103,7 @@ func makeSelect(
|
||||||
fn: func(ctx context.Context) error {
|
fn: func(ctx context.Context) error {
|
||||||
// unreachable, the context handler will never be called
|
// unreachable, the context handler will never be called
|
||||||
// as its channel can only be closed
|
// as its channel can only be closed
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
},
|
},
|
||||||
ch: reflect.ValueOf(ctx.Done()),
|
ch: reflect.ValueOf(ctx.Done()),
|
||||||
},
|
},
|
||||||
|
|
|
@ -22,7 +22,7 @@ func WaitForHealthy(ctx context.Context, client *http.Client, routes []*config.R
|
||||||
healthy++
|
healthy++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkHealth(ctx context.Context, client *http.Client, addr string) error {
|
func checkHealth(ctx context.Context, client *http.Client, addr string) error {
|
||||||
|
|
|
@ -74,7 +74,7 @@ func (svc *Source) updateLoop(ctx context.Context) error {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-svc.checkForUpdate:
|
case <-svc.checkForUpdate:
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,7 +34,7 @@ type source struct {
|
||||||
func (src *source) WaitReady(ctx context.Context) error {
|
func (src *source) WaitReady(ctx context.Context) error {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-src.ready:
|
case <-src.ready:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ func BuildImportCmd() *cobra.Command {
|
||||||
return fmt.Errorf("no config file provided")
|
return fmt.Errorf("no config file provided")
|
||||||
}
|
}
|
||||||
log.SetLevel(zerolog.ErrorLevel)
|
log.SetLevel(zerolog.ErrorLevel)
|
||||||
src, err := config.NewFileOrEnvironmentSource(configFile, files.FullVersion())
|
src, err := config.NewFileOrEnvironmentSource(cmd.Context(), configFile, files.FullVersion())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@ import (
|
||||||
func (svc *Mux) Watch(ctx context.Context, opts ...WatchOption) error {
|
func (svc *Mux) Watch(ctx context.Context, opts ...WatchOption) error {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-svc.ready:
|
case <-svc.ready:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -102,7 +102,7 @@ func TestDatabrokerRestart(t *testing.T) {
|
||||||
cl(context.Background(), newConfig())
|
cl(context.Background(), newConfig())
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
require.ErrorIs(t, context.Cause(ctx), controller.ErrBootstrapConfigurationChanged)
|
require.ErrorIs(t, context.Cause(ctx), controller.ErrBootstrapConfigurationChanged)
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
|
@ -56,5 +56,5 @@ func (w *LeaseStatus) MonitorLease(ctx context.Context, _ databroker.DataBrokerS
|
||||||
w.v.Store(true)
|
w.v.Store(true)
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
w.v.Store(false)
|
w.v.Store(false)
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ func TestUsageReporter(t *testing.T) {
|
||||||
t.Cleanup(cancel)
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
||||||
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New())
|
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx))
|
||||||
})
|
})
|
||||||
t.Cleanup(func() { cc.Close() })
|
t.Cleanup(func() { cc.Close() })
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Checker) ConfigSyncer(ctx context.Context) error {
|
func (c *Checker) ConfigSyncer(ctx context.Context) error {
|
||||||
syncer := databroker.NewSyncer("zero-health-check", c, databroker.WithTypeURL(protoutil.GetTypeURL(new(configpb.Config))))
|
syncer := databroker.NewSyncer(ctx, "zero-health-check", c, databroker.WithTypeURL(protoutil.GetTypeURL(new(configpb.Config))))
|
||||||
return syncer.Run(ctx)
|
return syncer.Run(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ func (c *service) SyncLoop(ctx context.Context) error {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-c.bundleSyncRequest:
|
case <-c.bundleSyncRequest:
|
||||||
log.Ctx(ctx).Debug().Msg("bundle sync triggered")
|
log.Ctx(ctx).Debug().Msg("bundle sync triggered")
|
||||||
err := c.syncBundles(ctx)
|
err := c.syncBundles(ctx)
|
||||||
|
|
|
@ -105,7 +105,7 @@ func (srv *Telemetry) handleRequests(ctx context.Context) error {
|
||||||
case req := <-requests:
|
case req := <-requests:
|
||||||
srv.handleRequest(ctx, req)
|
srv.handleRequest(ctx, req)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -32,7 +32,7 @@ import (
|
||||||
|
|
||||||
// Run runs the main pomerium application.
|
// Run runs the main pomerium application.
|
||||||
func Run(ctx context.Context, src config.Source) error {
|
func Run(ctx context.Context, src config.Source) error {
|
||||||
_, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Debug().Msgf(s, i...) }))
|
_, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Ctx(ctx).Debug().Msgf(s, i...) }))
|
||||||
|
|
||||||
evt := log.Ctx(ctx).Info().
|
evt := log.Ctx(ctx).Info().
|
||||||
Str("envoy_version", files.FullVersion()).
|
Str("envoy_version", files.FullVersion()).
|
||||||
|
@ -53,7 +53,7 @@ func Run(ctx context.Context, src config.Source) error {
|
||||||
// trigger changes when underlying files are changed
|
// trigger changes when underlying files are changed
|
||||||
src = config.NewFileWatcherSource(ctx, src)
|
src = config.NewFileWatcherSource(ctx, src)
|
||||||
|
|
||||||
src, err = autocert.New(src)
|
src, err = autocert.New(ctx, src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -71,7 +71,7 @@ func Run(ctx context.Context, src config.Source) error {
|
||||||
cfg := src.GetConfig()
|
cfg := src.GetConfig()
|
||||||
|
|
||||||
// setup the control plane
|
// setup the control plane
|
||||||
controlPlane, err := controlplane.NewServer(cfg, metricsMgr, eventsMgr)
|
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating control plane: %w", err)
|
return fmt.Errorf("error creating control plane: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -166,11 +166,11 @@ func setupAuthenticate(ctx context.Context, src config.Source, controlPlane *con
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
svc, err := authenticate.New(src.GetConfig())
|
svc, err := authenticate.New(ctx, src.GetConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating authenticate service: %w", err)
|
return fmt.Errorf("error creating authenticate service: %w", err)
|
||||||
}
|
}
|
||||||
err = controlPlane.EnableAuthenticate(svc)
|
err = controlPlane.EnableAuthenticate(ctx, svc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error adding authenticate service to control plane: %w", err)
|
return fmt.Errorf("error adding authenticate service to control plane: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -183,7 +183,7 @@ func setupAuthenticate(ctx context.Context, src config.Source, controlPlane *con
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupAuthorize(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) {
|
func setupAuthorize(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) {
|
||||||
svc, err := authorize.New(src.GetConfig())
|
svc, err := authorize.New(ctx, src.GetConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating authorize service: %w", err)
|
return nil, fmt.Errorf("error creating authorize service: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -200,7 +200,7 @@ func setupDataBroker(ctx context.Context,
|
||||||
controlPlane *controlplane.Server,
|
controlPlane *controlplane.Server,
|
||||||
eventsMgr *events.Manager,
|
eventsMgr *events.Manager,
|
||||||
) (*databroker_service.DataBroker, error) {
|
) (*databroker_service.DataBroker, error) {
|
||||||
svc, err := databroker_service.New(src.GetConfig(), eventsMgr)
|
svc, err := databroker_service.New(ctx, src.GetConfig(), eventsMgr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating databroker service: %w", err)
|
return nil, fmt.Errorf("error creating databroker service: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -223,11 +223,11 @@ func setupProxy(ctx context.Context, src config.Source, controlPlane *controlpla
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
svc, err := proxy.New(src.GetConfig())
|
svc, err := proxy.New(ctx, src.GetConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating proxy service: %w", err)
|
return fmt.Errorf("error creating proxy service: %w", err)
|
||||||
}
|
}
|
||||||
err = controlPlane.EnableProxy(svc)
|
err = controlPlane.EnableProxy(ctx, svc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error adding proxy service to control plane: %w", err)
|
return fmt.Errorf("error adding proxy service to control plane: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -185,7 +185,7 @@ func (srv *Server) run(ctx context.Context, cfg *config.Config) error {
|
||||||
|
|
||||||
// monitor the process so we exit if it prematurely exits
|
// monitor the process so we exit if it prematurely exits
|
||||||
var monitorProcessCtx context.Context
|
var monitorProcessCtx context.Context
|
||||||
monitorProcessCtx, srv.monitorProcessCancel = context.WithCancel(context.Background())
|
monitorProcessCtx, srv.monitorProcessCancel = context.WithCancel(context.WithoutCancel(ctx))
|
||||||
go srv.monitorProcess(monitorProcessCtx, int32(cmd.Process.Pid))
|
go srv.monitorProcess(monitorProcessCtx, int32(cmd.Process.Pid))
|
||||||
|
|
||||||
if srv.resourceMonitor != nil {
|
if srv.resourceMonitor != nil {
|
||||||
|
@ -251,7 +251,7 @@ func (srv *Server) parseLog(line string) (name string, logLevel string, msg stri
|
||||||
func (srv *Server) handleLogs(ctx context.Context, rc io.ReadCloser) {
|
func (srv *Server) handleLogs(ctx context.Context, rc io.ReadCloser) {
|
||||||
defer rc.Close()
|
defer rc.Close()
|
||||||
|
|
||||||
l := log.With().Str("service", "envoy").Logger()
|
l := log.Ctx(ctx).With().Str("service", "envoy").Logger()
|
||||||
bo := backoff.NewExponentialBackOff()
|
bo := backoff.NewExponentialBackOff()
|
||||||
|
|
||||||
s := bufio.NewReader(rc)
|
s := bufio.NewReader(rc)
|
||||||
|
|
|
@ -10,7 +10,7 @@ func (f *FanOut[T]) Publish(ctx context.Context, msg T) error {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-f.done:
|
case <-f.done:
|
||||||
return ErrStopped
|
return ErrStopped
|
||||||
case f.messages <- msg:
|
case f.messages <- msg:
|
||||||
|
|
|
@ -141,7 +141,7 @@ func WaitForReady(ctx context.Context, cc *grpc.ClientConn, timeout time.Duratio
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -71,13 +71,13 @@ func (locker *Leaser) Run(ctx context.Context) error {
|
||||||
case err == nil:
|
case err == nil:
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-retryTicker.C:
|
case <-retryTicker.C:
|
||||||
}
|
}
|
||||||
case errors.Is(err, retryableError{}):
|
case errors.Is(err, retryableError{}):
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-time.After(bo.NextBackOff()):
|
case <-time.After(bo.NextBackOff()):
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -169,7 +169,7 @@ func TestLeasers(t *testing.T) {
|
||||||
fn2 := func(ctx context.Context) error {
|
fn2 := func(ctx context.Context) error {
|
||||||
atomic.AddInt64(&counter, 10)
|
atomic.AddInt64(&counter, 10)
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
}
|
}
|
||||||
leaser := databroker.NewLeasers("TEST", time.Second*30, client, fn1, fn2)
|
leaser := databroker.NewLeasers("TEST", time.Second*30, client, fn1, fn2)
|
||||||
err := leaser.Run(context.Background())
|
err := leaser.Run(context.Background())
|
||||||
|
|
|
@ -110,7 +110,7 @@ func (r *Reconciler) reconcileLoop(ctx context.Context) error {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-r.trigger:
|
case <-r.trigger:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ func Test_SyncLatestRecords(t *testing.T) {
|
||||||
defer clearTimeout()
|
defer clearTimeout()
|
||||||
|
|
||||||
cc := testutil.NewGRPCServer(t, func(s *grpc.Server) {
|
cc := testutil.NewGRPCServer(t, func(s *grpc.Server) {
|
||||||
databrokerpb.RegisterDataBrokerServiceServer(s, databroker.New())
|
databrokerpb.RegisterDataBrokerServiceServer(s, databroker.New(ctx))
|
||||||
})
|
})
|
||||||
|
|
||||||
c := databrokerpb.NewDataBrokerServiceClient(cc)
|
c := databrokerpb.NewDataBrokerServiceClient(cc)
|
||||||
|
|
|
@ -72,8 +72,8 @@ type Syncer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSyncer creates a new Syncer.
|
// NewSyncer creates a new Syncer.
|
||||||
func NewSyncer(id string, handler SyncerHandler, options ...SyncerOption) *Syncer {
|
func NewSyncer(ctx context.Context, id string, handler SyncerHandler, options ...SyncerOption) *Syncer {
|
||||||
closeCtx, closeCtxCancel := context.WithCancel(context.Background())
|
closeCtx, closeCtxCancel := context.WithCancel(context.WithoutCancel(ctx))
|
||||||
|
|
||||||
bo := backoff.NewExponentialBackOff()
|
bo := backoff.NewExponentialBackOff()
|
||||||
bo.MaxElapsedTime = 0
|
bo.MaxElapsedTime = 0
|
||||||
|
@ -120,7 +120,7 @@ func (syncer *Syncer) Run(ctx context.Context) error {
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("sync")
|
log.Ctx(ctx).Error().Err(err).Msg("sync")
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-time.After(syncer.backoff.NextBackOff()):
|
case <-time.After(syncer.backoff.NextBackOff()):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -133,6 +133,9 @@ func (syncer *Syncer) init(ctx context.Context) error {
|
||||||
Type: syncer.cfg.typeURL,
|
Type: syncer.cfg.typeURL,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if status.Code(err) == codes.Canceled && ctx.Err() != nil {
|
||||||
|
err = fmt.Errorf("%w: %w", err, context.Cause(ctx))
|
||||||
|
}
|
||||||
return fmt.Errorf("error during initial sync: %w", err)
|
return fmt.Errorf("error during initial sync: %w", err)
|
||||||
}
|
}
|
||||||
syncer.backoff.Reset()
|
syncer.backoff.Reset()
|
||||||
|
@ -167,6 +170,9 @@ func (syncer *Syncer) sync(ctx context.Context) error {
|
||||||
syncer.serverVersion = 0
|
syncer.serverVersion = 0
|
||||||
return nil
|
return nil
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
if status.Code(err) == codes.Canceled && ctx.Err() != nil {
|
||||||
|
err = fmt.Errorf("%w: %w", err, context.Cause(ctx))
|
||||||
|
}
|
||||||
return fmt.Errorf("error receiving sync record: %w", err)
|
return fmt.Errorf("error receiving sync record: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -157,7 +157,7 @@ func TestSyncer(t *testing.T) {
|
||||||
|
|
||||||
clearCh := make(chan struct{})
|
clearCh := make(chan struct{})
|
||||||
updateCh := make(chan []*Record)
|
updateCh := make(chan []*Record)
|
||||||
syncer := NewSyncer("test", testSyncerHandler{
|
syncer := NewSyncer(ctx, "test", testSyncerHandler{
|
||||||
getDataBrokerServiceClient: func() DataBrokerServiceClient {
|
getDataBrokerServiceClient: func() DataBrokerServiceClient {
|
||||||
return NewDataBrokerServiceClient(gc)
|
return NewDataBrokerServiceClient(gc)
|
||||||
},
|
},
|
||||||
|
|
|
@ -124,13 +124,13 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
|
||||||
// wait for initial sync
|
// wait for initial sync
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-clear:
|
case <-clear:
|
||||||
mgr.reset()
|
mgr.reset()
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case msg := <-update:
|
case msg := <-update:
|
||||||
mgr.onUpdateRecords(ctx, msg)
|
mgr.onUpdateRecords(ctx, msg)
|
||||||
}
|
}
|
||||||
|
@ -150,7 +150,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case <-clear:
|
case <-clear:
|
||||||
mgr.reset()
|
mgr.reset()
|
||||||
case msg := <-update:
|
case msg := <-update:
|
||||||
|
|
|
@ -17,7 +17,7 @@ type dataBrokerSyncer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDataBrokerSyncer(
|
func newDataBrokerSyncer(
|
||||||
_ context.Context,
|
ctx context.Context,
|
||||||
cfg *atomicutil.Value[*config],
|
cfg *atomicutil.Value[*config],
|
||||||
update chan<- updateRecordsMessage,
|
update chan<- updateRecordsMessage,
|
||||||
clear chan<- struct{},
|
clear chan<- struct{},
|
||||||
|
@ -28,7 +28,7 @@ func newDataBrokerSyncer(
|
||||||
update: update,
|
update: update,
|
||||||
clear: clear,
|
clear: clear,
|
||||||
}
|
}
|
||||||
syncer.syncer = databroker.NewSyncer("identity_manager", syncer)
|
syncer.syncer = databroker.NewSyncer(ctx, "identity_manager", syncer)
|
||||||
return syncer
|
return syncer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ type sessionSyncerHandler struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSessionSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer {
|
func newSessionSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer {
|
||||||
return databroker.NewSyncer("identity_manager/sessions", sessionSyncerHandler{baseCtx: ctx, mgr: mgr},
|
return databroker.NewSyncer(ctx, "identity_manager/sessions", sessionSyncerHandler{baseCtx: ctx, mgr: mgr},
|
||||||
databroker.WithTypeURL(grpcutil.GetTypeURL(new(session.Session))))
|
databroker.WithTypeURL(grpcutil.GetTypeURL(new(session.Session))))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ type userSyncerHandler struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUserSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer {
|
func newUserSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer {
|
||||||
return databroker.NewSyncer("identity_manager/users", userSyncerHandler{baseCtx: ctx, mgr: mgr},
|
return databroker.NewSyncer(ctx, "identity_manager/users", userSyncerHandler{baseCtx: ctx, mgr: mgr},
|
||||||
databroker.WithTypeURL(grpcutil.GetTypeURL(new(user.User))))
|
databroker.WithTypeURL(grpcutil.GetTypeURL(new(user.User))))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,14 +37,14 @@ type Backend struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Backend.
|
// New creates a new Backend.
|
||||||
func New(dsn string, options ...Option) *Backend {
|
func New(ctx context.Context, dsn string, options ...Option) *Backend {
|
||||||
backend := &Backend{
|
backend := &Backend{
|
||||||
cfg: getConfig(options...),
|
cfg: getConfig(options...),
|
||||||
dsn: dsn,
|
dsn: dsn,
|
||||||
onRecordChange: signal.New(),
|
onRecordChange: signal.New(),
|
||||||
onServiceChange: signal.New(),
|
onServiceChange: signal.New(),
|
||||||
}
|
}
|
||||||
backend.closeCtx, backend.close = context.WithCancel(context.Background())
|
backend.closeCtx, backend.close = context.WithCancel(ctx)
|
||||||
|
|
||||||
go backend.doPeriodically(func(ctx context.Context) error {
|
go backend.doPeriodically(func(ctx context.Context) error {
|
||||||
_, pool, err := backend.init(ctx)
|
_, pool, err := backend.init(ctx)
|
||||||
|
|
|
@ -34,7 +34,7 @@ func TestBackend(t *testing.T) {
|
||||||
defer clearTimeout()
|
defer clearTimeout()
|
||||||
|
|
||||||
require.NoError(t, testutil.WithTestPostgres(func(dsn string) error {
|
require.NoError(t, testutil.WithTestPostgres(func(dsn string) error {
|
||||||
backend := New(dsn)
|
backend := New(ctx, dsn)
|
||||||
defer backend.Close()
|
defer backend.Close()
|
||||||
|
|
||||||
t.Run("put", func(t *testing.T) {
|
t.Run("put", func(t *testing.T) {
|
||||||
|
|
|
@ -42,7 +42,7 @@ func TestRegistry(t *testing.T) {
|
||||||
defer clearTimeout()
|
defer clearTimeout()
|
||||||
|
|
||||||
require.NoError(t, testutil.WithTestPostgres(func(dsn string) error {
|
require.NoError(t, testutil.WithTestPostgres(func(dsn string) error {
|
||||||
backend := New(dsn)
|
backend := New(ctx, dsn)
|
||||||
defer backend.Close()
|
defer backend.Close()
|
||||||
|
|
||||||
eg, ctx := errgroup.WithContext(ctx)
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
|
@ -53,7 +53,7 @@ func TestRegistry(t *testing.T) {
|
||||||
send: func(res *registry.ServiceList) error {
|
send: func(res *registry.ServiceList) error {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case listResults <- res:
|
case listResults <- res:
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -73,7 +73,7 @@ func TestRegistry(t *testing.T) {
|
||||||
eg.Go(func() error {
|
eg.Go(func() error {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case res := <-listResults:
|
case res := <-listResults:
|
||||||
testutil.AssertProtoEqual(t, ®istry.ServiceList{}, res)
|
testutil.AssertProtoEqual(t, ®istry.ServiceList{}, res)
|
||||||
}
|
}
|
||||||
|
@ -92,7 +92,7 @@ func TestRegistry(t *testing.T) {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return context.Cause(ctx)
|
||||||
case res := <-listResults:
|
case res := <-listResults:
|
||||||
testutil.AssertProtoEqual(t, ®istry.ServiceList{
|
testutil.AssertProtoEqual(t, ®istry.ServiceList{
|
||||||
Services: []*registry.Service{
|
Services: []*registry.Service{
|
||||||
|
|
|
@ -32,14 +32,14 @@ func Test_getUserInfoData(t *testing.T) {
|
||||||
defer clearTimeout()
|
defer clearTimeout()
|
||||||
|
|
||||||
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
||||||
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New())
|
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New(ctx))
|
||||||
})
|
})
|
||||||
t.Cleanup(func() { cc.Close() })
|
t.Cleanup(func() { cc.Close() })
|
||||||
|
|
||||||
client := databrokerpb.NewDataBrokerServiceClient(cc)
|
client := databrokerpb.NewDataBrokerServiceClient(cc)
|
||||||
|
|
||||||
opts := testOptions(t)
|
opts := testOptions(t)
|
||||||
proxy, err := New(&config.Config{Options: opts})
|
proxy, err := New(ctx, &config.Config{Options: opts})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
proxy.state.Load().dataBrokerClient = client
|
proxy.state.Load().dataBrokerClient = client
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -36,7 +37,7 @@ func TestProxy_SignOut(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
opts := testOptions(t)
|
opts := testOptions(t)
|
||||||
p, err := New(&config.Config{Options: opts})
|
p, err := New(context.Background(), &config.Config{Options: opts})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -129,7 +130,7 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p, err := New(&config.Config{Options: tt.options})
|
p, err := New(context.Background(), &config.Config{Options: tt.options})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -270,7 +271,7 @@ func TestLoadSessionState(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
opts := testOptions(t)
|
opts := testOptions(t)
|
||||||
proxy, err := New(&config.Config{Options: opts})
|
proxy, err := New(context.Background(), &config.Config{Options: opts})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
|
r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil)
|
||||||
|
@ -285,7 +286,7 @@ func TestLoadSessionState(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
opts := testOptions(t)
|
opts := testOptions(t)
|
||||||
proxy, err := New(&config.Config{Options: opts})
|
proxy, err := New(context.Background(), &config.Config{Options: opts})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
session := encodeSession(t, opts, &sessions.State{
|
session := encodeSession(t, opts, &sessions.State{
|
||||||
|
@ -308,7 +309,7 @@ func TestLoadSessionState(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
opts := testOptions(t)
|
opts := testOptions(t)
|
||||||
proxy, err := New(&config.Config{Options: opts})
|
proxy, err := New(context.Background(), &config.Config{Options: opts})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
session := encodeSession(t, opts, &sessions.State{
|
session := encodeSession(t, opts, &sessions.State{
|
||||||
|
|
|
@ -60,8 +60,8 @@ type Proxy struct {
|
||||||
|
|
||||||
// New takes a Proxy service from options and a validation function.
|
// New takes a Proxy service from options and a validation function.
|
||||||
// Function returns an error if options fail to validate.
|
// Function returns an error if options fail to validate.
|
||||||
func New(cfg *config.Config) (*Proxy, error) {
|
func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
|
||||||
state, err := newProxyStateFromConfig(cfg)
|
state, err := newProxyStateFromConfig(ctx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -71,7 +71,7 @@ func New(cfg *config.Config) (*Proxy, error) {
|
||||||
currentOptions: config.NewAtomicOptions(),
|
currentOptions: config.NewAtomicOptions(),
|
||||||
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
|
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
|
||||||
}
|
}
|
||||||
p.OnConfigChange(context.Background(), cfg)
|
p.OnConfigChange(ctx, cfg)
|
||||||
p.webauthn = webauthn.New(p.getWebauthnState)
|
p.webauthn = webauthn.New(p.getWebauthnState)
|
||||||
|
|
||||||
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
|
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
|
||||||
|
@ -87,25 +87,25 @@ func (p *Proxy) Mount(r *mux.Router) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnConfigChange updates internal structures based on config.Options
|
// OnConfigChange updates internal structures based on config.Options
|
||||||
func (p *Proxy) OnConfigChange(_ context.Context, cfg *config.Config) {
|
func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.currentOptions.Store(cfg.Options)
|
p.currentOptions.Store(cfg.Options)
|
||||||
if err := p.setHandlers(cfg.Options); err != nil {
|
if err := p.setHandlers(ctx, cfg.Options); err != nil {
|
||||||
log.Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
|
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
|
||||||
}
|
}
|
||||||
if state, err := newProxyStateFromConfig(cfg); err != nil {
|
if state, err := newProxyStateFromConfig(ctx, cfg); err != nil {
|
||||||
log.Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings")
|
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings")
|
||||||
} else {
|
} else {
|
||||||
p.state.Store(state)
|
p.state.Store(state)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) setHandlers(opts *config.Options) error {
|
func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error {
|
||||||
if opts.NumPolicies() == 0 {
|
if opts.NumPolicies() == 0 {
|
||||||
log.Info().Msg("proxy: configuration has no policies")
|
log.Ctx(ctx).Info().Msg("proxy: configuration has no policies")
|
||||||
}
|
}
|
||||||
r := httputil.NewRouter()
|
r := httputil.NewRouter()
|
||||||
r.NotFoundHandler = httputil.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) error {
|
r.NotFoundHandler = httputil.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) error {
|
||||||
|
|
|
@ -104,7 +104,7 @@ func TestNew(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := New(&config.Config{Options: tt.opts})
|
got, err := New(context.Background(), &config.Config{Options: tt.opts})
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -197,7 +197,7 @@ func Test_UpdateOptions(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p, err := New(&config.Config{Options: tt.originalOptions})
|
p, err := New(context.Background(), &config.Config{Options: tt.originalOptions})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ type proxyState struct {
|
||||||
authenticateFlow authenticateFlow
|
authenticateFlow authenticateFlow
|
||||||
}
|
}
|
||||||
|
|
||||||
func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
|
func newProxyStateFromConfig(ctx context.Context, cfg *config.Config) (*proxyState, error) {
|
||||||
err := ValidateOptions(cfg.Options)
|
err := ValidateOptions(cfg.Options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -57,7 +57,7 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{
|
dataBrokerConn, err := outboundGRPCConnection.Get(ctx, &grpc.OutboundOptions{
|
||||||
OutboundPort: cfg.OutboundPort,
|
OutboundPort: cfg.OutboundPort,
|
||||||
InstallationID: cfg.Options.InstallationID,
|
InstallationID: cfg.Options.InstallationID,
|
||||||
ServiceName: cfg.Options.Services,
|
ServiceName: cfg.Options.Services,
|
||||||
|
@ -71,10 +71,10 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
|
||||||
state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist
|
state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist
|
||||||
|
|
||||||
if cfg.Options.UseStatelessAuthenticateFlow() {
|
if cfg.Options.UseStatelessAuthenticateFlow() {
|
||||||
state.authenticateFlow, err = authenticateflow.NewStateless(
|
state.authenticateFlow, err = authenticateflow.NewStateless(ctx,
|
||||||
cfg, state.sessionStore, nil, nil, nil)
|
cfg, state.sessionStore, nil, nil, nil)
|
||||||
} else {
|
} else {
|
||||||
state.authenticateFlow, err = authenticateflow.NewStateful(cfg, state.sessionStore)
|
state.authenticateFlow, err = authenticateflow.NewStateful(ctx, cfg, state.sessionStore)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
Loading…
Add table
Reference in a new issue