log context (#2107)

This commit is contained in:
wasaga 2021-04-22 10:58:13 -04:00 committed by GitHub
parent e7995954ff
commit e0c09a0998
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
87 changed files with 714 additions and 524 deletions

View file

@ -3,6 +3,7 @@
package authenticate
import (
"context"
"errors"
"fmt"
"html/template"
@ -76,19 +77,19 @@ func New(cfg *config.Config) (*Authenticate, error) {
}
// OnConfigChange updates internal structures based on config.Options
func (a *Authenticate) OnConfigChange(cfg *config.Config) {
func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
if a == nil {
return
}
a.options.Store(cfg.Options)
if state, err := newAuthenticateStateFromConfig(cfg); err != nil {
log.Error().Err(err).Msg("authenticate: failed to update state")
log.Error(ctx).Err(err).Msg("authenticate: failed to update state")
} else {
a.state.Store(state)
}
if err := a.updateProvider(cfg); err != nil {
log.Error().Err(err).Msg("authenticate: failed to update identity provider")
log.Error(ctx).Err(err).Msg("authenticate: failed to update identity provider")
}
}

View file

@ -275,7 +275,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
endSessionURL.RawQuery = params.Encode()
redirectString = endSessionURL.String()
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
log.Warn().Err(err).Msg("authenticate.SignOut: failed getting session")
log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session")
}
if redirectString != "" {
httputil.Redirect(w, r, redirectString, http.StatusFound)
@ -558,7 +558,7 @@ func (a *Authenticate) saveSessionToDataBroker(
AccessToken: accessToken.AccessToken,
})
if err != nil {
log.Error().Err(err).Msg("directory: failed to refresh user data")
log.Error(ctx).Err(err).Msg("directory: failed to refresh user data")
}
return nil

View file

@ -56,7 +56,7 @@ func (a *Authorize) WaitForInitialSync(ctx context.Context) error {
return ctx.Err()
case <-a.dataBrokerInitialSync:
}
log.Info().Msg("initial sync from databroker complete")
log.Info(ctx).Msg("initial sync from databroker complete")
return nil
}
@ -82,10 +82,10 @@ func newPolicyEvaluator(opts *config.Options, store *evaluator.Store) (*evaluato
}
// OnConfigChange updates internal structures based on config.Options
func (a *Authorize) OnConfigChange(cfg *config.Config) {
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
a.currentOptions.Store(cfg.Options)
if state, err := newAuthorizeStateFromConfig(cfg, a.store); err != nil {
log.Error().Err(err).Msg("authorize: error updating state")
log.Error(ctx).Err(err).Msg("authorize: error updating state")
} else {
a.state.Store(state)
}

View file

@ -1,6 +1,7 @@
package authorize
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
@ -116,7 +117,7 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUhHNHZDWlJxUFgwNGtmSFQxeVVDM1pUQkF6MFRYWkNtZ043clpDcFE3cHJvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFbzQzdjAwQlR4c3pKZWpmdHhBOWNtVGVUSmtQQXVtOGt1b0UwVlRUZnlId2k3SHJlN2FRUgpHQVJ6Nm0wMjVRdGFiRGxqeDd5MjIyY1gxblhCQXo3MlF3PT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
assertFunc = assert.False
}
a.OnConfigChange(cfg)
a.OnConfigChange(context.Background(), cfg)
assertFunc(t, oldPe == a.state.Load().evaluator)
})
}

View file

@ -2,6 +2,7 @@ package authorize
import (
"bytes"
"context"
"net/http"
"net/url"
"sort"
@ -103,7 +104,7 @@ func (a *Authorize) htmlDeniedResponse(
})
if err != nil {
buf.WriteString(reason)
log.Error().Err(err).Msg("error executing error template")
log.Error(context.TODO()).Err(err).Msg("error executing error template")
}
envoyHeaders := []*envoy_config_core_v3.HeaderValueOption{

View file

@ -162,7 +162,7 @@ func getJWK(options *config.Options) (*jose.JSONWebKey, error) {
if err != nil {
return nil, fmt.Errorf("couldn't generate signing key: %w", err)
}
log.Info().Str("Algorithm", jwk.Algorithm).
log.Info(context.TODO()).Str("Algorithm", jwk.Algorithm).
Str("KeyID", jwk.KeyID).
Interface("Public Key", jwk.Public()).
Msg("authorize: signing key")

View file

@ -1,6 +1,7 @@
package evaluator
import (
"context"
"crypto/x509"
"encoding/pem"
"fmt"
@ -46,7 +47,7 @@ func isValidClientCertificate(ca, cert string) (bool, error) {
valid := verifyErr == nil
if verifyErr != nil {
log.Debug().Err(verifyErr).Msg("client certificate failed verification: %w")
log.Debug(context.Background()).Err(verifyErr).Msg("client certificate failed verification: %w")
}
isValidClientCertificateCache.Add(cacheKey, valid)

View file

@ -1,6 +1,7 @@
package evaluator
import (
"context"
"fmt"
"strconv"
@ -71,7 +72,7 @@ func getDenyVar(vars rego.Vars) []Result {
status, err := strconv.Atoi(fmt.Sprint(denial[0]))
if err != nil {
log.Error().Err(err).Msg("invalid type in deny")
log.Error(context.TODO()).Err(err).Msg("invalid type in deny")
continue
}
msg := fmt.Sprint(denial[1])

View file

@ -156,11 +156,12 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
}
func (s *Store) write(rawPath string, value interface{}) {
err := storage.Txn(context.Background(), s.Store, storage.WriteParams, func(txn storage.Transaction) error {
ctx := context.TODO()
err := storage.Txn(ctx, s.Store, storage.WriteParams, func(txn storage.Transaction) error {
return s.writeTxn(txn, rawPath, value)
})
if err != nil {
log.Error().Err(err).Msg("opa-store: error writing data")
log.Error(ctx).Err(err).Msg("opa-store: error writing data")
return
}
}

View file

@ -46,19 +46,19 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
u, err := a.forceSync(ctx, sessionState)
if err != nil {
log.Warn().Err(err).Msg("clearing session due to force sync failed")
log.Warn(ctx).Err(err).Msg("clearing session due to force sync failed")
sessionState = nil
}
req, err := a.getEvaluatorRequestFromCheckRequest(in, sessionState)
if err != nil {
log.Warn().Err(err).Msg("error building evaluator request")
log.Warn(ctx).Err(err).Msg("error building evaluator request")
return nil, err
}
reply, err := state.evaluator.Evaluate(ctx, req)
if err != nil {
log.Error().Err(err).Msg("error during OPA evaluation")
log.Error(ctx).Err(err).Msg("error during OPA evaluation")
return nil, err
}
defer func() {

View file

@ -26,7 +26,7 @@ func (a *Authorize) logAuthorizeCheck(
hdrs := getCheckRequestHeaders(in)
hattrs := in.GetAttributes().GetRequest().GetHttp()
evt := log.Info().Str("service", "authorize")
evt := log.Info(ctx).Str("service", "authorize")
// request
evt = evt.Str("request-id", requestid.FromContext(ctx))
evt = evt.Str("check-request-id", hdrs["X-Request-Id"])
@ -66,10 +66,10 @@ func (a *Authorize) logAuthorizeCheck(
}
sealed, err := enc.Encrypt(record)
if err != nil {
log.Warn().Err(err).Msg("authorize: error encrypting audit record")
log.Warn(ctx).Err(err).Msg("authorize: error encrypting audit record")
return
}
log.Info().
log.Info(ctx).
Str("request-id", requestid.FromContext(ctx)).
EmbedObject(sealed).
Msg("audit log")

View file

@ -150,7 +150,7 @@ func (a *Authorize) waitForRecordSync(ctx context.Context, recordTypeURL, record
// record not found, so no need to wait
return nil, nil
} else if err != nil {
log.Error().
log.Error(ctx).
Err(err).
Str("type", recordTypeURL).
Str("id", recordID).
@ -160,7 +160,7 @@ func (a *Authorize) waitForRecordSync(ctx context.Context, recordTypeURL, record
select {
case <-ctx.Done():
log.Warn().
log.Warn(ctx).
Str("type", recordTypeURL).
Str("id", recordID).
Msg("authorize: first sync of record did not complete")

View file

@ -17,10 +17,11 @@ var (
)
func main() {
if err := run(context.Background()); !errors.Is(err, context.Canceled) {
ctx := context.Background()
if err := run(ctx); !errors.Is(err, context.Canceled) {
log.Fatal().Err(err).Msg("cmd/pomerium")
}
log.Info().Msg("cmd/pomerium: exiting")
log.Info(ctx).Msg("cmd/pomerium: exiting")
}
func run(ctx context.Context) error {

View file

@ -1,11 +1,14 @@
package config
import (
"context"
"crypto/sha256"
"io/ioutil"
"sync"
"github.com/fsnotify/fsnotify"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/internal/fileutil"
"github.com/pomerium/pomerium/internal/log"
@ -13,7 +16,7 @@ import (
)
// A ChangeListener is called when configuration changes.
type ChangeListener = func(*Config)
type ChangeListener = func(context.Context, *Config)
// A ChangeDispatcher manages listeners on config changes.
type ChangeDispatcher struct {
@ -22,17 +25,17 @@ type ChangeDispatcher struct {
}
// Trigger triggers a change.
func (dispatcher *ChangeDispatcher) Trigger(cfg *Config) {
func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) {
dispatcher.Lock()
defer dispatcher.Unlock()
for _, li := range dispatcher.onConfigChangeListeners {
li(cfg)
li(ctx, cfg)
}
}
// OnConfigChange adds a listener.
func (dispatcher *ChangeDispatcher) OnConfigChange(li ChangeListener) {
func (dispatcher *ChangeDispatcher) OnConfigChange(ctx context.Context, li ChangeListener) {
dispatcher.Lock()
defer dispatcher.Unlock()
dispatcher.onConfigChangeListeners = append(dispatcher.onConfigChangeListeners, li)
@ -41,7 +44,7 @@ func (dispatcher *ChangeDispatcher) OnConfigChange(li ChangeListener) {
// A Source gets configuration.
type Source interface {
GetConfig() *Config
OnConfigChange(ChangeListener)
OnConfigChange(context.Context, ChangeListener)
}
// A StaticSource always returns the same config. Useful for testing.
@ -65,18 +68,18 @@ func (src *StaticSource) GetConfig() *Config {
}
// SetConfig sets the config.
func (src *StaticSource) SetConfig(cfg *Config) {
func (src *StaticSource) SetConfig(ctx context.Context, cfg *Config) {
src.mu.Lock()
defer src.mu.Unlock()
src.cfg = cfg
for _, li := range src.lis {
li(cfg)
li(ctx, cfg)
}
}
// OnConfigChange is ignored for the StaticSource.
func (src *StaticSource) OnConfigChange(li ChangeListener) {
func (src *StaticSource) OnConfigChange(ctx context.Context, li ChangeListener) {
src.mu.Lock()
defer src.mu.Unlock()
@ -95,38 +98,48 @@ type FileOrEnvironmentSource struct {
// NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource.
func NewFileOrEnvironmentSource(configFile string) (*FileOrEnvironmentSource, error) {
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context {
return c.Str("config_file_source", configFile)
})
options, err := newOptionsFromConfig(configFile)
if err != nil {
return nil, err
}
cfg := &Config{Options: options}
metrics.SetConfigInfo(cfg.Options.Services, "local", cfg.Checksum(), true)
metrics.SetConfigInfo(ctx, cfg.Options.Services, "local", cfg.Checksum(), true)
src := &FileOrEnvironmentSource{
configFile: configFile,
config: cfg,
}
options.viper.OnConfigChange(src.onConfigChange)
options.viper.OnConfigChange(src.onConfigChange(ctx))
go options.viper.WatchConfig()
return src, nil
}
func (src *FileOrEnvironmentSource) onConfigChange(evt fsnotify.Event) {
src.mu.Lock()
cfg := src.config
options, err := newOptionsFromConfig(src.configFile)
if err == nil {
cfg = &Config{Options: options}
metrics.SetConfigInfo(cfg.Options.Services, "local", cfg.Checksum(), true)
} else {
log.Error().Err(err).Msg("config: error updating config")
metrics.SetConfigInfo(cfg.Options.Services, "local", cfg.Checksum(), false)
}
src.mu.Unlock()
func (src *FileOrEnvironmentSource) onConfigChange(ctx context.Context) func(fsnotify.Event) {
return func(evt fsnotify.Event) {
ctx := log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("config_change_id", uuid.New().String())
})
log.Info(ctx).Msg("config: file updated, reconfiguring...")
src.mu.Lock()
cfg := src.config
options, err := newOptionsFromConfig(src.configFile)
if err == nil {
cfg = &Config{Options: options}
metrics.SetConfigInfo(ctx, cfg.Options.Services, "local", cfg.Checksum(), true)
} else {
log.Error(ctx).Err(err).Msg("config: error updating config")
metrics.SetConfigInfo(ctx, cfg.Options.Services, "local", cfg.Checksum(), false)
}
src.mu.Unlock()
src.Trigger(cfg)
src.Trigger(ctx, cfg)
}
}
// GetConfig gets the config.
@ -148,7 +161,7 @@ type FileWatcherSource struct {
ChangeDispatcher
}
// NewFileWatcherSource creates a new FileWatcherSource.
// NewFileWatcherSource creates a new FileWatcherSource
func NewFileWatcherSource(underlying Source) *FileWatcherSource {
src := &FileWatcherSource{
underlying: underlying,
@ -158,13 +171,13 @@ func NewFileWatcherSource(underlying Source) *FileWatcherSource {
ch := src.watcher.Bind()
go func() {
for range ch {
src.check(underlying.GetConfig())
src.check(context.TODO(), underlying.GetConfig())
}
}()
underlying.OnConfigChange(func(cfg *Config) {
src.check(cfg)
underlying.OnConfigChange(context.TODO(), func(ctx context.Context, cfg *Config) {
src.check(ctx, cfg)
})
src.check(underlying.GetConfig())
src.check(context.TODO(), underlying.GetConfig())
return src
}
@ -176,7 +189,7 @@ func (src *FileWatcherSource) GetConfig() *Config {
return src.computedConfig
}
func (src *FileWatcherSource) check(cfg *Config) {
func (src *FileWatcherSource) check(ctx context.Context, cfg *Config) {
if cfg == nil || cfg.Options == nil {
return
}
@ -218,5 +231,5 @@ func (src *FileWatcherSource) check(cfg *Config) {
src.computedConfig = cfg.Clone()
// trigger a change
src.Trigger(src.computedConfig)
src.Trigger(ctx, src.computedConfig)
}

View file

@ -1,6 +1,7 @@
package config
import (
"context"
"io/ioutil"
"os"
"path/filepath"
@ -13,6 +14,8 @@ import (
)
func TestFileWatcherSource(t *testing.T) {
ctx := context.Background()
tmpdir := filepath.Join(os.TempDir(), uuid.New().String())
err := os.MkdirAll(tmpdir, 0o755)
if !assert.NoError(t, err) {
@ -33,7 +36,7 @@ func TestFileWatcherSource(t *testing.T) {
src := NewFileWatcherSource(ssrc)
var closeOnce sync.Once
ch := make(chan struct{})
src.OnConfigChange(func(cfg *Config) {
src.OnConfigChange(context.Background(), func(ctx context.Context, cfg *Config) {
closeOnce.Do(func() {
close(ch)
})
@ -50,7 +53,7 @@ func TestFileWatcherSource(t *testing.T) {
t.Error("expected OnConfigChange to be fired after modifying a file")
}
ssrc.SetConfig(&Config{
ssrc.SetConfig(ctx, &Config{
Options: &Options{
CAFile: filepath.Join(tmpdir, "example.txt"),
},

View file

@ -1,6 +1,7 @@
package envoyconfig
import (
"context"
"encoding/base64"
"fmt"
"net"
@ -22,7 +23,7 @@ import (
)
// BuildClusters builds envoy clusters from the given config.
func (b *Builder) BuildClusters(cfg *config.Config) ([]*envoy_config_cluster_v3.Cluster, error) {
func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*envoy_config_cluster_v3.Cluster, error) {
grpcURL := &url.URL{
Scheme: "http",
Host: b.localGRPCAddress,
@ -36,15 +37,15 @@ func (b *Builder) BuildClusters(cfg *config.Config) ([]*envoy_config_cluster_v3.
return nil, err
}
controlGRPC, err := b.buildInternalCluster(cfg.Options, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, true)
controlGRPC, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, true)
if err != nil {
return nil, err
}
controlHTTP, err := b.buildInternalCluster(cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, false)
controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, false)
if err != nil {
return nil, err
}
authZ, err := b.buildInternalCluster(cfg.Options, "pomerium-authorize", authzURLs, true)
authZ, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authzURLs, true)
if err != nil {
return nil, err
}
@ -62,7 +63,7 @@ func (b *Builder) BuildClusters(cfg *config.Config) ([]*envoy_config_cluster_v3.
policy.EnvoyOpts = newDefaultEnvoyClusterConfig()
}
if len(policy.To) > 0 {
cluster, err := b.buildPolicyCluster(cfg.Options, &policy)
cluster, err := b.buildPolicyCluster(ctx, cfg.Options, &policy)
if err != nil {
return nil, fmt.Errorf("policy #%d: %w", i, err)
}
@ -78,12 +79,18 @@ func (b *Builder) BuildClusters(cfg *config.Config) ([]*envoy_config_cluster_v3.
return clusters, nil
}
func (b *Builder) buildInternalCluster(options *config.Options, name string, dsts []*url.URL, forceHTTP2 bool) (*envoy_config_cluster_v3.Cluster, error) {
func (b *Builder) buildInternalCluster(
ctx context.Context,
options *config.Options,
name string,
dsts []*url.URL,
forceHTTP2 bool,
) (*envoy_config_cluster_v3.Cluster, error) {
cluster := newDefaultEnvoyClusterConfig()
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
var endpoints []Endpoint
for _, dst := range dsts {
ts, err := b.buildInternalTransportSocket(options, dst)
ts, err := b.buildInternalTransportSocket(ctx, options, dst)
if err != nil {
return nil, err
}
@ -95,14 +102,14 @@ func (b *Builder) buildInternalCluster(options *config.Options, name string, dst
return cluster, nil
}
func (b *Builder) buildPolicyCluster(options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
cluster := new(envoy_config_cluster_v3.Cluster)
proto.Merge(cluster, policy.EnvoyOpts)
cluster.AltStatName = getClusterStatsName(policy)
name := getClusterID(policy)
endpoints, err := b.buildPolicyEndpoints(policy)
endpoints, err := b.buildPolicyEndpoints(ctx, policy)
if err != nil {
return nil, err
}
@ -122,10 +129,10 @@ func (b *Builder) buildPolicyCluster(options *config.Options, policy *config.Pol
return cluster, nil
}
func (b *Builder) buildPolicyEndpoints(policy *config.Policy) ([]Endpoint, error) {
func (b *Builder) buildPolicyEndpoints(ctx context.Context, policy *config.Policy) ([]Endpoint, error) {
var endpoints []Endpoint
for _, dst := range policy.To {
ts, err := b.buildPolicyTransportSocket(policy, dst.URL)
ts, err := b.buildPolicyTransportSocket(ctx, policy, dst.URL)
if err != nil {
return nil, err
}
@ -134,7 +141,7 @@ func (b *Builder) buildPolicyEndpoints(policy *config.Policy) ([]Endpoint, error
return endpoints, nil
}
func (b *Builder) buildInternalTransportSocket(options *config.Options, endpoint *url.URL) (*envoy_config_core_v3.TransportSocket, error) {
func (b *Builder) buildInternalTransportSocket(ctx context.Context, options *config.Options, endpoint *url.URL) (*envoy_config_core_v3.TransportSocket, error) {
if endpoint.Scheme != "https" {
return nil, nil
}
@ -154,13 +161,13 @@ func (b *Builder) buildInternalTransportSocket(options *config.Options, endpoint
} else if options.CA != "" {
bs, err := base64.StdEncoding.DecodeString(options.CA)
if err != nil {
log.Error().Err(err).Msg("invalid custom CA certificate")
log.Error(ctx).Err(err).Msg("invalid custom CA certificate")
}
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
} else {
rootCA, err := getRootCertificateAuthority()
if err != nil {
log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else {
validationContext.TrustedCa = b.filemgr.FileDataSource(rootCA)
}
@ -183,12 +190,12 @@ func (b *Builder) buildInternalTransportSocket(options *config.Options, endpoint
}, nil
}
func (b *Builder) buildPolicyTransportSocket(policy *config.Policy, dst url.URL) (*envoy_config_core_v3.TransportSocket, error) {
func (b *Builder) buildPolicyTransportSocket(ctx context.Context, policy *config.Policy, dst url.URL) (*envoy_config_core_v3.TransportSocket, error) {
if dst.Scheme != "https" {
return nil, nil
}
vc, err := b.buildPolicyValidationContext(policy, dst)
vc, err := b.buildPolicyValidationContext(ctx, policy, dst)
if err != nil {
return nil, err
}
@ -216,7 +223,7 @@ func (b *Builder) buildPolicyTransportSocket(policy *config.Policy, dst url.URL)
}
if policy.ClientCertificate != nil {
tlsContext.CommonTlsContext.TlsCertificates = append(tlsContext.CommonTlsContext.TlsCertificates,
b.envoyTLSCertificateFromGoTLSCertificate(policy.ClientCertificate))
b.envoyTLSCertificateFromGoTLSCertificate(ctx, policy.ClientCertificate))
}
tlsConfig := marshalAny(tlsContext)
@ -229,6 +236,7 @@ func (b *Builder) buildPolicyTransportSocket(policy *config.Policy, dst url.URL)
}
func (b *Builder) buildPolicyValidationContext(
ctx context.Context,
policy *config.Policy, dst url.URL,
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
sni := dst.Hostname()
@ -247,13 +255,13 @@ func (b *Builder) buildPolicyValidationContext(
} else if policy.TLSCustomCA != "" {
bs, err := base64.StdEncoding.DecodeString(policy.TLSCustomCA)
if err != nil {
log.Error().Err(err).Msg("invalid custom CA certificate")
log.Error(ctx).Err(err).Msg("invalid custom CA certificate")
}
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
} else {
rootCA, err := getRootCertificateAuthority()
if err != nil {
log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else {
validationContext.TrustedCa = b.filemgr.FileDataSource(rootCA)
}

View file

@ -1,6 +1,7 @@
package envoyconfig
import (
"context"
"encoding/base64"
"os"
"path/filepath"
@ -18,6 +19,7 @@ import (
)
func Test_buildPolicyTransportSocket(t *testing.T) {
ctx := context.Background()
cacheDir, _ := os.UserCacheDir()
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-32484c314b584447463735303142374c31414145374650305a525539554938594d524855353757313942494d473847535231.pem")
@ -26,14 +28,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
rootCA := b.filemgr.FileDataSource(rootCAPath).GetFilename()
t.Run("insecure", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(&config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com"),
}, *mustParseURL(t, "http://example.com"))
require.NoError(t, err)
assert.Nil(t, ts)
})
t.Run("host as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(&config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
}, *mustParseURL(t, "https://example.com"))
require.NoError(t, err)
@ -67,7 +69,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("tls_server_name as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(&config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
TLSServerName: "use-this-name.example.com",
}, *mustParseURL(t, "https://example.com"))
@ -102,7 +104,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("tls_skip_verify", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(&config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
TLSSkipVerify: true,
}, *mustParseURL(t, "https://example.com"))
@ -138,7 +140,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("custom ca", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(&config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
}, *mustParseURL(t, "https://example.com"))
@ -174,7 +176,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
})
t.Run("client certificate", func(t *testing.T) {
clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey)
ts, err := b.buildPolicyTransportSocket(&config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
ClientCertificate: clientCert,
}, *mustParseURL(t, "https://example.com"))
@ -219,11 +221,12 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
}
func Test_buildCluster(t *testing.T) {
ctx := context.Background()
b := New("local-grpc", "local-http", filemgr.NewManager(), nil)
rootCAPath, _ := getRootCertificateAuthority()
rootCA := b.filemgr.FileDataSource(rootCAPath).GetFilename()
t.Run("insecure", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(&config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"),
})
require.NoError(t, err)
@ -278,7 +281,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("secure", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(&config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
To: mustParseWeightedURLs(t,
"https://example.com",
"https://example.com",
@ -406,7 +409,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("ip addresses", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(&config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"),
})
require.NoError(t, err)
@ -459,7 +462,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("weights", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(&config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"),
})
require.NoError(t, err)
@ -514,7 +517,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("localhost", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(&config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "http://localhost"),
})
require.NoError(t, err)
@ -557,7 +560,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("outlier", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(&config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com"),
})
require.NoError(t, err)

View file

@ -3,6 +3,7 @@ package envoyconfig
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
@ -132,7 +133,10 @@ func buildAddress(hostport string, defaultPort int) *envoy_config_core_v3.Addres
}
}
func (b *Builder) envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_extensions_transport_sockets_tls_v3.TlsCertificate {
func (b *Builder) envoyTLSCertificateFromGoTLSCertificate(
ctx context.Context,
cert *tls.Certificate,
) *envoy_extensions_transport_sockets_tls_v3.TlsCertificate {
envoyCert := &envoy_extensions_transport_sockets_tls_v3.TlsCertificate{}
var chain bytes.Buffer
for _, cbs := range cert.Certificate {
@ -153,7 +157,7 @@ func (b *Builder) envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate)
},
))
} else {
log.Warn().Err(err).Msg("failed to marshal private key for tls config")
log.Warn(ctx).Err(err).Msg("failed to marshal private key for tls config")
}
for _, scts := range cert.SignedCertificateTimestamps {
envoyCert.SignedCertificateTimestamp = append(envoyCert.SignedCertificateTimestamp,
@ -185,10 +189,10 @@ func getRootCertificateAuthority() (string, error) {
}
}
if rootCABundle.value == "" {
log.Error().Strs("known-locations", knownRootLocations).
log.Error(context.TODO()).Strs("known-locations", knownRootLocations).
Msgf("no root certificates were found in any of the known locations")
} else {
log.Info().Msgf("using %s as the system root certificate authority bundle", rootCABundle.value)
log.Info(context.TODO()).Msgf("using %s as the system root certificate authority bundle", rootCABundle.value)
}
})
if rootCABundle.value == "" {

View file

@ -2,6 +2,7 @@
package filemgr
import (
"context"
"fmt"
"io/ioutil"
"os"
@ -34,7 +35,7 @@ func (mgr *Manager) BytesDataSource(fileName string, data []byte) *envoy_config_
fileName = fmt.Sprintf("%s-%x%s", fileName[:len(fileName)-len(ext)], h, ext)
if err := os.MkdirAll(mgr.cfg.cacheDir, 0o700); err != nil {
log.Error().Err(err).Msg("filemgr: error creating cache directory, falling back to inline bytes")
log.Error(context.TODO()).Err(err).Msg("filemgr: error creating cache directory, falling back to inline bytes")
return inlineBytes(data)
}
@ -42,11 +43,11 @@ func (mgr *Manager) BytesDataSource(fileName string, data []byte) *envoy_config_
if _, err := os.Stat(filePath); os.IsNotExist(err) {
err = ioutil.WriteFile(filePath, data, 0o600)
if err != nil {
log.Error().Err(err).Msg("filemgr: error writing cache file, falling back to inline bytes")
log.Error(context.TODO()).Err(err).Msg("filemgr: error writing cache file, falling back to inline bytes")
return inlineBytes(data)
}
} else if err != nil {
log.Error().Err(err).Msg("filemgr: error reading cache file, falling back to inline bytes")
log.Error(context.TODO()).Err(err).Msg("filemgr: error reading cache file, falling back to inline bytes")
return inlineBytes(data)
}
@ -62,7 +63,7 @@ func (mgr *Manager) ClearCache() {
return os.Remove(p)
})
if err != nil {
log.Error().Err(err).Msg("failed to clear envoy file cache")
log.Error(context.TODO()).Err(err).Msg("failed to clear envoy file cache")
}
}

View file

@ -1,6 +1,7 @@
package envoyconfig
import (
"context"
"encoding/base64"
"fmt"
"net"
@ -187,7 +188,7 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen
CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{
TlsParams: tlsParams,
TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{
b.envoyTLSCertificateFromGoTLSCertificate(cert),
b.envoyTLSCertificateFromGoTLSCertificate(context.TODO(), cert),
},
AlpnProtocols: []string{"h2", "http/1.1"},
},
@ -645,19 +646,21 @@ func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_con
}
func (b *Builder) buildDownstreamTLSContext(cfg *config.Config, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
ctx := context.TODO()
certs, err := cfg.AllCertificates()
if err != nil {
log.Warn().Str("domain", domain).Err(err).Msg("failed to get all certificates from config")
log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get all certificates from config")
return nil
}
cert, err := cryptutil.GetCertificateForDomain(certs, domain)
if err != nil {
log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain")
log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get certificate for domain")
return nil
}
envoyCert := b.envoyTLSCertificateFromGoTLSCertificate(cert)
envoyCert := b.envoyTLSCertificateFromGoTLSCertificate(context.TODO(), cert)
return &envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext{
CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{
TlsParams: tlsParams,

View file

@ -1,6 +1,7 @@
package config
import (
"context"
"crypto/tls"
"net/http"
"sync/atomic"
@ -8,6 +9,8 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/tripper"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/rs/zerolog"
)
type httpTransport struct {
@ -18,16 +21,19 @@ type httpTransport struct {
// NewHTTPTransport creates a new http transport. If CA or CAFile is set, the transport will
// add the CA to system cert pool.
func NewHTTPTransport(src Source) http.RoundTripper {
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context {
return c.Caller()
})
t := new(httpTransport)
t.underlying, _ = http.DefaultTransport.(*http.Transport)
src.OnConfigChange(func(cfg *Config) {
t.update(cfg.Options)
src.OnConfigChange(ctx, func(ctx context.Context, cfg *Config) {
t.update(ctx, cfg.Options)
})
t.update(src.GetConfig().Options)
t.update(ctx, src.GetConfig().Options)
return t
}
func (t *httpTransport) update(options *Options) {
func (t *httpTransport) update(ctx context.Context, options *Options) {
nt := new(http.Transport)
if t.underlying != nil {
nt = t.underlying.Clone()
@ -40,7 +46,7 @@ func (t *httpTransport) update(options *Options) {
MinVersion: tls.VersionTLS12,
}
} else {
log.Error().Err(err).Msg("config: error getting cert pool")
log.Error(ctx).Err(err).Msg("config: error getting cert pool")
}
}
t.transport.Store(nt)
@ -78,7 +84,7 @@ func NewPolicyHTTPTransport(options *Options, policy *Policy) http.RoundTripper
tlsClientConfig.MinVersion = tls.VersionTLS12
isCustomClientConfig = true
} else {
log.Error().Err(err).Msg("config: error getting ca cert pool")
log.Error(context.TODO()).Err(err).Msg("config: error getting ca cert pool")
}
}
@ -89,7 +95,7 @@ func NewPolicyHTTPTransport(options *Options, policy *Policy) http.RoundTripper
tlsClientConfig.MinVersion = tls.VersionTLS12
isCustomClientConfig = true
} else {
log.Error().Err(err).Msg("config: error getting custom ca cert pool")
log.Error(context.TODO()).Err(err).Msg("config: error getting custom ca cert pool")
}
}

View file

@ -1,6 +1,7 @@
package config
import (
"context"
"sync"
"github.com/pomerium/pomerium/internal/log"
@ -12,10 +13,10 @@ type LogManager struct {
}
// NewLogManager creates a new LogManager.
func NewLogManager(src Source) *LogManager {
func NewLogManager(ctx context.Context, src Source) *LogManager {
mgr := &LogManager{}
src.OnConfigChange(mgr.OnConfigChange)
mgr.OnConfigChange(src.GetConfig())
src.OnConfigChange(ctx, mgr.OnConfigChange)
mgr.OnConfigChange(ctx, src.GetConfig())
return mgr
}
@ -25,7 +26,7 @@ func (mgr *LogManager) Close() error {
}
// OnConfigChange is called whenever configuration changes.
func (mgr *LogManager) OnConfigChange(cfg *Config) {
func (mgr *LogManager) OnConfigChange(ctx context.Context, cfg *Config) {
if cfg == nil || cfg.Options == nil {
return
}

View file

@ -1,6 +1,7 @@
package config
import (
"context"
"net/http"
"os"
"sync"
@ -9,6 +10,8 @@ import (
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/rs/zerolog"
)
// A MetricsManager manages metrics for a given configuration.
@ -22,11 +25,14 @@ type MetricsManager struct {
}
// NewMetricsManager creates a new MetricsManager.
func NewMetricsManager(src Source) *MetricsManager {
func NewMetricsManager(ctx context.Context, src Source) *MetricsManager {
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "metrics_manager")
})
mgr := &MetricsManager{}
metrics.RegisterInfoMetrics()
src.OnConfigChange(mgr.OnConfigChange)
mgr.OnConfigChange(src.GetConfig())
src.OnConfigChange(ctx, mgr.OnConfigChange)
mgr.OnConfigChange(ctx, src.GetConfig())
return mgr
}
@ -36,7 +42,7 @@ func (mgr *MetricsManager) Close() error {
}
// OnConfigChange updates the metrics manager when configuration is changed.
func (mgr *MetricsManager) OnConfigChange(cfg *Config) {
func (mgr *MetricsManager) OnConfigChange(ctx context.Context, cfg *Config) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
@ -63,7 +69,7 @@ func (mgr *MetricsManager) updateInfo(cfg *Config) {
hostname, err := os.Hostname()
if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to get OS hostname")
log.Error(context.TODO()).Err(err).Msg("telemetry/metrics: failed to get OS hostname")
hostname = "__unknown__"
}
@ -84,13 +90,13 @@ func (mgr *MetricsManager) updateServer(cfg *Config) {
mgr.handler = nil
if mgr.addr == "" {
log.Info().Msg("metrics: http server disabled")
log.Info(context.TODO()).Msg("metrics: http server disabled")
return
}
handler, err := metrics.PrometheusHandler(EnvoyAdminURL, mgr.installationID)
if err != nil {
log.Error().Err(err).Msg("metrics: failed to create prometheus handler")
log.Error(context.TODO()).Err(err).Msg("metrics: failed to create prometheus handler")
return
}

View file

@ -1,6 +1,7 @@
package config
import (
"context"
"encoding/base64"
"fmt"
"net/http"
@ -12,12 +13,13 @@ import (
)
func TestMetricsManager(t *testing.T) {
ctx := context.Background()
src := NewStaticSource(&Config{
Options: &Options{
MetricsAddr: "ADDRESS",
},
})
mgr := NewMetricsManager(src)
mgr := NewMetricsManager(ctx, src)
srv1 := httptest.NewServer(mgr)
defer srv1.Close()
srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -42,7 +44,7 @@ func TestMetricsManagerBasicAuth(t *testing.T) {
MetricsBasicAuth: base64.StdEncoding.EncodeToString([]byte("x:y")),
},
})
mgr := NewMetricsManager(src)
mgr := NewMetricsManager(context.Background(), src)
srv1 := httptest.NewServer(mgr)
defer srv1.Close()

View file

@ -2,6 +2,7 @@ package config
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
@ -427,7 +428,7 @@ func (o *Options) viperIsSet(key string) bool {
// parseHeaders handles unmarshalling any custom headers correctly from the
// environment or viper's parsed keys
func (o *Options) parseHeaders() error {
func (o *Options) parseHeaders(ctx context.Context) error {
var headers map[string]string
if o.HeadersEnv != "" {
// Handle JSON by default via viper
@ -450,7 +451,7 @@ func (o *Options) parseHeaders() error {
}
if o.viperIsSet("headers") {
log.Warn().Msg("config: headers has been renamed to set_response_headers, it will be removed in v0.16")
log.Warn(ctx).Msg("config: headers has been renamed to set_response_headers, it will be removed in v0.16")
}
// option was renamed from `headers` to `set_response_headers`. Both config settings are supported.
@ -507,6 +508,7 @@ func bindEnvs(o *Options, v *viper.Viper) error {
// Validate ensures the Options fields are valid, and hydrated.
func (o *Options) Validate() error {
ctx := context.TODO()
if !IsValidService(o.Services) {
return fmt.Errorf("config: %s is an invalid service type", o.Services)
}
@ -591,7 +593,7 @@ func (o *Options) Validate() error {
return fmt.Errorf("config: failed to parse policy: %w", err)
}
if err := o.parseHeaders(); err != nil {
if err := o.parseHeaders(ctx); err != nil {
return fmt.Errorf("config: failed to parse headers: %w", err)
}
@ -669,7 +671,7 @@ func (o *Options) Validate() error {
// GoogleCloudServerlessAuthenticationServiceAccount
if o.Provider == "google" && o.GoogleCloudServerlessAuthenticationServiceAccount == "" {
o.GoogleCloudServerlessAuthenticationServiceAccount = o.ServiceAccount
log.Info().Msg("defaulting to idp_service_account for google_cloud_serverless_authentication_service_account")
log.Info(ctx).Msg("defaulting to idp_service_account for google_cloud_serverless_authentication_service_account")
}
// strip quotes from redirect address (#811)
@ -683,7 +685,7 @@ func (o *Options) Validate() error {
switch o.Provider {
case azure.Name, github.Name, gitlab.Name, google.Name, okta.Name, onelogin.Name:
if len(o.Scopes) > 0 {
log.Warn().Msg(idpCustomScopesWarnMsg)
log.Warn(ctx).Msg(idpCustomScopesWarnMsg)
}
default:
}

View file

@ -1,6 +1,7 @@
package config
import (
"context"
"encoding/base64"
"fmt"
"io/ioutil"
@ -193,7 +194,7 @@ func Test_parseHeaders(t *testing.T) {
o.viperSet("headers", tt.viperHeaders)
o.viperSet("HeadersEnv", tt.envHeaders)
o.HeadersEnv = tt.envHeaders
err := o.parseHeaders()
err := o.parseHeaders(context.Background())
if (err != nil) != tt.wantErr {
t.Errorf("Error condition unexpected: err=%s", err)

View file

@ -1,16 +1,18 @@
package config
import (
"context"
"fmt"
"reflect"
"sync"
octrace "go.opencensus.io/trace"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/rs/zerolog"
octrace "go.opencensus.io/trace"
)
// TracingOptions are the options for tracing.
@ -60,10 +62,13 @@ type TraceManager struct {
}
// NewTraceManager creates a new TraceManager.
func NewTraceManager(src Source) *TraceManager {
func NewTraceManager(ctx context.Context, src Source) *TraceManager {
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "trace_manager")
})
mgr := &TraceManager{}
src.OnConfigChange(mgr.OnConfigChange)
mgr.OnConfigChange(src.GetConfig())
src.OnConfigChange(ctx, mgr.OnConfigChange)
mgr.OnConfigChange(ctx, src.GetConfig())
return mgr
}
@ -79,18 +84,18 @@ func (mgr *TraceManager) Close() error {
}
// OnConfigChange updates the manager whenever the configuration is changed.
func (mgr *TraceManager) OnConfigChange(cfg *Config) {
func (mgr *TraceManager) OnConfigChange(ctx context.Context, cfg *Config) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
traceOpts, err := NewTracingOptions(cfg.Options)
if err != nil {
log.Error().Err(err).Msg("trace: failed to build tracing options")
log.Error(ctx).Err(err).Msg("trace: failed to build tracing options")
return
}
if reflect.DeepEqual(traceOpts, mgr.traceOpts) {
log.Debug().Msg("no change detected in trace options")
log.Debug(ctx).Msg("no change detected in trace options")
return
}
mgr.traceOpts = traceOpts
@ -104,11 +109,11 @@ func (mgr *TraceManager) OnConfigChange(cfg *Config) {
return
}
log.Info().Interface("options", traceOpts).Msg("trace: starting exporter")
log.Info(ctx).Interface("options", traceOpts).Msg("trace: starting exporter")
mgr.exporter, err = trace.RegisterTracing(traceOpts)
if err != nil {
log.Error().Err(err).Msg("trace: failed to register exporter")
log.Error(ctx).Err(err).Msg("trace: failed to register exporter")
return
}
}

View file

@ -123,12 +123,12 @@ func TestTraceManager(t *testing.T) {
TracingSampleRate: 1,
}})
_ = NewTraceManager(src)
_ = NewTraceManager(ctx, src)
_, span := trace.StartSpan(ctx, "Example")
span.End()
src.SetConfig(&Config{Options: &Options{
src.SetConfig(ctx, &Config{Options: &Options{
TracingProvider: "zipkin",
ZipkinEndpoint: srv2.URL,
TracingSampleRate: 1,

View file

@ -104,13 +104,13 @@ func New(cfg *config.Config) (*DataBroker, error) {
}
// OnConfigChange is called whenever configuration is changed.
func (c *DataBroker) OnConfigChange(cfg *config.Config) {
func (c *DataBroker) OnConfigChange(ctx context.Context, cfg *config.Config) {
err := c.update(cfg)
if err != nil {
log.Error().Err(err).Msg("databroker: error updating configuration")
log.Error(ctx).Err(err).Msg("databroker: error updating configuration")
}
c.dataBrokerServer.OnConfigChange(cfg)
c.dataBrokerServer.OnConfigChange(ctx, cfg)
}
// Register registers all the gRPC services with the given server.

View file

@ -27,7 +27,7 @@ func newDataBrokerServer(cfg *config.Config) *dataBrokerServer {
}
// OnConfigChange updates the underlying databroker server whenever configuration is changed.
func (srv *dataBrokerServer) OnConfigChange(cfg *config.Config) {
func (srv *dataBrokerServer) OnConfigChange(ctx context.Context, cfg *config.Config) {
srv.server.UpdateConfig(srv.getOptions(cfg)...)
srv.setKey(cfg)
}

View file

@ -5,7 +5,8 @@ import (
"net/http"
"net/http/cookiejar"
"github.com/rs/zerolog/log"
"github.com/pomerium/pomerium/internal/log"
"golang.org/x/net/publicsuffix"
)
@ -51,6 +52,6 @@ type loggingRoundTripper struct {
func (rt *loggingRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) {
res, err = rt.RoundTripper.RoundTrip(req)
log.Debug().Str("method", req.Method).Str("url", req.URL.String()).Msg("http request")
log.Debug(req.Context()).Str("method", req.Method).Str("url", req.URL.String()).Msg("http request")
return res, err
}

View file

@ -8,7 +8,7 @@ import (
"os"
"os/exec"
"github.com/rs/zerolog/log"
"github.com/pomerium/pomerium/internal/log"
)
type cmdOption func(*exec.Cmd)
@ -53,7 +53,7 @@ func run(ctx context.Context, name string, options ...cmdOption) error {
if err != nil {
return fmt.Errorf("failed to create stderr pipe for %s: %w", name, err)
}
go cmdLogger(stderr)
go cmdLogger(ctx, stderr)
defer stderr.Close()
}
if cmd.Stdout == nil {
@ -61,17 +61,17 @@ func run(ctx context.Context, name string, options ...cmdOption) error {
if err != nil {
return fmt.Errorf("failed to create stdout pipe for %s: %w", name, err)
}
go cmdLogger(stdout)
go cmdLogger(ctx, stdout)
defer stdout.Close()
}
log.Debug().Strs("args", cmd.Args).Msgf("running %s", name)
log.Debug(ctx).Strs("args", cmd.Args).Msgf("running %s", name)
return cmd.Run()
}
func cmdLogger(rdr io.Reader) {
func cmdLogger(ctx context.Context, rdr io.Reader) {
s := bufio.NewScanner(rdr)
for s.Scan() {
log.Debug().Msg(s.Text())
log.Debug(ctx).Msg(s.Text())
}
}

View file

@ -15,9 +15,9 @@ import (
"time"
"github.com/google/go-jsonnet"
"github.com/rs/zerolog/log"
"github.com/pomerium/pomerium/integration/internal/netutil"
"github.com/pomerium/pomerium/internal/log"
)
var requiredDeployments = []string{
@ -180,7 +180,7 @@ func applyManifests(ctx context.Context, jsonsrc string) error {
return fmt.Errorf("error applying manifests: %w", err)
}
log.Info().Msg("waiting for deployments to come up")
log.Info(ctx).Msg("waiting for deployments to come up")
ctx, clearTimeout := context.WithTimeout(ctx, 15*time.Minute)
defer clearTimeout()
ticker := time.NewTicker(time.Second * 5)
@ -218,7 +218,7 @@ func applyManifests(ctx context.Context, jsonsrc string) error {
for _, dep := range requiredDeployments {
if byName[dep] < 1 {
done = false
log.Warn().Str("deployment", dep).Msg("deployment is not ready yet")
log.Warn(ctx).Str("deployment", dep).Msg("deployment is not ready yet")
}
}
if done {
@ -233,7 +233,7 @@ func applyManifests(ctx context.Context, jsonsrc string) error {
<-ticker.C
}
time.Sleep(time.Minute)
log.Info().Msg("all deployments are ready")
log.Info(ctx).Msg("all deployments are ready")
return nil
}

View file

@ -69,15 +69,15 @@ func newManager(ctx context.Context,
if err != nil {
return nil, err
}
mgr.src.OnConfigChange(func(cfg *config.Config) {
mgr.src.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) {
err := mgr.update(cfg)
if err != nil {
log.Error().Err(err).Msg("autocert: error updating config")
log.Error(ctx).Err(err).Msg("autocert: error updating config")
return
}
cfg = mgr.GetConfig()
mgr.Trigger(cfg)
mgr.Trigger(ctx, cfg)
})
go func() {
ticker := time.NewTicker(checkInterval)
@ -90,7 +90,7 @@ func newManager(ctx context.Context,
case <-ticker.C:
err := mgr.renewConfigCerts()
if err != nil {
log.Error().Err(err).Msg("autocert: error updating config")
log.Error(context.TODO()).Err(err).Msg("autocert: error updating config")
return
}
}
@ -153,7 +153,7 @@ func (mgr *Manager) renewConfigCerts() error {
}
mgr.config = cfg
mgr.Trigger(cfg)
mgr.Trigger(context.TODO(), cfg)
return nil
}
@ -172,10 +172,10 @@ func (mgr *Manager) update(cfg *config.Config) error {
func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.Certificate, error) {
cert, err := cm.CacheManagedCertificate(domain)
if err != nil {
log.Info().Str("domain", domain).Msg("obtaining certificate")
log.Info(context.TODO()).Str("domain", domain).Msg("obtaining certificate")
err = cm.ObtainCert(context.Background(), domain, false)
if err != nil {
log.Error().Err(err).Msg("autocert failed to obtain client certificate")
log.Error(context.TODO()).Err(err).Msg("autocert failed to obtain client certificate")
return certmagic.Certificate{}, errObtainCertFailed
}
metrics.RecordAutocertRenewal()
@ -187,13 +187,13 @@ func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.C
// renewCert attempts to renew given certificate.
func (mgr *Manager) renewCert(domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) {
expired := time.Now().After(cert.Leaf.NotAfter)
log.Info().Str("domain", domain).Msg("renewing certificate")
log.Info(context.TODO()).Str("domain", domain).Msg("renewing certificate")
err := cm.RenewCert(context.Background(), domain, false)
if err != nil {
if expired {
return certmagic.Certificate{}, errRenewCertFailed
}
log.Warn().Err(err).Msg("renew client certificated failed, use existing cert")
log.Warn(context.TODO()).Err(err).Msg("renew client certificated failed, use existing cert")
}
return cm.CacheManagedCertificate(domain)
}
@ -220,11 +220,11 @@ func (mgr *Manager) updateAutocert(cfg *config.Config) error {
return fmt.Errorf("autocert: failed to renew client certificate: %w", err)
}
if err != nil {
log.Error().Err(err).Msg("autocert: failed to obtain client certificate")
log.Error(context.TODO()).Err(err).Msg("autocert: failed to obtain client certificate")
continue
}
log.Info().Strs("names", cert.Names).Msg("autocert: added certificate")
log.Info(context.TODO()).Strs("names", cert.Names).Msg("autocert: added certificate")
cfg.AutoCertificates = append(cfg.AutoCertificates, cert.Certificate)
}
@ -260,10 +260,10 @@ func (mgr *Manager) updateServer(cfg *config.Config) {
}),
}
go func() {
log.Info().Str("addr", hsrv.Addr).Msg("starting http redirect server")
log.Info(context.TODO()).Str("addr", hsrv.Addr).Msg("starting http redirect server")
err := hsrv.ListenAndServe()
if err != nil {
log.Error().Err(err).Msg("failed to run http redirect server")
log.Error(context.TODO()).Err(err).Msg("failed to run http redirect server")
}
}()
mgr.srv = hsrv

View file

@ -12,6 +12,7 @@ import (
"syscall"
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
"github.com/pomerium/pomerium/authenticate"
@ -32,7 +33,7 @@ import (
// Run runs the main pomerium application.
func Run(ctx context.Context, configFile string) error {
log.Info().Str("version", version.FullVersion()).Msg("cmd/pomerium")
log.Info(ctx).Str("version", version.FullVersion()).Msg("cmd/pomerium")
var src config.Source
@ -41,8 +42,8 @@ func Run(ctx context.Context, configFile string) error {
return err
}
src = databroker.NewConfigSource(src)
logMgr := config.NewLogManager(src)
src = databroker.NewConfigSource(ctx, src)
logMgr := config.NewLogManager(ctx, src)
defer logMgr.Close()
// trigger changes when underlying files are changed
@ -56,9 +57,9 @@ func Run(ctx context.Context, configFile string) error {
// override the default http transport so we can use the custom CA in the TLS client config (#1570)
http.DefaultTransport = config.NewHTTPTransport(src)
metricsMgr := config.NewMetricsManager(src)
metricsMgr := config.NewMetricsManager(ctx, src)
defer metricsMgr.Close()
traceMgr := config.NewTraceManager(src)
traceMgr := config.NewTraceManager(ctx, src)
defer traceMgr.Close()
// setup the control plane
@ -66,43 +67,46 @@ func Run(ctx context.Context, configFile string) error {
if err != nil {
return fmt.Errorf("error creating control plane: %w", err)
}
src.OnConfigChange(func(cfg *config.Config) {
if err := controlPlane.OnConfigChange(cfg); err != nil {
log.Error().Err(err).Msg("config change")
}
})
src.OnConfigChange(ctx,
func(ctx context.Context, cfg *config.Config) {
if err := controlPlane.OnConfigChange(ctx, cfg); err != nil {
log.Error(ctx).Err(err).Msg("config change")
}
})
if err = controlPlane.OnConfigChange(src.GetConfig()); err != nil {
if err = controlPlane.OnConfigChange(log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("config_file_source", configFile).Bool("bootstrap", true)
}), src.GetConfig()); err != nil {
return fmt.Errorf("applying config: %w", err)
}
_, grpcPort, _ := net.SplitHostPort(controlPlane.GRPCListener.Addr().String())
_, httpPort, _ := net.SplitHostPort(controlPlane.HTTPListener.Addr().String())
log.Info().Str("port", grpcPort).Msg("gRPC server started")
log.Info().Str("port", httpPort).Msg("HTTP server started")
log.Info(ctx).Str("port", grpcPort).Msg("gRPC server started")
log.Info(ctx).Str("port", httpPort).Msg("HTTP server started")
// create envoy server
envoyServer, err := envoy.NewServer(src, grpcPort, httpPort, controlPlane.Builder)
envoyServer, err := envoy.NewServer(ctx, src, grpcPort, httpPort, controlPlane.Builder)
if err != nil {
return fmt.Errorf("error creating envoy server: %w", err)
}
defer envoyServer.Close()
// add services
if err := setupAuthenticate(src, controlPlane); err != nil {
if err := setupAuthenticate(ctx, src, controlPlane); err != nil {
return err
}
var authorizeServer *authorize.Authorize
if config.IsAuthorize(src.GetConfig().Options.Services) {
authorizeServer, err = setupAuthorize(src, controlPlane)
authorizeServer, err = setupAuthorize(ctx, src, controlPlane)
if err != nil {
return err
}
}
var dataBrokerServer *databroker_service.DataBroker
if config.IsDataBroker(src.GetConfig().Options.Services) {
dataBrokerServer, err = setupDataBroker(src, controlPlane)
dataBrokerServer, err = setupDataBroker(ctx, src, controlPlane)
if err != nil {
return fmt.Errorf("setting up databroker: %w", err)
}
@ -112,10 +116,10 @@ func Run(ctx context.Context, configFile string) error {
}
}
if err = setupRegistryReporter(src); err != nil {
if err = setupRegistryReporter(ctx, src); err != nil {
return fmt.Errorf("setting up registry reporter: %w", err)
}
if err := setupProxy(src, controlPlane); err != nil {
if err := setupProxy(ctx, src, controlPlane); err != nil {
return err
}
@ -159,7 +163,7 @@ func Run(ctx context.Context, configFile string) error {
return eg.Wait()
}
func setupAuthenticate(src config.Source, controlPlane *controlplane.Server) error {
func setupAuthenticate(ctx context.Context, src config.Source, controlPlane *controlplane.Server) error {
if !config.IsAuthenticate(src.GetConfig().Options.Services) {
return nil
}
@ -174,56 +178,56 @@ func setupAuthenticate(src config.Source, controlPlane *controlplane.Server) err
return fmt.Errorf("error getting authenticate URL: %w", err)
}
src.OnConfigChange(svc.OnConfigChange)
svc.OnConfigChange(src.GetConfig())
src.OnConfigChange(ctx, svc.OnConfigChange)
svc.OnConfigChange(ctx, src.GetConfig())
host := urlutil.StripPort(authenticateURL.Host)
sr := controlPlane.HTTPRouter.Host(host).Subrouter()
svc.Mount(sr)
log.Info().Str("host", host).Msg("enabled authenticate service")
log.Info(context.TODO()).Str("host", host).Msg("enabled authenticate service")
return nil
}
func setupAuthorize(src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) {
func setupAuthorize(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) {
svc, err := authorize.New(src.GetConfig())
if err != nil {
return nil, fmt.Errorf("error creating authorize service: %w", err)
}
envoy_service_auth_v3.RegisterAuthorizationServer(controlPlane.GRPCServer, svc)
log.Info().Msg("enabled authorize service")
src.OnConfigChange(svc.OnConfigChange)
svc.OnConfigChange(src.GetConfig())
log.Info(context.TODO()).Msg("enabled authorize service")
src.OnConfigChange(ctx, svc.OnConfigChange)
svc.OnConfigChange(ctx, src.GetConfig())
return svc, nil
}
func setupDataBroker(src config.Source, controlPlane *controlplane.Server) (*databroker_service.DataBroker, error) {
func setupDataBroker(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*databroker_service.DataBroker, error) {
svc, err := databroker_service.New(src.GetConfig())
if err != nil {
return nil, fmt.Errorf("error creating databroker service: %w", err)
}
svc.Register(controlPlane.GRPCServer)
log.Info().Msg("enabled databroker service")
src.OnConfigChange(svc.OnConfigChange)
svc.OnConfigChange(src.GetConfig())
log.Info(context.TODO()).Msg("enabled databroker service")
src.OnConfigChange(ctx, svc.OnConfigChange)
svc.OnConfigChange(ctx, src.GetConfig())
return svc, nil
}
func setupRegistryServer(src config.Source, controlPlane *controlplane.Server) error {
svc := registry.NewInMemoryServer(context.TODO(), registryTTL)
registry_pb.RegisterRegistryServer(controlPlane.GRPCServer, svc)
log.Info().Msg("enabled service discovery")
log.Info(context.TODO()).Msg("enabled service discovery")
return nil
}
func setupRegistryReporter(src config.Source) error {
func setupRegistryReporter(ctx context.Context, src config.Source) error {
reporter := new(registry.Reporter)
src.OnConfigChange(reporter.OnConfigChange)
reporter.OnConfigChange(src.GetConfig())
src.OnConfigChange(ctx, reporter.OnConfigChange)
reporter.OnConfigChange(ctx, src.GetConfig())
return nil
}
func setupProxy(src config.Source, controlPlane *controlplane.Server) error {
func setupProxy(ctx context.Context, src config.Source, controlPlane *controlplane.Server) error {
if !config.IsProxy(src.GetConfig().Options.Services) {
return nil
}
@ -234,9 +238,9 @@ func setupProxy(src config.Source, controlPlane *controlplane.Server) error {
}
controlPlane.HTTPRouter.PathPrefix("/").Handler(svc)
log.Info().Msg("enabled proxy service")
src.OnConfigChange(svc.OnConfigChange)
svc.OnConfigChange(src.GetConfig())
log.Info(context.TODO()).Msg("enabled proxy service")
src.OnConfigChange(ctx, svc.OnConfigChange)
svc.OnConfigChange(ctx, src.GetConfig())
return nil
}

View file

@ -19,7 +19,7 @@ func (srv *Server) StreamAccessLogs(stream envoy_service_accesslog_v3.AccessLogS
for {
msg, err := stream.Recv()
if err != nil {
log.Error().Err(err).Msg("access log stream error, disconnecting")
log.Error(stream.Context()).Err(err).Msg("access log stream error, disconnecting")
return err
}
@ -27,9 +27,9 @@ func (srv *Server) StreamAccessLogs(stream envoy_service_accesslog_v3.AccessLogS
reqPath := entry.GetRequest().GetPath()
var evt *zerolog.Event
if reqPath == "/ping" || reqPath == "/healthz" {
evt = log.Debug()
evt = log.Debug(stream.Context())
} else {
evt = log.Info()
evt = log.Info(stream.Context())
}
// common properties
evt = evt.Str("service", "envoy")

View file

@ -9,6 +9,7 @@ import (
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
"github.com/gorilla/mux"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
@ -106,7 +107,11 @@ func NewServer(name string, metricsMgr *config.MetricsManager) (*Server, error)
srv.reproxy,
)
res, err := srv.buildDiscoveryResources()
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
return c.Str("server_name", name)
})
res, err := srv.buildDiscoveryResources(ctx)
if err != nil {
return nil, err
}
@ -123,7 +128,7 @@ func (srv *Server) Run(ctx context.Context) error {
// start the gRPC server
eg.Go(func() error {
log.Info().Str("addr", srv.GRPCListener.Addr().String()).Msg("starting control-plane gRPC server")
log.Info(ctx).Str("addr", srv.GRPCListener.Addr().String()).Msg("starting control-plane gRPC server")
return srv.GRPCServer.Serve(srv.GRPCListener)
})
@ -160,7 +165,7 @@ func (srv *Server) Run(ctx context.Context) error {
// start the HTTP server
eg.Go(func() error {
log.Info().Str("addr", srv.HTTPListener.Addr().String()).Msg("starting control-plane HTTP server")
log.Info(ctx).Str("addr", srv.HTTPListener.Addr().String()).Msg("starting control-plane HTTP server")
return hsrv.Serve(srv.HTTPListener)
})
@ -178,17 +183,17 @@ func (srv *Server) Run(ctx context.Context) error {
}
// OnConfigChange updates the pomerium config options.
func (srv *Server) OnConfigChange(cfg *config.Config) error {
srv.reproxy.Update(cfg)
func (srv *Server) OnConfigChange(ctx context.Context, cfg *config.Config) error {
srv.reproxy.Update(ctx, cfg)
prev := srv.currentConfig.Load()
srv.currentConfig.Store(versionedConfig{
Config: cfg,
version: prev.version + 1,
})
res, err := srv.buildDiscoveryResources()
res, err := srv.buildDiscoveryResources(ctx)
if err != nil {
return err
}
srv.xdsmgr.Update(res)
srv.xdsmgr.Update(ctx, res)
return nil
}

View file

@ -1,6 +1,7 @@
package controlplane
import (
"context"
"encoding/hex"
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
@ -14,11 +15,11 @@ const (
listenerTypeURL = "type.googleapis.com/envoy.config.listener.v3.Listener"
)
func (srv *Server) buildDiscoveryResources() (map[string][]*envoy_service_discovery_v3.Resource, error) {
func (srv *Server) buildDiscoveryResources(ctx context.Context) (map[string][]*envoy_service_discovery_v3.Resource, error) {
resources := map[string][]*envoy_service_discovery_v3.Resource{}
cfg := srv.currentConfig.Load()
clusters, err := srv.Builder.BuildClusters(cfg.Config)
clusters, err := srv.Builder.BuildClusters(ctx, cfg.Config)
if err != nil {
return nil, err
}

View file

@ -2,6 +2,7 @@
package xdsmgr
import (
"context"
"encoding/json"
"errors"
"sync"
@ -51,7 +52,7 @@ func (mgr *Manager) DeltaAggregatedResources(
stateByTypeURL := map[string]*streamState{}
getDeltaResponse := func(typeURL string) *envoy_service_discovery_v3.DeltaDiscoveryResponse {
getDeltaResponse := func(ctx context.Context, typeURL string) *envoy_service_discovery_v3.DeltaDiscoveryResponse {
mgr.mu.Lock()
defer mgr.mu.Unlock()
@ -85,7 +86,7 @@ func (mgr *Manager) DeltaAggregatedResources(
return res
}
handleDeltaRequest := func(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
handleDeltaRequest := func(ctx context.Context, req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
@ -109,8 +110,9 @@ func (mgr *Manager) DeltaAggregatedResources(
case req.GetErrorDetail() != nil:
// a NACK
bs, _ := json.Marshal(req.ErrorDetail.Details)
log.Error().
log.Error(ctx).
Err(errors.New(req.ErrorDetail.Message)).
Str("nonce", req.ResponseNonce).
Int32("code", req.ErrorDetail.Code).
RawJSON("details", bs).Msg("error applying configuration")
// - set the client resource versions to the current resource versions
@ -121,12 +123,18 @@ func (mgr *Manager) DeltaAggregatedResources(
case req.GetResponseNonce() == mgr.nonce:
// an ACK for the last response
// - set the client resource versions to the current resource versions
log.Debug(ctx).
Str("nonce", req.ResponseNonce).
Msg("ACK")
state.clientResourceVersions = make(map[string]string)
for _, resource := range mgr.resources[req.GetTypeUrl()] {
state.clientResourceVersions[resource.Name] = resource.Version
}
default:
// an ACK for a response that's not the last response
log.Debug(ctx).
Str("nonce", req.ResponseNonce).
Msg("stale ACK")
}
// update subscriptions
@ -168,15 +176,16 @@ func (mgr *Manager) DeltaAggregatedResources(
})
// 2. handle incoming requests or resource changes
eg.Go(func() error {
changeCtx := ctx
for {
var typeURLs []string
select {
case <-ctx.Done():
return ctx.Err()
case req := <-incoming:
handleDeltaRequest(req)
handleDeltaRequest(changeCtx, req)
typeURLs = []string{req.GetTypeUrl()}
case <-ch:
case changeCtx = <-ch:
mgr.mu.Lock()
for typeURL := range mgr.resources {
typeURLs = append(typeURLs, typeURL)
@ -185,7 +194,7 @@ func (mgr *Manager) DeltaAggregatedResources(
}
for _, typeURL := range typeURLs {
res := getDeltaResponse(typeURL)
res := getDeltaResponse(changeCtx, typeURL)
if res == nil {
continue
}
@ -194,6 +203,10 @@ func (mgr *Manager) DeltaAggregatedResources(
case <-ctx.Done():
return ctx.Err()
case outgoing <- res:
log.Info(changeCtx).
Str("nounce", res.Nonce).
Str("type", res.TypeUrl).
Msg("send update")
}
}
}
@ -224,11 +237,11 @@ func (mgr *Manager) StreamAggregatedResources(
// Update updates the state of resources. If any changes are made they will be pushed to any listening
// streams. For each TypeURL the list of resources should be the complete list of resources.
func (mgr *Manager) Update(resources map[string][]*envoy_service_discovery_v3.Resource) {
func (mgr *Manager) Update(ctx context.Context, resources map[string][]*envoy_service_discovery_v3.Resource) {
mgr.mu.Lock()
mgr.nonce = uuid.New().String()
mgr.resources = resources
mgr.mu.Unlock()
mgr.signal.Broadcast()
mgr.signal.Broadcast(ctx)
}

View file

@ -28,7 +28,7 @@ func TestManager(t *testing.T) {
origOnHandleDeltaRequest := onHandleDeltaRequest
defer func() { onHandleDeltaRequest = origOnHandleDeltaRequest }()
onHandleDeltaRequest = func(state *streamState) {
stateChanged.Broadcast()
stateChanged.Broadcast(ctx)
}
srv := grpc.NewServer()
@ -94,7 +94,7 @@ func TestManager(t *testing.T) {
}, msg.GetResources())
ack(msg.Nonce)
mgr.Update(map[string][]*envoy_service_discovery_v3.Resource{
mgr.Update(ctx, map[string][]*envoy_service_discovery_v3.Resource{
typeURL: {{Name: "r1", Version: "2"}},
})
@ -105,7 +105,7 @@ func TestManager(t *testing.T) {
}, msg.GetResources())
ack(msg.Nonce)
mgr.Update(map[string][]*envoy_service_discovery_v3.Resource{
mgr.Update(ctx, map[string][]*envoy_service_discovery_v3.Resource{
typeURL: nil,
})

View file

@ -1,6 +1,7 @@
package databroker
import (
"context"
"crypto/tls"
"encoding/base64"
"time"
@ -65,7 +66,7 @@ func WithSharedKey(sharedKey string) ServerOption {
return func(cfg *serverConfig) {
key, err := base64.StdEncoding.DecodeString(sharedKey)
if err != nil || len(key) != cryptutil.DefaultKeySize {
log.Error().Err(err).Msgf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
log.Error(context.TODO()).Err(err).Msgf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
return
}
cfg.secret = key

View file

@ -35,22 +35,22 @@ type dbConfig struct {
}
// NewConfigSource creates a new ConfigSource.
func NewConfigSource(underlying config.Source, listeners ...config.ChangeListener) *ConfigSource {
func NewConfigSource(ctx context.Context, underlying config.Source, listeners ...config.ChangeListener) *ConfigSource {
src := &ConfigSource{
dbConfigs: map[string]dbConfig{},
}
for _, li := range listeners {
src.OnConfigChange(li)
src.OnConfigChange(ctx, li)
}
underlying.OnConfigChange(func(cfg *config.Config) {
underlying.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) {
src.mu.Lock()
src.underlyingConfig = cfg.Clone()
src.mu.Unlock()
src.rebuild(false)
src.rebuild(ctx, firstTime(false))
})
src.underlyingConfig = underlying.GetConfig()
src.rebuild(true)
src.rebuild(ctx, firstTime(true))
return src
}
@ -62,8 +62,10 @@ func (src *ConfigSource) GetConfig() *config.Config {
return src.computedConfig
}
func (src *ConfigSource) rebuild(firstTime bool) {
_, span := trace.StartSpan(context.Background(), "databroker.config_source.rebuild")
type firstTime bool
func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
_, span := trace.StartSpan(ctx, "databroker.config_source.rebuild")
defer span.End()
src.mu.Lock()
@ -78,7 +80,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
for _, policy := range cfg.Options.GetAllPolicies() {
id, err := policy.RouteID()
if err != nil {
log.Warn().Err(err).
log.Warn(ctx).Err(err).
Str("policy", policy.String()).
Msg("databroker: invalid policy config, ignoring")
return
@ -95,7 +97,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
err := cfg.Options.Validate()
if err != nil {
metrics.SetDBConfigRejected(cfg.Options.Services, id, cfgpb.version, err)
metrics.SetDBConfigRejected(ctx, cfg.Options.Services, id, cfgpb.version, err)
return
}
@ -103,7 +105,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
policy, err := config.NewPolicyFromProto(routepb)
if err != nil {
errCount++
log.Warn().Err(err).
log.Warn(ctx).Err(err).
Str("db_config_id", id).
Msg("databroker: error converting protobuf into policy")
continue
@ -112,7 +114,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
err = policy.Validate()
if err != nil {
errCount++
log.Warn().Err(err).
log.Warn(ctx).Err(err).
Str("db_config_id", id).
Str("policy", policy.String()).
Msg("databroker: invalid policy, ignoring")
@ -122,7 +124,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
routeID, err := policy.RouteID()
if err != nil {
errCount++
log.Warn().Err(err).
log.Warn(ctx).Err(err).
Str("db_config_id", id).
Str("policy", policy.String()).
Msg("databroker: cannot establish policy route ID, ignoring")
@ -131,7 +133,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
if _, ok := seen[routeID]; ok {
errCount++
log.Warn().Err(err).
log.Warn(ctx).Err(err).
Str("db_config_id", id).
Str("policy", policy.String()).
Msg("databroker: duplicate policy detected, ignoring")
@ -141,7 +143,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
additionalPolicies = append(additionalPolicies, *policy)
}
metrics.SetDBConfigInfo(cfg.Options.Services, id, cfgpb.version, int64(errCount))
metrics.SetDBConfigInfo(ctx, cfg.Options.Services, id, cfgpb.version, int64(errCount))
}
// add the additional policies here since calling `Validate` will reset them.
@ -149,10 +151,10 @@ func (src *ConfigSource) rebuild(firstTime bool) {
src.computedConfig = cfg
if !firstTime {
src.Trigger(cfg)
src.Trigger(ctx, cfg)
}
metrics.SetConfigInfo(cfg.Options.Services, "databroker", cfg.Checksum(), true)
metrics.SetConfigInfo(ctx, cfg.Options.Services, "databroker", cfg.Checksum(), true)
}
func (src *ConfigSource) runUpdater(cfg *config.Config) {
@ -191,7 +193,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
cc, err := grpc.NewGRPCClientConn(connectionOptions)
if err != nil {
log.Error().Err(err).Msg("databroker: failed to create gRPC connection to data broker")
log.Error(context.TODO()).Err(err).Msg("databroker: failed to create gRPC connection to data broker")
return
}
@ -237,7 +239,7 @@ func (s *syncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64,
var cfgpb configpb.Config
err := record.GetData().UnmarshalTo(&cfgpb)
if err != nil {
log.Warn().Err(err).Msg("databroker: error decoding config")
log.Warn(ctx).Err(err).Msg("databroker: error decoding config")
delete(s.src.dbConfigs, record.GetId())
continue
}
@ -246,5 +248,5 @@ func (s *syncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64,
}
s.src.mu.Unlock()
s.src.rebuild(false)
s.src.rebuild(ctx, firstTime(false))
}

View file

@ -37,9 +37,9 @@ func TestConfigSource(t *testing.T) {
base.InsecureServer = true
base.GRPCInsecure = true
src := NewConfigSource(config.NewStaticSource(&config.Config{
src := NewConfigSource(ctx, config.NewStaticSource(&config.Config{
Options: base,
}), func(cfg *config.Config) {
}), func(_ context.Context, cfg *config.Config) {
cfgs <- cfg
})
cfgs <- src.GetConfig()

View file

@ -10,7 +10,6 @@ import (
"sync"
"github.com/google/go-cmp/cmp"
"github.com/rs/zerolog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/anypb"
@ -35,7 +34,6 @@ const (
// Server implements the databroker service using an in memory database.
type Server struct {
cfg *serverConfig
log zerolog.Logger
mu sync.RWMutex
version uint64
@ -44,33 +42,31 @@ type Server struct {
// New creates a new server.
func New(options ...ServerOption) *Server {
srv := &Server{
log: log.With().Str("service", "databroker").Logger(),
}
srv := &Server{}
srv.UpdateConfig(options...)
return srv
}
func (srv *Server) initVersion() {
func (srv *Server) initVersion(ctx context.Context) {
db, _, err := srv.getBackendLocked()
if err != nil {
log.Error().Err(err).Msg("failed to init server version")
log.Error(ctx).Err(err).Msg("failed to init server version")
return
}
// Get version from storage first.
r, err := db.Get(context.Background(), recordTypeServerVersion, serverVersionKey)
r, err := db.Get(ctx, recordTypeServerVersion, serverVersionKey)
switch {
case err == nil:
var sv wrapperspb.UInt64Value
if err := r.GetData().UnmarshalTo(&sv); err == nil {
srv.log.Debug().Uint64("server_version", sv.Value).Msg("got db version from Backend")
log.Debug(ctx).Uint64("server_version", sv.Value).Msg("got db version from Backend")
srv.version = sv.Value
}
return
case errors.Is(err, storage.ErrNotFound): // no server version, so we'll create a new one
case err != nil:
log.Error().Err(err).Msg("failed to retrieve server version")
log.Error(ctx).Err(err).Msg("failed to retrieve server version")
return
}
@ -81,7 +77,7 @@ func (srv *Server) initVersion() {
Id: serverVersionKey,
Data: data,
}); err != nil {
srv.log.Warn().Err(err).Msg("failed to save server version.")
log.Warn(ctx).Err(err).Msg("failed to save server version.")
}
}
@ -90,9 +86,11 @@ func (srv *Server) UpdateConfig(options ...ServerOption) {
srv.mu.Lock()
defer srv.mu.Unlock()
ctx := context.TODO()
cfg := newServerConfig(options...)
if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) {
log.Debug().Msg("databroker: no changes detected, re-using existing DBs")
log.Debug(ctx).Msg("databroker: no changes detected, re-using existing DBs")
return
}
srv.cfg = cfg
@ -100,19 +98,19 @@ func (srv *Server) UpdateConfig(options ...ServerOption) {
if srv.backend != nil {
err := srv.backend.Close()
if err != nil {
log.Error().Err(err).Msg("databroker: error closing backend")
log.Error(ctx).Err(err).Msg("databroker: error closing backend")
}
srv.backend = nil
}
srv.initVersion()
srv.initVersion(ctx)
}
// Get gets a record from the in-memory list.
func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databroker.GetResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.Get")
defer span.End()
srv.log.Info().
log.Info(ctx).
Str("peer", grpcutil.GetPeerAddr(ctx)).
Str("type", req.GetType()).
Str("id", req.GetId()).
@ -141,7 +139,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*databroker.QueryResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.Query")
defer span.End()
srv.log.Info().
log.Info(ctx).
Str("peer", grpcutil.GetPeerAddr(ctx)).
Str("type", req.GetType()).
Str("query", req.GetQuery()).
@ -185,7 +183,7 @@ func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databr
defer span.End()
record := req.GetRecord()
srv.log.Info().
log.Info(ctx).
Str("peer", grpcutil.GetPeerAddr(ctx)).
Str("type", record.GetType()).
Str("id", record.GetId()).
@ -208,7 +206,7 @@ func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databr
func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error {
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync")
defer span.End()
srv.log.Info().
log.Info(stream.Context()).
Str("peer", grpcutil.GetPeerAddr(stream.Context())).
Uint64("server_version", req.GetServerVersion()).
Uint64("record_version", req.GetRecordVersion()).
@ -251,7 +249,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error {
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.SyncLatest")
defer span.End()
srv.log.Info().
log.Info(stream.Context()).
Str("peer", grpcutil.GetPeerAddr(stream.Context())).
Str("type", req.GetType()).
Msg("sync latest")
@ -333,9 +331,10 @@ func (srv *Server) getBackendLocked() (backend storage.Backend, version uint64,
}
func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
ctx := context.Background()
caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile)
if err != nil {
log.Warn().Err(err).Msg("failed to read databroker CA file")
log.Warn(ctx).Err(err).Msg("failed to read databroker CA file")
}
tlsConfig := &tls.Config{
RootCAs: caCertPool,
@ -348,10 +347,10 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
switch srv.cfg.storageType {
case config.StorageInMemoryName:
srv.log.Info().Msg("using in-memory store")
log.Info(ctx).Msg("using in-memory store")
return inmemory.New(), nil
case config.StorageRedisName:
srv.log.Info().Msg("using redis store")
log.Info(ctx).Msg("using redis store")
backend, err = redis.New(
srv.cfg.storageConnectionString,
redis.WithTLSConfig(tlsConfig),

View file

@ -10,7 +10,6 @@ import (
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
)
@ -19,7 +18,6 @@ func newServer(cfg *serverConfig) *Server {
return &Server{
version: 11,
cfg: cfg,
log: log.With().Str("service", "databroker").Logger(),
}
}

View file

@ -69,19 +69,25 @@ func getConfig(options ...Option) *config {
// The Provider retrieves users and groups from gitlab.
type Provider struct {
cfg *config
log zerolog.Logger
}
// New creates a new Provider.
func New(options ...Option) *Provider {
return &Provider{
cfg: getConfig(options...),
log: log.With().Str("service", "directory").Str("provider", "gitlab").Logger(),
}
}
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "directory").Str("provider", "gitlab")
})
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
ctx = withLog(ctx)
du := &directory.User{
Id: userID,
}
@ -107,11 +113,13 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc
// UserGroups gets the directory user groups for gitlab.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
ctx = withLog(ctx)
if p.cfg.serviceAccount == nil {
return nil, nil, fmt.Errorf("gitlab: service account not defined")
}
p.log.Info().Msg("getting user groups")
log.Info(ctx).Msg("getting user groups")
groups, err := p.listGroups(ctx, "")
if err != nil {

View file

@ -93,7 +93,6 @@ func getConfig(options ...Option) *config {
// A Provider is an Okta user group directory provider.
type Provider struct {
cfg *config
log zerolog.Logger
lastUpdated *time.Time
groups map[string]*directory.Group
}
@ -102,13 +101,20 @@ type Provider struct {
func New(options ...Option) *Provider {
return &Provider{
cfg: getConfig(options...),
log: log.With().Str("service", "directory").Str("provider", "okta").Logger(),
groups: make(map[string]*directory.Group),
}
}
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "directory").Str("provider", "okta")
})
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
ctx = withLog(ctx)
if p.cfg.serviceAccount == nil {
return nil, ErrServiceAccountNotDefined
}
@ -139,11 +145,13 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc
// UserGroups fetches the groups of which the user is a member
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
ctx = withLog(ctx)
if p.cfg.serviceAccount == nil {
return nil, nil, ErrServiceAccountNotDefined
}
p.log.Info().Msg("getting user groups")
log.Info(ctx).Msg("getting user groups")
if p.cfg.providerURL == nil {
return nil, nil, ErrProviderURLNotDefined
@ -164,7 +172,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
// the cached lookup and the local groups list
var apiErr *APIError
if errors.As(err, &apiErr) && apiErr.HTTPStatusCode == http.StatusNotFound {
log.Debug().Str("group", group.Id).Msg("okta: group was removed")
log.Debug(ctx).Str("group", group.Id).Msg("okta: group was removed")
delete(p.groups, group.Id)
groups = append(groups[:i], groups[i+1:]...)
i--

View file

@ -78,7 +78,6 @@ func getConfig(options ...Option) *config {
// The Provider retrieves users and groups from onelogin.
type Provider struct {
cfg *config
log zerolog.Logger
mu sync.RWMutex
token *oauth2.Token
@ -89,10 +88,15 @@ func New(options ...Option) *Provider {
cfg := getConfig(options...)
return &Provider{
cfg: cfg,
log: log.With().Str("service", "directory").Str("provider", "onelogin").Logger(),
}
}
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "directory").Str("provider", "onelogin")
})
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
if p.cfg.serviceAccount == nil {
@ -102,6 +106,8 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc
Id: userID,
}
ctx = withLog(ctx)
token, err := p.getToken(ctx)
if err != nil {
return nil, err
@ -124,7 +130,9 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
return nil, nil, fmt.Errorf("onelogin: service account not defined")
}
p.log.Info().Msg("getting user groups")
ctx = withLog(ctx)
log.Info(ctx).Msg("getting user groups")
token, err := p.getToken(ctx)
if err != nil {
@ -252,7 +260,7 @@ func (p *Provider) apiGet(ctx context.Context, accessToken string, uri string, o
return "", err
}
p.log.Info().
log.Info(ctx).
Str("url", uri).
Interface("result", result).
Msg("api request")

View file

@ -50,8 +50,9 @@ func GetProvider(options Options) (provider Provider) {
globalProvider.Lock()
defer globalProvider.Unlock()
ctx := context.TODO()
if globalProvider.provider != nil && cmp.Equal(globalProvider.options, options) {
log.Debug().Str("provider", options.Provider).Msg("directory: no change detected, reusing existing directory provider")
log.Debug(ctx).Str("provider", options.Provider).Msg("directory: no change detected, reusing existing directory provider")
return globalProvider.provider
}
defer func() {
@ -72,7 +73,7 @@ func GetProvider(options Options) (provider Provider) {
auth0.WithDomain(options.ProviderURL),
auth0.WithServiceAccount(serviceAccount))
}
log.Warn().
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
@ -82,7 +83,7 @@ func GetProvider(options Options) (provider Provider) {
if err == nil {
return azure.New(azure.WithServiceAccount(serviceAccount))
}
log.Warn().
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
@ -92,7 +93,7 @@ func GetProvider(options Options) (provider Provider) {
if err == nil {
return github.New(github.WithServiceAccount(serviceAccount))
}
log.Warn().
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
@ -107,7 +108,7 @@ func GetProvider(options Options) (provider Provider) {
gitlab.WithURL(providerURL),
gitlab.WithServiceAccount(serviceAccount))
}
log.Warn().
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
@ -117,7 +118,7 @@ func GetProvider(options Options) (provider Provider) {
if err == nil {
return google.New(google.WithServiceAccount(serviceAccount))
}
log.Warn().
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
@ -129,7 +130,7 @@ func GetProvider(options Options) (provider Provider) {
okta.WithProviderURL(providerURL),
okta.WithServiceAccount(serviceAccount))
}
log.Warn().
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
@ -139,7 +140,7 @@ func GetProvider(options Options) (provider Provider) {
if err == nil {
return onelogin.New(onelogin.WithServiceAccount(serviceAccount))
}
log.Warn().
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
@ -151,14 +152,14 @@ func GetProvider(options Options) (provider Provider) {
ping.WithProviderURL(providerURL),
ping.WithServiceAccount(serviceAccount))
}
log.Warn().
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
Msg("invalid service account for ping directory provider")
}
log.Warn().
log.Warn(ctx).
Str("provider", options.Provider).
Msg("no directory provider implementation found, disabling support for groups")
return nullProvider{}

View file

@ -67,7 +67,7 @@ type Server struct {
}
// NewServer creates a new server with traffic routed by envoy.
func NewServer(src config.Source, grpcPort, httpPort string, builder *envoyconfig.Builder) (*Server, error) {
func NewServer(ctx context.Context, src config.Source, grpcPort, httpPort string, builder *envoyconfig.Builder) (*Server, error) {
wd := filepath.Join(os.TempDir(), workingDirectoryName)
err := os.MkdirAll(wd, embeddedEnvoyPermissions)
if err != nil {
@ -76,7 +76,7 @@ func NewServer(src config.Source, grpcPort, httpPort string, builder *envoyconfi
envoyPath, err := extractEmbeddedEnvoy()
if err != nil {
log.Warn().Err(err).Send()
log.Warn(ctx).Err(err).Send()
envoyPath = "envoy"
}
@ -98,7 +98,7 @@ func NewServer(src config.Source, grpcPort, httpPort string, builder *envoyconfi
return nil, fmt.Errorf("invalid envoy binary, expected %s but got %s", Checksum, s)
}
} else {
log.Info().Msg("no checksum defined, envoy binary will not be verified!")
log.Info(ctx).Msg("no checksum defined, envoy binary will not be verified!")
}
srv := &Server{
@ -108,12 +108,12 @@ func NewServer(src config.Source, grpcPort, httpPort string, builder *envoyconfi
httpPort: httpPort,
envoyPath: envoyPath,
}
go srv.runProcessCollector()
go srv.runProcessCollector(ctx)
src.OnConfigChange(srv.onConfigChange)
srv.onConfigChange(src.GetConfig())
src.OnConfigChange(ctx, srv.onConfigChange)
srv.onConfigChange(ctx, src.GetConfig())
log.Info().
log.Info(ctx).
Str("path", envoyPath).
Str("checksum", Checksum).
Msg("running envoy")
@ -130,7 +130,7 @@ func (srv *Server) Close() error {
if srv.cmd != nil && srv.cmd.Process != nil {
err = srv.cmd.Process.Kill()
if err != nil {
log.Error().Err(err).Str("service", "envoy").Msg("envoy: failed to kill process on close")
log.Error(context.TODO()).Err(err).Str("service", "envoy").Msg("envoy: failed to kill process on close")
}
srv.cmd = nil
}
@ -138,17 +138,17 @@ func (srv *Server) Close() error {
return err
}
func (srv *Server) onConfigChange(cfg *config.Config) {
srv.update(cfg)
func (srv *Server) onConfigChange(ctx context.Context, cfg *config.Config) {
srv.update(ctx, cfg)
}
func (srv *Server) update(cfg *config.Config) {
func (srv *Server) update(ctx context.Context, cfg *config.Config) {
srv.mu.Lock()
defer srv.mu.Unlock()
tracingOptions, err := config.NewTracingOptions(cfg.Options)
if err != nil {
log.Error().Err(err).Str("service", "envoy").Msg("invalid tracing config")
log.Error(ctx).Err(err).Str("service", "envoy").Msg("invalid tracing config")
return
}
@ -159,24 +159,24 @@ func (srv *Server) update(cfg *config.Config) {
}
if cmp.Equal(srv.options, options, cmp.AllowUnexported(serverOptions{})) {
log.Debug().Str("service", "envoy").Msg("envoy: no config changes detected")
log.Debug(ctx).Str("service", "envoy").Msg("envoy: no config changes detected")
return
}
srv.options = options
if err := srv.writeConfig(cfg); err != nil {
log.Error().Err(err).Str("service", "envoy").Msg("envoy: failed to write envoy config")
if err := srv.writeConfig(ctx, cfg); err != nil {
log.Error(ctx).Err(err).Str("service", "envoy").Msg("envoy: failed to write envoy config")
return
}
log.Info().Msg("envoy: starting envoy process")
if err := srv.run(); err != nil {
log.Error().Err(err).Str("service", "envoy").Msg("envoy: failed to run envoy process")
log.Info(ctx).Msg("envoy: starting envoy process")
if err := srv.run(ctx); err != nil {
log.Error(ctx).Err(err).Str("service", "envoy").Msg("envoy: failed to run envoy process")
return
}
}
func (srv *Server) run() error {
func (srv *Server) run(ctx context.Context) error {
args := []string{
"-c", configFileName,
"--log-level", srv.options.logLevel,
@ -198,13 +198,13 @@ func (srv *Server) run() error {
if err != nil {
return fmt.Errorf("error creating stderr pipe for envoy: %w", err)
}
go srv.handleLogs(stderr)
go srv.handleLogs(ctx, stderr)
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("error creating stderr pipe for envoy: %w", err)
}
go srv.handleLogs(stdout)
go srv.handleLogs(ctx, stdout)
// make sure envoy is killed if we're killed
cmd.SysProcAttr = sysProcAttr
@ -216,10 +216,10 @@ func (srv *Server) run() error {
// release the previous process so we can hot-reload
if srv.cmd != nil && srv.cmd.Process != nil {
log.Info().Msg("envoy: releasing envoy process for hot-reload")
log.Info(ctx).Msg("envoy: releasing envoy process for hot-reload")
err := srv.cmd.Process.Release()
if err != nil {
log.Warn().Err(err).Str("service", "envoy").Msg("envoy: failed to release envoy process for hot-reload")
log.Warn(ctx).Err(err).Str("service", "envoy").Msg("envoy: failed to release envoy process for hot-reload")
}
}
srv.cmd = cmd
@ -227,14 +227,14 @@ func (srv *Server) run() error {
return nil
}
func (srv *Server) writeConfig(cfg *config.Config) error {
func (srv *Server) writeConfig(ctx context.Context, cfg *config.Config) error {
confBytes, err := srv.buildBootstrapConfig(cfg)
if err != nil {
return err
}
cfgPath := filepath.Join(srv.wd, configFileName)
log.Debug().Str("service", "envoy").Str("location", cfgPath).Msg("wrote config file to location")
log.Debug(ctx).Str("service", "envoy").Str("location", cfgPath).Msg("wrote config file to location")
return atomic.WriteFile(cfgPath, bytes.NewReader(confBytes))
}
@ -313,9 +313,10 @@ func (srv *Server) parseLog(line string) (name string, logLevel string, msg stri
return
}
func (srv *Server) handleLogs(rc io.ReadCloser) {
func (srv *Server) handleLogs(ctx context.Context, rc io.ReadCloser) {
defer rc.Close()
l := log.With().Str("service", "envoy").Logger()
bo := backoff.NewExponentialBackOff()
s := bufio.NewReader(rc)
@ -325,7 +326,7 @@ func (srv *Server) handleLogs(rc io.ReadCloser) {
if errors.Is(err, io.EOF) || errors.Is(err, os.ErrClosed) {
break
}
log.Error().Err(err).Msg("failed to read log")
log.Error(ctx).Err(err).Msg("failed to read log")
time.Sleep(bo.NextBackOff())
continue
}
@ -336,7 +337,8 @@ func (srv *Server) handleLogs(rc io.ReadCloser) {
if name == "" {
name = "envoy"
}
lvl := zerolog.DebugLevel
lvl := zerolog.ErrorLevel
if x, err := zerolog.ParseLevel(logLevel); err == nil {
lvl = x
}
@ -354,14 +356,13 @@ func (srv *Server) handleLogs(rc io.ReadCloser) {
continue
}
log.WithLevel(lvl).
Str("service", "envoy").
l.WithLevel(lvl).
Str("name", name).
Msg(msg)
}
}
func (srv *Server) runProcessCollector() {
func (srv *Server) runProcessCollector(ctx context.Context) {
// macos is not supported
if runtime.GOOS != "linux" {
return
@ -369,7 +370,7 @@ func (srv *Server) runProcessCollector() {
pc := metrics.NewProcessCollector("envoy")
if err := view.Register(pc.Views()...); err != nil {
log.Error().Err(err).Msg("failed to register envoy process metric views")
log.Error(ctx).Err(err).Msg("failed to register envoy process metric views")
}
const collectInterval = time.Second * 10
@ -387,7 +388,7 @@ func (srv *Server) runProcessCollector() {
if pid > 0 {
err := pc.Measure(context.Background(), pid)
if err != nil {
log.Error().Err(err).Msg("failed to measure envoy process metrics")
log.Error(ctx).Err(err).Msg("failed to measure envoy process metrics")
}
}
}

View file

@ -1,6 +1,7 @@
package envoy
import (
"context"
"io/ioutil"
"regexp"
"strings"
@ -43,6 +44,6 @@ func Benchmark_handleLogs(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
srv.handleLogs(rc)
srv.handleLogs(context.Background(), rc)
}
}

View file

@ -1,9 +1,11 @@
package fileutil
import (
"context"
"sync"
"github.com/rjeczalik/notify"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/signal"
@ -29,6 +31,10 @@ func (watcher *Watcher) Add(filePath string) {
watcher.mu.Lock()
defer watcher.mu.Unlock()
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context {
return c.Str("watch_file", filePath)
})
// already watching
if _, ok := watcher.filePaths[filePath]; ok {
return
@ -37,18 +43,18 @@ func (watcher *Watcher) Add(filePath string) {
ch := make(chan notify.EventInfo, 1)
go func() {
for evt := range ch {
log.Info().Str("path", evt.Path()).Str("event", evt.Event().String()).Msg("filemgr: detected file change")
watcher.Signal.Broadcast()
log.Info(ctx).Str("event", evt.Event().String()).Msg("filemgr: detected file change")
watcher.Signal.Broadcast(ctx)
}
}()
err := notify.Watch(filePath, ch, notify.All)
if err != nil {
log.Error().Err(err).Str("path", filePath).Msg("filemgr: error watching file path")
log.Error(ctx).Err(err).Msg("filemgr: error watching file path")
notify.Stop(ch)
close(ch)
return
}
log.Debug().Str("path", filePath).Msg("filemgr: watching file for changes")
log.Debug(ctx).Msg("filemgr: watching file for changes")
watcher.filePaths[filePath] = ch
}

View file

@ -2,6 +2,7 @@
package reproxy
import (
"context"
"encoding/base64"
"errors"
"math/rand"
@ -120,7 +121,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler {
}
// Update updates the handler with new configuration.
func (h *Handler) Update(cfg *config.Config) {
func (h *Handler) Update(ctx context.Context, cfg *config.Config) {
h.mu.Lock()
defer h.mu.Unlock()
@ -130,7 +131,7 @@ func (h *Handler) Update(cfg *config.Config) {
for i, p := range cfg.Options.Policies {
id, err := p.RouteID()
if err != nil {
log.Warn().Err(err).Msg("reproxy: error getting route id")
log.Warn(ctx).Err(err).Msg("reproxy: error getting route id")
continue
}
h.policies[id] = &cfg.Options.Policies[i]

View file

@ -1,6 +1,7 @@
package reproxy
import (
"context"
"io"
"io/ioutil"
"net/http"
@ -58,7 +59,7 @@ func TestMiddleware(t *testing.T) {
}},
},
}
h.Update(cfg)
h.Update(context.Background(), cfg)
policyID, _ := cfg.Options.Policies[0].RouteID()

View file

@ -97,8 +97,8 @@ func Shutdown(srv *http.Server) {
rec := <-sigint
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
log.Info().Str("signal", rec.String()).Msg("internal/httputil: shutting down servers")
log.Info(context.TODO()).Str("signal", rec.String()).Msg("internal/httputil: shutting down servers")
if err := srv.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("internal/httputil: shutdown failed")
log.Error(context.TODO()).Err(err).Msg("internal/httputil: shutdown failed")
}
}

View file

@ -47,7 +47,6 @@ type (
// A Manager refreshes identity information using session and user data.
type Manager struct {
cfg *atomicConfig
log zerolog.Logger
sessionScheduler *scheduler.Scheduler
userScheduler *scheduler.Scheduler
@ -68,7 +67,6 @@ func New(
) *Manager {
mgr := &Manager{
cfg: newAtomicConfig(newConfig()),
log: log.With().Str("service", "identity_manager").Logger(),
sessionScheduler: scheduler.New(),
userScheduler: scheduler.New(),
@ -79,6 +77,12 @@ func New(
return mgr
}
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "identity_manager")
})
}
// UpdateConfig updates the manager with the new options.
func (mgr *Manager) UpdateConfig(options ...Option) {
mgr.cfg.Store(newConfig(options...))
@ -86,10 +90,11 @@ func (mgr *Manager) UpdateConfig(options ...Option) {
// Run runs the manager. This method blocks until an error occurs or the given context is canceled.
func (mgr *Manager) Run(ctx context.Context) error {
ctx = withLog(ctx)
update := make(chan updateRecordsMessage, 1)
clear := make(chan struct{}, 1)
syncer := newDataBrokerSyncer(mgr.cfg, mgr.log, update, clear)
syncer := newDataBrokerSyncer(ctx, mgr.cfg, update, clear)
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
@ -119,7 +124,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
mgr.onUpdateRecords(ctx, msg)
}
mgr.log.Info().
log.Info(ctx).
Int("directory_groups", len(mgr.directoryGroups)).
Int("directory_users", len(mgr.directoryUsers)).
Int("sessions", mgr.sessions.Len()).
@ -196,7 +201,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
}
func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) {
mgr.log.Info().Msg("refreshing directory users")
log.Info(ctx).Msg("refreshing directory users")
ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.Load().groupRefreshTimeout)
defer clearTimeout()
@ -208,7 +213,7 @@ func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) {
msg += ". You may need to increase the identity provider directory timeout setting"
msg += "(https://www.pomerium.io/reference/#identity-provider-refresh-directory-settings)"
}
mgr.log.Warn().Err(err).Msg(msg)
log.Warn(ctx).Err(err).Msg(msg)
return
}
@ -232,7 +237,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
id := newDG.GetId()
any, err := anypb.New(newDG)
if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
log.Warn(ctx).Err(err).Msg("failed to marshal directory group")
return
}
eg.Go(func() error {
@ -262,7 +267,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
id := curDG.GetId()
any, err := anypb.New(curDG)
if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
log.Warn(ctx).Err(err).Msg("failed to marshal directory group")
return
}
eg.Go(func() error {
@ -287,7 +292,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
}
if err := eg.Wait(); err != nil {
mgr.log.Warn().Err(err).Msg("manager: failed to merge groups")
log.Warn(ctx).Err(err).Msg("manager: failed to merge groups")
}
}
@ -305,7 +310,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
id := newDU.GetId()
any, err := anypb.New(newDU)
if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
log.Warn(ctx).Err(err).Msg("failed to marshal directory user")
return
}
eg.Go(func() error {
@ -335,7 +340,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
id := curDU.GetId()
any, err := anypb.New(curDU)
if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
log.Warn(ctx).Err(err).Msg("failed to marshal directory user")
return
}
eg.Go(func() error {
@ -361,19 +366,19 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
}
if err := eg.Wait(); err != nil {
mgr.log.Warn().Err(err).Msg("manager: failed to merge users")
log.Warn(ctx).Err(err).Msg("manager: failed to merge users")
}
}
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) {
mgr.log.Info().
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("refreshing session")
s, ok := mgr.sessions.Get(userID, sessionID)
if !ok {
mgr.log.Warn().
log.Warn(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("no session found for refresh")
@ -382,7 +387,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
expiry := s.GetExpiresAt().AsTime()
if !expiry.After(time.Now()) {
mgr.log.Info().
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("deleting expired session")
@ -391,7 +396,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
}
if s.Session == nil || s.Session.OauthToken == nil {
mgr.log.Warn().
log.Warn(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("no session oauth2 token found for refresh")
@ -400,13 +405,13 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
newToken, err := mgr.cfg.Load().authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s)
if isTemporaryError(err) {
mgr.log.Error().Err(err).
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token")
return
} else if err != nil {
mgr.log.Error().Err(err).
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token, deleting session")
@ -417,13 +422,13 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
err = mgr.cfg.Load().authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &s)
if isTemporaryError(err) {
mgr.log.Error().Err(err).
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info")
return
} else if err != nil {
mgr.log.Error().Err(err).
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info, deleting session")
@ -433,7 +438,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
res, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
if err != nil {
mgr.log.Error().Err(err).
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update session")
@ -444,13 +449,13 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
}
func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
mgr.log.Info().
log.Info(ctx).
Str("user_id", userID).
Msg("refreshing user")
u, ok := mgr.users.Get(userID)
if !ok {
mgr.log.Warn().
log.Warn(ctx).
Str("user_id", userID).
Msg("no user found for refresh")
return
@ -460,7 +465,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
for _, s := range mgr.sessions.GetSessionsForUser(userID) {
if s.Session == nil || s.Session.OauthToken == nil {
mgr.log.Warn().
log.Warn(ctx).
Str("user_id", userID).
Msg("no session oauth2 token found for refresh")
continue
@ -468,13 +473,13 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
err := mgr.cfg.Load().authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u)
if isTemporaryError(err) {
mgr.log.Error().Err(err).
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info")
return
} else if err != nil {
mgr.log.Error().Err(err).
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info, deleting session")
@ -484,7 +489,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
record, err := user.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
if err != nil {
mgr.log.Error().Err(err).
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user")
@ -502,7 +507,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
var pbDirectoryGroup directory.Group
err := record.GetData().UnmarshalTo(&pbDirectoryGroup)
if err != nil {
mgr.log.Warn().Msgf("error unmarshaling directory group: %s", err)
log.Warn(ctx).Msgf("error unmarshaling directory group: %s", err)
continue
}
mgr.onUpdateDirectoryGroup(ctx, &pbDirectoryGroup)
@ -510,7 +515,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
var pbDirectoryUser directory.User
err := record.GetData().UnmarshalTo(&pbDirectoryUser)
if err != nil {
mgr.log.Warn().Msgf("error unmarshaling directory user: %s", err)
log.Warn(ctx).Msgf("error unmarshaling directory user: %s", err)
continue
}
mgr.onUpdateDirectoryUser(ctx, &pbDirectoryUser)
@ -518,7 +523,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
var pbSession session.Session
err := record.GetData().UnmarshalTo(&pbSession)
if err != nil {
mgr.log.Warn().Msgf("error unmarshaling session: %s", err)
log.Warn(ctx).Msgf("error unmarshaling session: %s", err)
continue
}
mgr.onUpdateSession(ctx, record, &pbSession)
@ -526,7 +531,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
var pbUser user.User
err := record.GetData().UnmarshalTo(&pbUser)
if err != nil {
mgr.log.Warn().Msgf("error unmarshaling user: %s", err)
log.Warn(ctx).Msgf("error unmarshaling user: %s", err)
continue
}
}
@ -578,7 +583,7 @@ func (mgr *Manager) onUpdateDirectoryGroup(_ context.Context, pbDirectoryGroup *
func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Session) {
err := session.Delete(ctx, mgr.cfg.Load().dataBrokerClient, pbSession.GetId())
if err != nil {
mgr.log.Error().Err(err).
log.Error(ctx).Err(err).
Str("session_id", pbSession.GetId()).
Msg("failed to delete session")
}

View file

@ -3,14 +3,11 @@ package manager
import (
"context"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type dataBrokerSyncer struct {
cfg *atomicConfig
log zerolog.Logger
update chan<- updateRecordsMessage
clear chan<- struct{}
@ -19,14 +16,13 @@ type dataBrokerSyncer struct {
}
func newDataBrokerSyncer(
ctx context.Context,
cfg *atomicConfig,
log zerolog.Logger,
update chan<- updateRecordsMessage,
clear chan<- struct{},
) *dataBrokerSyncer {
syncer := &dataBrokerSyncer{
cfg: cfg,
log: log,
update: update,
clear: clear,

View file

@ -142,7 +142,7 @@ func (p *Provider) userEmail(ctx context.Context, t *oauth2.Token, v interface{}
Email string `json:"email"`
Verified bool `json:"email_verified"`
}
log.Debug().Interface("emails", response).Msg("github: user emails")
log.Debug(ctx).Interface("emails", response).Msg("github: user emails")
for _, email := range response {
if email.Primary && email.Verified {
out.Email = email.Email

View file

@ -88,35 +88,55 @@ func With() zerolog.Context {
}
// Level creates a child logger with the minimum accepted level set to level.
func Level(level zerolog.Level) zerolog.Logger {
return Logger().Level(level)
func Level(ctx context.Context, level zerolog.Level) *zerolog.Logger {
l := contextLogger(ctx).Level(level)
return &l
}
// Debug starts a new message with debug level.
//
// You must call Msg on the returned event in order to send the event.
func Debug() *zerolog.Event {
return Logger().Debug()
func Debug(ctx context.Context) *zerolog.Event {
return contextLogger(ctx).Debug()
}
// Info starts a new message with info level.
//
// You must call Msg on the returned event in order to send the event.
func Info() *zerolog.Event {
return Logger().Info()
func Info(ctx context.Context) *zerolog.Event {
return contextLogger(ctx).Info()
}
// Warn starts a new message with warn level.
//
// You must call Msg on the returned event in order to send the event.
func Warn() *zerolog.Event {
return Logger().Warn()
func Warn(ctx context.Context) *zerolog.Event {
return contextLogger(ctx).Warn()
}
func contextLogger(ctx context.Context) *zerolog.Logger {
global := Logger()
if global.GetLevel() == zerolog.Disabled {
return global
}
l := zerolog.Ctx(ctx)
if l.GetLevel() == zerolog.Disabled { // no logger associated with context
return global
}
return l
}
// WithContext returns a context that has an associated logger and extra fields set via update
func WithContext(ctx context.Context, update func(c zerolog.Context) zerolog.Context) context.Context {
l := contextLogger(ctx).With().Logger()
l.UpdateContext(update)
return l.WithContext(ctx)
}
// Error starts a new message with error level.
//
// You must call Msg on the returned event in order to send the event.
func Error() *zerolog.Event {
func Error(ctx context.Context) *zerolog.Event {
return Logger().Error()
}
@ -136,18 +156,11 @@ func Panic() *zerolog.Event {
return Logger().Panic()
}
// WithLevel starts a new message with level.
//
// You must call Msg on the returned event in order to send the event.
func WithLevel(level zerolog.Level) *zerolog.Event {
return Logger().WithLevel(level)
}
// Log starts a new message with no level. Setting zerolog.GlobalLevel to
// zerolog.Disabled will still disable events produced by this method.
//
// You must call Msg on the returned event in order to send the event.
func Log() *zerolog.Event {
func Log(ctx context.Context) *zerolog.Event {
return Logger().Log()
}

View file

@ -1,6 +1,7 @@
package log_test
import (
"context"
"errors"
"flag"
"time"
@ -54,7 +55,7 @@ func ExamplePrintf() {
// Example of a log with no particular "level"
func ExampleLog() {
setup()
log.Log().Msg("hello world")
log.Log(context.Background()).Msg("hello world")
// Output: {"time":1199811905,"message":"hello world"}
}
@ -62,7 +63,7 @@ func ExampleLog() {
// Example of a log at a particular "level" (in this case, "debug")
func ExampleDebug() {
setup()
log.Debug().Msg("hello world")
log.Debug(context.Background()).Msg("hello world")
// Output: {"level":"debug","time":1199811905,"message":"hello world"}
}
@ -70,7 +71,7 @@ func ExampleDebug() {
// Example of a log at a particular "level" (in this case, "info")
func ExampleInfo() {
setup()
log.Info().Msg("hello world")
log.Info(context.Background()).Msg("hello world")
// Output: {"level":"info","time":1199811905,"message":"hello world"}
}
@ -78,7 +79,7 @@ func ExampleInfo() {
// Example of a log at a particular "level" (in this case, "warn")
func ExampleWarn() {
setup()
log.Warn().Msg("hello world")
log.Warn(context.Background()).Msg("hello world")
// Output: {"level":"warn","time":1199811905,"message":"hello world"}
}
@ -86,7 +87,7 @@ func ExampleWarn() {
// Example of a log at a particular "level" (in this case, "error")
func ExampleError() {
setup()
log.Error().Msg("hello world")
log.Error(context.Background()).Msg("hello world")
// Output: {"level":"error","time":1199811905,"message":"hello world"}
}
@ -119,10 +120,10 @@ func Example() {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
log.Debug().Msg("This message appears only when log level set to Debug")
log.Info().Msg("This message appears when log level set to Debug or Info")
log.Debug(context.Background()).Msg("This message appears only when log level set to Debug")
log.Info(context.Background()).Msg("This message appears when log level set to Debug or Info")
if e := log.Debug(); e.Enabled() {
if e := log.Debug(context.Background()); e.Enabled() {
// Compute log output only if enabled.
value := "bar"
e.Str("foo", value).Msg("some debug message")
@ -134,19 +135,19 @@ func Example() {
func ExampleSetLevel() {
setup()
log.SetLevel("info")
log.Debug().Msg("Debug")
log.Info().Msg("Debug or Info")
log.Debug(context.Background()).Msg("Debug")
log.Info(context.Background()).Msg("Debug or Info")
log.SetLevel("warn")
log.Debug().Msg("Debug")
log.Info().Msg("Debug or Info")
log.Warn().Msg("Debug or Info or Warn")
log.Debug(context.Background()).Msg("Debug")
log.Info(context.Background()).Msg("Debug or Info")
log.Warn(context.Background()).Msg("Debug or Info or Warn")
log.SetLevel("error")
log.Debug().Msg("Debug")
log.Info().Msg("Debug or Info")
log.Warn().Msg("Debug or Info or Warn")
log.Error().Msg("Debug or Info or Warn or Error")
log.Debug(context.Background()).Msg("Debug")
log.Info(context.Background()).Msg("Debug or Info")
log.Warn(context.Background()).Msg("Debug or Info or Warn")
log.Error(context.Background()).Msg("Debug or Info or Warn or Error")
log.SetLevel("default-fall-through")
log.Debug().Msg("Debug")
log.Debug(context.Background()).Msg("Debug")
// Output:
// {"level":"info","time":1199811905,"message":"Debug or Info"}
@ -154,3 +155,32 @@ func ExampleSetLevel() {
// {"level":"error","time":1199811905,"message":"Debug or Info or Warn or Error"}
// {"level":"debug","time":1199811905,"message":"Debug"}
}
func ExampleContext() {
setup()
bg := context.Background()
ctx1 := log.WithContext(bg, func(c zerolog.Context) zerolog.Context {
return c.Str("param_one", "one")
})
ctx2 := log.WithContext(ctx1, func(c zerolog.Context) zerolog.Context {
return c.Str("param_two", "two")
})
log.Warn(bg).Str("non_context_param", "value").Msg("background")
log.Warn(ctx1).Str("non_context_param", "value").Msg("first")
log.Warn(ctx2).Str("non_context_param", "value").Msg("second")
for i := 0; i < 10; i++ {
ctx1 = log.WithContext(ctx1, func(c zerolog.Context) zerolog.Context {
return c.Int("counter", i)
})
}
log.Info(ctx1).Str("non_ctx_param", "value").Msg("after counter")
/*
{"level":"warn","ctx":"one","param":"first","time":1199811905,"message":"first"}
{"level":"warn","ctx":"two","param":"second","time":1199811905,"message":"second"}
{"level":"warn","param":"third","time":1199811905,"message":"third"}
*/
}

View file

@ -51,7 +51,7 @@ func (s *inMemoryServer) periodicCheck(ctx context.Context) {
return
case <-time.After(after):
if s.lockAndRmExpired() {
s.onchange.Broadcast()
s.onchange.Broadcast(ctx)
}
}
}
@ -70,7 +70,7 @@ func (s *inMemoryServer) Report(ctx context.Context, req *pb.RegisterRequest) (*
}
if updated {
s.onchange.Broadcast()
s.onchange.Broadcast(ctx)
}
return &pb.RegisterResponse{
@ -171,7 +171,7 @@ func (s *inMemoryServer) Watch(req *pb.ListRequest, srv pb.Registry_WatchServer)
}
}
func (s *inMemoryServer) getServiceUpdates(ctx context.Context, kinds map[pb.ServiceKind]bool, updates chan struct{}) ([]*pb.Service, error) {
func (s *inMemoryServer) getServiceUpdates(ctx context.Context, kinds map[pb.ServiceKind]bool, updates chan context.Context) ([]*pb.Service, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()

View file

@ -23,25 +23,25 @@ type Reporter struct {
}
// OnConfigChange applies configuration changes to the reporter
func (r *Reporter) OnConfigChange(cfg *config.Config) {
func (r *Reporter) OnConfigChange(ctx context.Context, cfg *config.Config) {
if r.cancel != nil {
r.cancel()
}
services, err := getReportedServices(cfg)
if err != nil {
log.Warn().Err(err).Msg("metrics announce to service registry is disabled")
log.Warn(ctx).Err(err).Msg("metrics announce to service registry is disabled")
}
sharedKey, err := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
if err != nil {
log.Error().Err(err).Msg("decoding shared key")
log.Error(ctx).Err(err).Msg("decoding shared key")
return
}
urls, err := cfg.Options.GetDataBrokerURLs()
if err != nil {
log.Error().Err(err).Msg("invalid databroker urls")
log.Error(ctx).Err(err).Msg("invalid databroker urls")
return
}
@ -58,7 +58,7 @@ func (r *Reporter) OnConfigChange(cfg *config.Config) {
SignedJWTKey: sharedKey,
})
if err != nil {
log.Error().Err(err).Msg("connecting to registry")
log.Error(ctx).Err(err).Msg("connecting to registry")
return
}
@ -145,7 +145,7 @@ func runReporter(
after = resp.CallBackAfter.AsDuration()
backoff.Reset()
case <-ctx.Done():
log.Info().Msg("service registry reporter stopping")
log.Info(ctx).Msg("service registry reporter stopping")
return
}
}

View file

@ -2,28 +2,29 @@
package signal
import (
"context"
"sync"
)
// A Signal is used to let multiple listeners know when something happened.
type Signal struct {
mu sync.Mutex
chs map[chan struct{}]struct{}
chs map[chan context.Context]struct{}
}
// New creates a new Signal.
func New() *Signal {
return &Signal{
chs: make(map[chan struct{}]struct{}),
chs: make(map[chan context.Context]struct{}),
}
}
// Broadcast signals all the listeners. Broadcast never blocks.
func (s *Signal) Broadcast() {
func (s *Signal) Broadcast(ctx context.Context) {
s.mu.Lock()
for ch := range s.chs {
select {
case ch <- struct{}{}:
case ch <- ctx:
default:
}
}
@ -32,8 +33,8 @@ func (s *Signal) Broadcast() {
// Bind creates a new listening channel bound to the signal. The channel used has a size of 1
// and any given broadcast will signal at least one event, but may signal more than one.
func (s *Signal) Bind() chan struct{} {
ch := make(chan struct{}, 1)
func (s *Signal) Bind() chan context.Context {
ch := make(chan context.Context, 1)
s.mu.Lock()
s.chs[ch] = struct{}{}
s.mu.Unlock()
@ -41,7 +42,7 @@ func (s *Signal) Bind() chan struct{} {
}
// Unbind stops the listening channel bound to the signal.
func (s *Signal) Unbind(ch chan struct{}) {
func (s *Signal) Unbind(ch chan context.Context) {
s.mu.Lock()
delete(s.chs, ch)
s.mu.Unlock()

View file

@ -1,6 +1,7 @@
package tcptunnel
import (
"context"
"crypto/tls"
"github.com/pomerium/pomerium/internal/cliutil"
@ -19,7 +20,7 @@ func getConfig(options ...Option) *config {
if jwtCache, err := cliutil.NewLocalJWTCache(); err == nil {
WithJWTCache(jwtCache)(cfg)
} else {
log.Error().Err(err).Msg("tcptunnel: error creating local JWT cache, using in-memory JWT cache")
log.Error(context.TODO()).Err(err).Msg("tcptunnel: error creating local JWT cache, using in-memory JWT cache")
WithJWTCache(cliutil.NewMemoryJWTCache())(cfg)
}
for _, o := range options {

View file

@ -43,7 +43,7 @@ func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) erro
return err
}
defer func() { _ = li.Close() }()
log.Info().Msg("tcptunnel: listening on " + li.Addr().String())
log.Info(ctx).Msg("tcptunnel: listening on " + li.Addr().String())
go func() {
<-ctx.Done()
@ -62,7 +62,7 @@ func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) erro
}
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
log.Warn().Err(err).Msg("tcptunnel: temporarily failed to accept local connection")
log.Warn(ctx).Err(err).Msg("tcptunnel: temporarily failed to accept local connection")
select {
case <-time.After(bo.NextBackOff()):
case <-ctx.Done():
@ -79,7 +79,7 @@ func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) erro
err := tun.Run(ctx, conn)
if err != nil {
log.Error().Err(err).Msg("tcptunnel: error serving local connection")
log.Error(ctx).Err(err).Msg("tcptunnel: error serving local connection")
}
}()
}
@ -102,7 +102,7 @@ func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter) error {
}
func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string, retryCount int) error {
log.Info().
log.Info(ctx).
Str("dst", tun.cfg.dstHost).
Str("proxy", tun.cfg.proxyHost).
Bool("secure", tun.cfg.tlsConfig != nil).
@ -132,7 +132,7 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string,
}
defer func() {
_ = remote.Close()
log.Info().Msg("tcptunnel: connection closed")
log.Info(ctx).Msg("tcptunnel: connection closed")
}()
if done := ctx.Done(); done != nil {
go func() {
@ -189,7 +189,7 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string,
return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode)
}
log.Info().Msg("tcptunnel: connection established")
log.Info(ctx).Msg("tcptunnel: connection established")
errc := make(chan error, 2)
go func() {

View file

@ -142,7 +142,7 @@ func GRPCClientInterceptor(service string) grpc.UnaryClientInterceptor {
tag.Upsert(TagKeyGRPCService, rpcService),
)
if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "GRPCClientInterceptor").Msg("telemetry/metrics: failed to create context")
log.Warn(ctx).Err(tagErr).Str("context", "GRPCClientInterceptor").Msg("telemetry/metrics: failed to create context")
return invoker(ctx, method, req, reply, cc, opts...)
}
@ -181,7 +181,7 @@ func (h *GRPCServerMetricsHandler) TagRPC(ctx context.Context, tagInfo *grpcstat
tag.Upsert(TagKeyGRPCService, rpcService),
)
if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "GRPCServerStatsHandler").Msg("telemetry/metrics: failed to create context")
log.Warn(ctx).Err(tagErr).Str("context", "GRPCServerStatsHandler").Msg("telemetry/metrics: failed to create context")
return ctx
}

View file

@ -120,7 +120,7 @@ func HTTPMetricsHandler(getInstallationID func() string, service string) func(ne
tag.Upsert(TagKeyHTTPMethod, r.Method),
)
if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "HTTPMetricsHandler").Msg("telemetry/metrics: failed to create metrics tag")
log.Warn(ctx).Err(tagErr).Str("context", "HTTPMetricsHandler").Msg("telemetry/metrics: failed to create metrics tag")
next.ServeHTTP(w, r)
return
}
@ -148,7 +148,7 @@ func HTTPMetricsRoundTripper(getInstallationID func() string, service string, de
tag.Upsert(TagKeyDestination, destination),
)
if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "HTTPMetricsRoundTripper").Msg("telemetry/metrics: failed to create metrics tag")
log.Warn(ctx).Err(tagErr).Str("context", "HTTPMetricsRoundTripper").Msg("telemetry/metrics: failed to create metrics tag")
return next.RoundTrip(r)
}

View file

@ -104,8 +104,8 @@ func RecordIdentityManagerLastRefresh() {
// SetDBConfigInfo records status, databroker version and error count while parsing
// the configuration from a databroker
func SetDBConfigInfo(service, configID string, version uint64, errCount int64) {
log.Info().
func SetDBConfigInfo(ctx context.Context, service, configID string, version uint64, errCount int64) {
log.Info(ctx).
Str("service", service).
Str("config_id", configID).
Uint64("version", version).
@ -113,14 +113,14 @@ func SetDBConfigInfo(service, configID string, version uint64, errCount int64) {
Msg("set db config info")
if err := stats.RecordWithTags(
context.Background(),
ctx,
[]tag.Mutator{
tag.Insert(TagKeyService, service),
tag.Insert(TagConfigID, configID),
},
configDBVersion.M(int64(version)),
); err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to record config version number")
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to record config version number")
}
if err := stats.RecordWithTags(
@ -131,20 +131,20 @@ func SetDBConfigInfo(service, configID string, version uint64, errCount int64) {
},
configDBErrors.M(errCount),
); err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to record config error count")
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to record config error count")
}
}
// SetDBConfigRejected records that a certain databroker config version has been rejected
func SetDBConfigRejected(service, configID string, version uint64, err error) {
log.Warn().Err(err).Msg("databroker: invalid config detected, ignoring")
SetDBConfigInfo(service, configID, version, -1)
func SetDBConfigRejected(ctx context.Context, service, configID string, version uint64, err error) {
log.Warn(ctx).Err(err).Msg("databroker: invalid config detected, ignoring")
SetDBConfigInfo(ctx, service, configID, version, -1)
}
// SetConfigInfo records the status, checksum and timestamp of a configuration
// reload. You must register InfoViews or the related config views before calling
func SetConfigInfo(service, configName string, checksum uint64, success bool) {
func SetConfigInfo(ctx context.Context, service, configName string, checksum uint64, success bool) {
if success {
registry.setConfigChecksum(service, configName, checksum)
@ -154,7 +154,7 @@ func SetConfigInfo(service, configName string, checksum uint64, success bool) {
[]tag.Mutator{serviceTag},
configLastReload.M(time.Now().Unix()),
); err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to record config checksum timestamp")
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to record config checksum timestamp")
}
if err := stats.RecordWithTags(
@ -162,12 +162,12 @@ func SetConfigInfo(service, configName string, checksum uint64, success bool) {
[]tag.Mutator{serviceTag},
configLastReloadSuccess.M(1),
); err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to record config reload")
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to record config reload")
}
} else {
stats.Record(context.Background(), configLastReloadSuccess.M(0))
}
log.Info().
log.Info(ctx).
Str("service", service).
Str("config", configName).
Str("checksum", fmt.Sprintf("%x", checksum)).

View file

@ -1,6 +1,7 @@
package metrics
import (
"context"
"fmt"
"runtime"
"testing"
@ -28,7 +29,7 @@ func Test_SetConfigInfo(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
view.Unregister(InfoViews...)
view.Register(InfoViews...)
SetConfigInfo("test_service", "test config", 0, tt.success)
SetConfigInfo(context.Background(), "test_service", "test config", 0, tt.success)
testDataRetrieval(ConfigLastReloadView, t, tt.wantLastReload)
testDataRetrieval(ConfigLastReloadSuccessView, t, tt.wantLastReloadSuccess)
@ -55,7 +56,7 @@ func Test_SetDBConfigInfo(t *testing.T) {
t.Run(fmt.Sprintf("version=%d errors=%d", tt.version, tt.errCount), func(t *testing.T) {
view.Unregister(InfoViews...)
view.Register(InfoViews...)
SetDBConfigInfo("test_service", "test_config", tt.version, tt.errCount)
SetDBConfigInfo(context.TODO(), "test_service", "test_config", tt.version, tt.errCount)
testDataRetrieval(ConfigDBVersionView, t, tt.wantVersion)
testDataRetrieval(ConfigDBErrorsView, t, tt.wantErrors)

View file

@ -89,26 +89,26 @@ func newProxyMetricsHandler(exporter *ocprom.Exporter, envoyURL url.URL, install
err := writeMetricsWithInstallationID(w, rec.Body, installationID)
if err != nil {
log.Error().Err(err).Send()
log.Error(r.Context()).Err(err).Send()
return
}
req, err := http.NewRequestWithContext(r.Context(), "GET", envoyURL.String(), nil)
if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to create request for envoy")
log.Error(r.Context()).Err(err).Msg("telemetry/metrics: failed to create request for envoy")
return
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: fail to fetch proxy metrics")
log.Error(r.Context()).Err(err).Msg("telemetry/metrics: fail to fetch proxy metrics")
return
}
defer resp.Body.Close()
err = writeMetricsWithInstallationID(w, resp.Body, installationID)
if err != nil {
log.Error().Err(err).Send()
log.Error(r.Context()).Err(err).Send()
return
}
}

View file

@ -1,6 +1,7 @@
package metrics
import (
"context"
"runtime"
"sync"
@ -34,6 +35,7 @@ func newMetricRegistry() *metricRegistry {
}
func (r *metricRegistry) init() {
ctx := context.TODO()
r.Do(
func() {
r.registry = metric.NewRegistry()
@ -49,7 +51,7 @@ func (r *metricRegistry) init() {
),
)
if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to register build info metric")
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to register build info metric")
}
r.configChecksum, err = r.registry.AddFloat64Gauge(metrics.ConfigChecksumDecimal,
@ -57,7 +59,7 @@ func (r *metricRegistry) init() {
metric.WithLabelKeys(metrics.ServiceLabel, metrics.ConfigLabel),
)
if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to register config checksum metric")
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to register config checksum metric")
}
r.policyCount, err = r.registry.AddInt64DerivedGauge(metrics.PolicyCountTotal,
@ -65,12 +67,12 @@ func (r *metricRegistry) init() {
metric.WithLabelKeys(metrics.ServiceLabel),
)
if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to register policy count metric")
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to register policy count metric")
}
err = registerAutocertMetrics(r.registry)
if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to register autocert metrics")
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to register autocert metrics")
}
})
}
@ -89,7 +91,7 @@ func (r *metricRegistry) setBuildInfo(service, hostname string) {
metricdata.NewLabelValue(hostname),
)
if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to get build info metric")
log.Error(context.TODO()).Err(err).Msg("telemetry/metrics: failed to get build info metric")
}
// This sets our build_info metric to a constant 1 per
@ -103,7 +105,7 @@ func (r *metricRegistry) addPolicyCountCallback(service string, f func() int64)
}
err := r.policyCount.UpsertEntry(f, metricdata.NewLabelValue(service))
if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to get policy count metric")
log.Error(context.TODO()).Err(err).Msg("telemetry/metrics: failed to get policy count metric")
}
}
@ -113,7 +115,7 @@ func (r *metricRegistry) setConfigChecksum(service string, configName string, ch
}
m, err := r.configChecksum.GetEntry(metricdata.NewLabelValue(service), metricdata.NewLabelValue(configName))
if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to get config checksum metric")
log.Error(context.TODO()).Err(err).Msg("telemetry/metrics: failed to get config checksum metric")
}
m.Set(float64(checksum))
}
@ -122,13 +124,13 @@ func (r *metricRegistry) addInt64DerivedGaugeMetric(name, desc, service string,
m, err := r.registry.AddInt64DerivedGauge(name, metric.WithDescription(desc),
metric.WithLabelKeys(metrics.ServiceLabel))
if err != nil {
log.Error().Err(err).Str("service", service).Msg("telemetry/metrics: failed to register metric")
log.Error(context.TODO()).Err(err).Str("service", service).Msg("telemetry/metrics: failed to register metric")
return
}
err = m.UpsertEntry(f, metricdata.NewLabelValue(service))
if err != nil {
log.Error().Err(err).Str("service", service).Msg("telemetry/metrics: failed to update metric")
log.Error(context.TODO()).Err(err).Str("service", service).Msg("telemetry/metrics: failed to update metric")
return
}
}
@ -137,13 +139,13 @@ func (r *metricRegistry) addInt64DerivedCumulativeMetric(name, desc, service str
m, err := r.registry.AddInt64DerivedCumulative(name, metric.WithDescription(desc),
metric.WithLabelKeys(metrics.ServiceLabel))
if err != nil {
log.Error().Err(err).Str("service", service).Msg("telemetry/metrics: failed to register metric")
log.Error(context.TODO()).Err(err).Str("service", service).Msg("telemetry/metrics: failed to register metric")
return
}
err = m.UpsertEntry(f, metricdata.NewLabelValue(service))
if err != nil {
log.Error().Err(err).Str("service", service).Msg("telemetry/metrics: failed to update metric")
log.Error(context.TODO()).Err(err).Str("service", service).Msg("telemetry/metrics: failed to update metric")
return
}
}

View file

@ -57,6 +57,6 @@ func RecordStorageOperation(ctx context.Context, tags *StorageOperationTags, dur
storageOperationDuration.M(duration.Milliseconds()),
)
if err != nil {
log.Warn().Err(err).Msg("internal/telemetry/metrics: failed to record")
log.Warn(ctx).Err(err).Msg("internal/telemetry/metrics: failed to record")
}
}

View file

@ -77,7 +77,7 @@ func RegisterTracing(opts *TracingOptions) (trace.Exporter, error) {
}
trace.ApplyConfig(trace.Config{DefaultSampler: trace.ProbabilitySampler(opts.SampleRate)})
log.Debug().Interface("Opts", opts).Msg("telemetry/trace: exporter created")
log.Debug(context.TODO()).Interface("Opts", opts).Msg("telemetry/trace: exporter created")
return exporter, nil
}

View file

@ -1,6 +1,7 @@
package cryptutil
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
@ -14,9 +15,10 @@ import (
// GetCertPool gets a cert pool for the given CA or CAFile.
func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
ctx := context.TODO()
rootCAs, err := x509.SystemCertPool()
if err != nil {
log.Error().Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one")
log.Error(ctx).Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one")
rootCAs = x509.NewCertPool()
}
if ca == "" && caFile == "" {
@ -38,7 +40,7 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
if ok := rootCAs.AppendCertsFromPEM(data); !ok {
return nil, fmt.Errorf("failed to append any PEM-encoded certificates")
}
log.Debug().Msg("pkg/cryptutil: added custom certificate authority")
log.Debug(ctx).Msg("pkg/cryptutil: added custom certificate authority")
return rootCAs, nil
}

View file

@ -60,6 +60,7 @@ type Options struct {
// NewGRPCClientConn returns a new gRPC pomerium service client connection.
func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
ctx := context.TODO()
if len(opts.Addrs) == 0 {
return nil, errors.New("internal/grpc: connection address required")
}
@ -105,7 +106,7 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
}
if opts.WithInsecure {
log.Info().Str("addr", connAddr).Msg("internal/grpc: grpc with insecure")
log.Info(ctx).Str("addr", connAddr).Msg("internal/grpc: grpc with insecure")
dialOptions = append(dialOptions, grpc.WithInsecure())
} else {
rootCAs, err := cryptutil.GetCertPool(opts.CA, opts.CAFile)
@ -117,7 +118,7 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
// override allowed certificate name string, typically used when doing behind ingress connection
if opts.OverrideCertificateName != "" {
log.Debug().Str("cert-override-name", opts.OverrideCertificateName).Msg("internal/grpc: grpc")
log.Debug(ctx).Str("cert-override-name", opts.OverrideCertificateName).Msg("internal/grpc: grpc")
err := cert.OverrideServerName(opts.OverrideCertificateName)
if err != nil {
return nil, err
@ -169,7 +170,7 @@ func GetGRPCClientConn(name string, opts *Options) (*grpc.ClientConn, error) {
err := current.conn.Close()
if err != nil {
log.Error().Err(err).Msg("grpc: failed to close existing connection")
log.Error(context.TODO()).Err(err).Msg("grpc: failed to close existing connection")
}
}

View file

@ -101,7 +101,7 @@ func (syncer *Syncer) Run(ctx context.Context) error {
}
if err != nil {
syncer.log().Error().Err(err).Msg("sync")
log.Error(syncer.logCtx(ctx)).Err(err).Msg("sync")
select {
case <-ctx.Done():
return ctx.Err()
@ -112,42 +112,42 @@ func (syncer *Syncer) Run(ctx context.Context) error {
}
func (syncer *Syncer) init(ctx context.Context) error {
syncer.log().Info().Msg("initial sync")
records, recordVersion, serverVersion, err := InitialSync(ctx, syncer.handler.GetDataBrokerServiceClient(), &SyncLatestRequest{
log.Info(syncer.logCtx(ctx)).Msg("initial sync")
records, recordVersion, serverVersion, err := InitialSync(syncer.logCtx(ctx), syncer.handler.GetDataBrokerServiceClient(), &SyncLatestRequest{
Type: syncer.cfg.typeURL,
})
if err != nil {
syncer.log().Error().Err(err).Msg("error during initial sync")
log.Error(syncer.logCtx(ctx)).Err(err).Msg("error during initial sync")
return err
}
syncer.backoff.Reset()
// reset the records as we have to sync latest
syncer.handler.ClearRecords(ctx)
syncer.handler.ClearRecords(syncer.logCtx(ctx))
syncer.recordVersion = recordVersion
syncer.serverVersion = serverVersion
syncer.handler.UpdateRecords(ctx, serverVersion, records)
syncer.handler.UpdateRecords(syncer.logCtx(ctx), serverVersion, records)
return nil
}
func (syncer *Syncer) sync(ctx context.Context) error {
stream, err := syncer.handler.GetDataBrokerServiceClient().Sync(ctx, &SyncRequest{
stream, err := syncer.handler.GetDataBrokerServiceClient().Sync(syncer.logCtx(ctx), &SyncRequest{
ServerVersion: syncer.serverVersion,
RecordVersion: syncer.recordVersion,
})
if err != nil {
syncer.log().Error().Err(err).Msg("error during sync")
log.Error(syncer.logCtx(ctx)).Err(err).Msg("error during sync")
return err
}
syncer.log().Info().Msg("listening for updates")
log.Info(syncer.logCtx(ctx)).Msg("listening for updates")
for {
res, err := stream.Recv()
if status.Code(err) == codes.Aborted {
syncer.log().Error().Err(err).Msg("aborted sync due to mismatched server version")
log.Error(syncer.logCtx(ctx)).Err(err).Msg("aborted sync due to mismatched server version")
// server version changed, so re-init
syncer.serverVersion = 0
return nil
@ -155,14 +155,13 @@ func (syncer *Syncer) sync(ctx context.Context) error {
return err
}
syncer.log().Debug().
log.Debug(syncer.logCtx(ctx)).
Uint("version", uint(res.Record.GetVersion())).
Str("type", res.Record.Type).
Str("id", res.Record.Id).
Msg("syncer got record")
if syncer.recordVersion != res.GetRecord().GetVersion()-1 {
syncer.log().Error().Err(err).
log.Error(syncer.logCtx(ctx)).Err(err).
Uint64("received", res.GetRecord().GetVersion()).
Msg("aborted sync due to missing record")
syncer.serverVersion = 0
@ -170,15 +169,17 @@ func (syncer *Syncer) sync(ctx context.Context) error {
}
syncer.recordVersion = res.GetRecord().GetVersion()
if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() {
syncer.handler.UpdateRecords(ctx, syncer.serverVersion, []*Record{res.GetRecord()})
syncer.handler.UpdateRecords(syncer.logCtx(ctx), syncer.serverVersion, []*Record{res.GetRecord()})
}
}
}
func (syncer *Syncer) log() *zerolog.Logger {
l := log.With().Str("syncer_id", syncer.id).
Str("type", syncer.cfg.typeURL).
Uint64("server_version", syncer.serverVersion).
Uint64("record_version", syncer.recordVersion).Logger()
return &l
// logCtx adds log params to context which
func (syncer *Syncer) logCtx(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("syncer_id", syncer.id).
Str("type", syncer.cfg.typeURL).
Uint64("server_version", syncer.serverVersion).
Uint64("record_version", syncer.recordVersion)
})
}

View file

@ -9,9 +9,11 @@ import (
"time"
"github.com/google/btree"
"github.com/rs/zerolog"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
@ -141,14 +143,20 @@ func (backend *Backend) GetAll(_ context.Context) ([]*databroker.Record, uint64,
}
// Put puts a record into the in-memory store.
func (backend *Backend) Put(_ context.Context, record *databroker.Record) error {
func (backend *Backend) Put(ctx context.Context, record *databroker.Record) error {
if record == nil {
return fmt.Errorf("records cannot be nil")
}
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("db_op", "put").
Str("db_id", record.Id).
Str("db_type", record.Type)
})
backend.mu.Lock()
defer backend.mu.Unlock()
defer backend.onChange.Broadcast()
defer backend.onChange.Broadcast(ctx)
record.ModifiedAt = timestamppb.Now()
record.Version = backend.nextVersion()

View file

@ -12,7 +12,7 @@ type recordStream struct {
ctx context.Context
backend *Backend
changed chan struct{}
changed chan context.Context
ready []*databroker.Record
version uint64

View file

@ -15,7 +15,7 @@ type logger struct {
}
func (l logger) Printf(ctx context.Context, format string, v ...interface{}) {
log.Info().Str("service", "redis").Msgf(format, v...)
log.Info(ctx).Str("service", "redis").Msgf(format, v...)
}
func init() {

View file

@ -51,6 +51,7 @@ type Backend struct {
// New creates a new redis storage backend.
func New(rawURL string, options ...Option) (*Backend, error) {
ctx := context.TODO()
cfg := getConfig(options...)
backend := &Backend{
cfg: cfg,
@ -63,7 +64,7 @@ func New(rawURL string, options ...Option) (*Backend, error) {
return nil, err
}
metrics.AddRedisMetrics(backend.client.PoolStats)
go backend.listenForVersionChanges()
go backend.listenForVersionChanges(ctx)
if cfg.expiry != 0 {
go func() {
ticker := time.NewTicker(time.Minute)
@ -75,7 +76,7 @@ func New(rawURL string, options ...Option) (*Backend, error) {
case <-ticker.C:
}
backend.removeChangesBefore(time.Now().Add(-cfg.expiry))
backend.removeChangesBefore(ctx, time.Now().Add(-cfg.expiry))
}
}()
}
@ -146,7 +147,7 @@ func (backend *Backend) GetAll(ctx context.Context) (records []*databroker.Recor
var record databroker.Record
err := proto.Unmarshal([]byte(result), &record)
if err != nil {
log.Warn().Err(err).Msg("redis: invalid record detected")
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
continue
}
records = append(records, &record)
@ -246,8 +247,8 @@ func (backend *Backend) incrementVersion(ctx context.Context,
return ErrExceededMaxRetries
}
func (backend *Backend) listenForVersionChanges() {
ctx, cancel := context.WithCancel(context.Background())
func (backend *Backend) listenForVersionChanges(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
go func() {
<-backend.closed
cancel()
@ -274,14 +275,13 @@ outer:
switch msg.(type) {
case *redis.Message:
backend.onChange.Broadcast()
backend.onChange.Broadcast(ctx)
}
}
}
}
func (backend *Backend) removeChangesBefore(cutoff time.Time) {
ctx := context.Background()
func (backend *Backend) removeChangesBefore(ctx context.Context, cutoff time.Time) {
for {
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
Min: "-inf",
@ -291,7 +291,7 @@ func (backend *Backend) removeChangesBefore(cutoff time.Time) {
})
results, err := cmd.Result()
if err != nil {
log.Error().Err(err).Msg("redis: error retrieving changes for expiration")
log.Error(ctx).Err(err).Msg("redis: error retrieving changes for expiration")
return
}
@ -303,7 +303,7 @@ func (backend *Backend) removeChangesBefore(cutoff time.Time) {
var record databroker.Record
err = proto.Unmarshal([]byte(results[0]), &record)
if err != nil {
log.Warn().Err(err).Msg("redis: invalid record detected")
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
record.ModifiedAt = timestamppb.New(cutoff.Add(-time.Second)) // set the modified so will delete it
}
@ -315,7 +315,7 @@ func (backend *Backend) removeChangesBefore(cutoff time.Time) {
// remove the record
err = backend.client.ZRem(ctx, changesSetKey, results[0]).Err()
if err != nil {
log.Error().Err(err).Msg("redis: error removing member")
log.Error(ctx).Err(err).Msg("redis: error removing member")
return
}
}

View file

@ -181,7 +181,7 @@ func TestExpiry(t *testing.T) {
_ = stream.Close()
require.Len(t, records, 1000)
backend.removeChangesBefore(time.Now().Add(time.Second))
backend.removeChangesBefore(ctx, time.Now().Add(time.Second))
stream, err = backend.Sync(ctx, 0)
require.NoError(t, err)

View file

@ -17,7 +17,7 @@ type recordStream struct {
ctx context.Context
backend *Backend
changed chan struct{}
changed chan context.Context
version uint64
record *databroker.Record
err error
@ -63,6 +63,7 @@ func (stream *recordStream) Next(block bool) bool {
ticker := time.NewTicker(watchPollInterval)
defer ticker.Stop()
changeCtx := context.Background()
for {
cmd := stream.backend.client.ZRangeByScore(stream.ctx, changesSetKey, &redis.ZRangeBy{
Min: fmt.Sprintf("(%d", stream.version),
@ -81,7 +82,7 @@ func (stream *recordStream) Next(block bool) bool {
var record databroker.Record
err = proto.Unmarshal([]byte(result), &record)
if err != nil {
log.Warn().Err(err).Msg("redis: invalid record detected")
log.Warn(changeCtx).Err(err).Msg("redis: invalid record detected")
} else {
stream.record = &record
}
@ -97,7 +98,7 @@ func (stream *recordStream) Next(block bool) bool {
case <-stream.closed:
return false
case <-ticker.C: // check again
case <-stream.changed: // check again
case changeCtx = <-stream.changed: // check again
}
} else {
return false

View file

@ -57,7 +57,7 @@ func MatchAny(any *anypb.Any, query string) bool {
msg, err := any.UnmarshalNew()
if err != nil {
// ignore invalid any types
log.Error().Err(err).Msg("storage: invalid any type")
log.Error(context.TODO()).Err(err).Msg("storage: invalid any type")
return false
}

View file

@ -35,6 +35,7 @@ func (m *mockCheckClient) Check(ctx context.Context, in *envoy_service_auth_v2.C
}
func TestProxy_ForwardAuth(t *testing.T) {
ctx := context.Background()
t.Parallel()
allowClient := &mockCheckClient{
@ -85,7 +86,7 @@ func TestProxy_ForwardAuth(t *testing.T) {
if err != nil {
t.Fatal(err)
}
p.OnConfigChange(&config.Config{Options: tt.options})
p.OnConfigChange(ctx, &config.Config{Options: tt.options})
state := p.state.Load()
state.sessionStore = tt.sessionStore
signer, err := jws.NewHS256Signer(nil)

View file

@ -2,6 +2,7 @@ package proxy
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
@ -240,7 +241,7 @@ func TestProxy_Callback(t *testing.T) {
if err != nil {
t.Fatal(err)
}
p.OnConfigChange(&config.Config{Options: tt.options})
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
state := p.state.Load()
state.encoder = tt.cipher
state.sessionStore = tt.sessionStore
@ -486,7 +487,7 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
if err != nil {
t.Fatal(err)
}
p.OnConfigChange(&config.Config{Options: tt.options})
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
state := p.state.Load()
state.encoder = tt.cipher
state.sessionStore = tt.sessionStore

View file

@ -5,6 +5,7 @@
package proxy
import (
"context"
"fmt"
"html/template"
"net/http"
@ -76,17 +77,17 @@ func New(cfg *config.Config) (*Proxy, error) {
}
// OnConfigChange updates internal structures based on config.Options
func (p *Proxy) OnConfigChange(cfg *config.Config) {
func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
if p == nil {
return
}
p.currentOptions.Store(cfg.Options)
if err := p.setHandlers(cfg.Options); err != nil {
log.Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
}
if state, err := newProxyStateFromConfig(cfg); err != nil {
log.Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings")
log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy state from configuration settings")
} else {
p.state.Store(state)
}
@ -94,7 +95,7 @@ func (p *Proxy) OnConfigChange(cfg *config.Config) {
func (p *Proxy) setHandlers(opts *config.Options) error {
if len(opts.GetAllPolicies()) == 0 {
log.Warn().Msg("proxy: configuration has no policies")
log.Warn(context.TODO()).Msg("proxy: configuration has no policies")
}
r := httputil.NewRouter()
r.NotFoundHandler = httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {

View file

@ -1,6 +1,7 @@
package proxy
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
@ -197,7 +198,7 @@ func Test_UpdateOptions(t *testing.T) {
t.Fatal(err)
}
p.OnConfigChange(&config.Config{Options: tt.updatedOptions})
p.OnConfigChange(context.Background(), &config.Config{Options: tt.updatedOptions})
r := httptest.NewRequest("GET", tt.host, nil)
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
@ -210,5 +211,5 @@ func Test_UpdateOptions(t *testing.T) {
// Test nil
var p *Proxy
p.OnConfigChange(&config.Config{})
p.OnConfigChange(context.Background(), &config.Config{})
}