log context (#2107)

This commit is contained in:
wasaga 2021-04-22 10:58:13 -04:00 committed by GitHub
parent e7995954ff
commit e0c09a0998
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
87 changed files with 714 additions and 524 deletions

View file

@ -3,6 +3,7 @@
package authenticate package authenticate
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"html/template" "html/template"
@ -76,19 +77,19 @@ func New(cfg *config.Config) (*Authenticate, error) {
} }
// OnConfigChange updates internal structures based on config.Options // OnConfigChange updates internal structures based on config.Options
func (a *Authenticate) OnConfigChange(cfg *config.Config) { func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
if a == nil { if a == nil {
return return
} }
a.options.Store(cfg.Options) a.options.Store(cfg.Options)
if state, err := newAuthenticateStateFromConfig(cfg); err != nil { if state, err := newAuthenticateStateFromConfig(cfg); err != nil {
log.Error().Err(err).Msg("authenticate: failed to update state") log.Error(ctx).Err(err).Msg("authenticate: failed to update state")
} else { } else {
a.state.Store(state) a.state.Store(state)
} }
if err := a.updateProvider(cfg); err != nil { if err := a.updateProvider(cfg); err != nil {
log.Error().Err(err).Msg("authenticate: failed to update identity provider") log.Error(ctx).Err(err).Msg("authenticate: failed to update identity provider")
} }
} }

View file

@ -275,7 +275,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
endSessionURL.RawQuery = params.Encode() endSessionURL.RawQuery = params.Encode()
redirectString = endSessionURL.String() redirectString = endSessionURL.String()
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) { } else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
log.Warn().Err(err).Msg("authenticate.SignOut: failed getting session") log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session")
} }
if redirectString != "" { if redirectString != "" {
httputil.Redirect(w, r, redirectString, http.StatusFound) httputil.Redirect(w, r, redirectString, http.StatusFound)
@ -558,7 +558,7 @@ func (a *Authenticate) saveSessionToDataBroker(
AccessToken: accessToken.AccessToken, AccessToken: accessToken.AccessToken,
}) })
if err != nil { if err != nil {
log.Error().Err(err).Msg("directory: failed to refresh user data") log.Error(ctx).Err(err).Msg("directory: failed to refresh user data")
} }
return nil return nil

View file

@ -56,7 +56,7 @@ func (a *Authorize) WaitForInitialSync(ctx context.Context) error {
return ctx.Err() return ctx.Err()
case <-a.dataBrokerInitialSync: case <-a.dataBrokerInitialSync:
} }
log.Info().Msg("initial sync from databroker complete") log.Info(ctx).Msg("initial sync from databroker complete")
return nil return nil
} }
@ -82,10 +82,10 @@ func newPolicyEvaluator(opts *config.Options, store *evaluator.Store) (*evaluato
} }
// OnConfigChange updates internal structures based on config.Options // OnConfigChange updates internal structures based on config.Options
func (a *Authorize) OnConfigChange(cfg *config.Config) { func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
a.currentOptions.Store(cfg.Options) a.currentOptions.Store(cfg.Options)
if state, err := newAuthorizeStateFromConfig(cfg, a.store); err != nil { if state, err := newAuthorizeStateFromConfig(cfg, a.store); err != nil {
log.Error().Err(err).Msg("authorize: error updating state") log.Error(ctx).Err(err).Msg("authorize: error updating state")
} else { } else {
a.state.Store(state) a.state.Store(state)
} }

View file

@ -1,6 +1,7 @@
package authorize package authorize
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -116,7 +117,7 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUhHNHZDWlJxUFgwNGtmSFQxeVVDM1pUQkF6MFRYWkNtZ043clpDcFE3cHJvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFbzQzdjAwQlR4c3pKZWpmdHhBOWNtVGVUSmtQQXVtOGt1b0UwVlRUZnlId2k3SHJlN2FRUgpHQVJ6Nm0wMjVRdGFiRGxqeDd5MjIyY1gxblhCQXo3MlF3PT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=" o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUhHNHZDWlJxUFgwNGtmSFQxeVVDM1pUQkF6MFRYWkNtZ043clpDcFE3cHJvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFbzQzdjAwQlR4c3pKZWpmdHhBOWNtVGVUSmtQQXVtOGt1b0UwVlRUZnlId2k3SHJlN2FRUgpHQVJ6Nm0wMjVRdGFiRGxqeDd5MjIyY1gxblhCQXo3MlF3PT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
assertFunc = assert.False assertFunc = assert.False
} }
a.OnConfigChange(cfg) a.OnConfigChange(context.Background(), cfg)
assertFunc(t, oldPe == a.state.Load().evaluator) assertFunc(t, oldPe == a.state.Load().evaluator)
}) })
} }

View file

@ -2,6 +2,7 @@ package authorize
import ( import (
"bytes" "bytes"
"context"
"net/http" "net/http"
"net/url" "net/url"
"sort" "sort"
@ -103,7 +104,7 @@ func (a *Authorize) htmlDeniedResponse(
}) })
if err != nil { if err != nil {
buf.WriteString(reason) buf.WriteString(reason)
log.Error().Err(err).Msg("error executing error template") log.Error(context.TODO()).Err(err).Msg("error executing error template")
} }
envoyHeaders := []*envoy_config_core_v3.HeaderValueOption{ envoyHeaders := []*envoy_config_core_v3.HeaderValueOption{

View file

@ -162,7 +162,7 @@ func getJWK(options *config.Options) (*jose.JSONWebKey, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("couldn't generate signing key: %w", err) return nil, fmt.Errorf("couldn't generate signing key: %w", err)
} }
log.Info().Str("Algorithm", jwk.Algorithm). log.Info(context.TODO()).Str("Algorithm", jwk.Algorithm).
Str("KeyID", jwk.KeyID). Str("KeyID", jwk.KeyID).
Interface("Public Key", jwk.Public()). Interface("Public Key", jwk.Public()).
Msg("authorize: signing key") Msg("authorize: signing key")

View file

@ -1,6 +1,7 @@
package evaluator package evaluator
import ( import (
"context"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
@ -46,7 +47,7 @@ func isValidClientCertificate(ca, cert string) (bool, error) {
valid := verifyErr == nil valid := verifyErr == nil
if verifyErr != nil { if verifyErr != nil {
log.Debug().Err(verifyErr).Msg("client certificate failed verification: %w") log.Debug(context.Background()).Err(verifyErr).Msg("client certificate failed verification: %w")
} }
isValidClientCertificateCache.Add(cacheKey, valid) isValidClientCertificateCache.Add(cacheKey, valid)

View file

@ -1,6 +1,7 @@
package evaluator package evaluator
import ( import (
"context"
"fmt" "fmt"
"strconv" "strconv"
@ -71,7 +72,7 @@ func getDenyVar(vars rego.Vars) []Result {
status, err := strconv.Atoi(fmt.Sprint(denial[0])) status, err := strconv.Atoi(fmt.Sprint(denial[0]))
if err != nil { if err != nil {
log.Error().Err(err).Msg("invalid type in deny") log.Error(context.TODO()).Err(err).Msg("invalid type in deny")
continue continue
} }
msg := fmt.Sprint(denial[1]) msg := fmt.Sprint(denial[1])

View file

@ -156,11 +156,12 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
} }
func (s *Store) write(rawPath string, value interface{}) { func (s *Store) write(rawPath string, value interface{}) {
err := storage.Txn(context.Background(), s.Store, storage.WriteParams, func(txn storage.Transaction) error { ctx := context.TODO()
err := storage.Txn(ctx, s.Store, storage.WriteParams, func(txn storage.Transaction) error {
return s.writeTxn(txn, rawPath, value) return s.writeTxn(txn, rawPath, value)
}) })
if err != nil { if err != nil {
log.Error().Err(err).Msg("opa-store: error writing data") log.Error(ctx).Err(err).Msg("opa-store: error writing data")
return return
} }
} }

View file

@ -46,19 +46,19 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
u, err := a.forceSync(ctx, sessionState) u, err := a.forceSync(ctx, sessionState)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("clearing session due to force sync failed") log.Warn(ctx).Err(err).Msg("clearing session due to force sync failed")
sessionState = nil sessionState = nil
} }
req, err := a.getEvaluatorRequestFromCheckRequest(in, sessionState) req, err := a.getEvaluatorRequestFromCheckRequest(in, sessionState)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("error building evaluator request") log.Warn(ctx).Err(err).Msg("error building evaluator request")
return nil, err return nil, err
} }
reply, err := state.evaluator.Evaluate(ctx, req) reply, err := state.evaluator.Evaluate(ctx, req)
if err != nil { if err != nil {
log.Error().Err(err).Msg("error during OPA evaluation") log.Error(ctx).Err(err).Msg("error during OPA evaluation")
return nil, err return nil, err
} }
defer func() { defer func() {

View file

@ -26,7 +26,7 @@ func (a *Authorize) logAuthorizeCheck(
hdrs := getCheckRequestHeaders(in) hdrs := getCheckRequestHeaders(in)
hattrs := in.GetAttributes().GetRequest().GetHttp() hattrs := in.GetAttributes().GetRequest().GetHttp()
evt := log.Info().Str("service", "authorize") evt := log.Info(ctx).Str("service", "authorize")
// request // request
evt = evt.Str("request-id", requestid.FromContext(ctx)) evt = evt.Str("request-id", requestid.FromContext(ctx))
evt = evt.Str("check-request-id", hdrs["X-Request-Id"]) evt = evt.Str("check-request-id", hdrs["X-Request-Id"])
@ -66,10 +66,10 @@ func (a *Authorize) logAuthorizeCheck(
} }
sealed, err := enc.Encrypt(record) sealed, err := enc.Encrypt(record)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("authorize: error encrypting audit record") log.Warn(ctx).Err(err).Msg("authorize: error encrypting audit record")
return return
} }
log.Info(). log.Info(ctx).
Str("request-id", requestid.FromContext(ctx)). Str("request-id", requestid.FromContext(ctx)).
EmbedObject(sealed). EmbedObject(sealed).
Msg("audit log") Msg("audit log")

View file

@ -150,7 +150,7 @@ func (a *Authorize) waitForRecordSync(ctx context.Context, recordTypeURL, record
// record not found, so no need to wait // record not found, so no need to wait
return nil, nil return nil, nil
} else if err != nil { } else if err != nil {
log.Error(). log.Error(ctx).
Err(err). Err(err).
Str("type", recordTypeURL). Str("type", recordTypeURL).
Str("id", recordID). Str("id", recordID).
@ -160,7 +160,7 @@ func (a *Authorize) waitForRecordSync(ctx context.Context, recordTypeURL, record
select { select {
case <-ctx.Done(): case <-ctx.Done():
log.Warn(). log.Warn(ctx).
Str("type", recordTypeURL). Str("type", recordTypeURL).
Str("id", recordID). Str("id", recordID).
Msg("authorize: first sync of record did not complete") Msg("authorize: first sync of record did not complete")

View file

@ -17,10 +17,11 @@ var (
) )
func main() { func main() {
if err := run(context.Background()); !errors.Is(err, context.Canceled) { ctx := context.Background()
if err := run(ctx); !errors.Is(err, context.Canceled) {
log.Fatal().Err(err).Msg("cmd/pomerium") log.Fatal().Err(err).Msg("cmd/pomerium")
} }
log.Info().Msg("cmd/pomerium: exiting") log.Info(ctx).Msg("cmd/pomerium: exiting")
} }
func run(ctx context.Context) error { func run(ctx context.Context) error {

View file

@ -1,11 +1,14 @@
package config package config
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"io/ioutil" "io/ioutil"
"sync" "sync"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/internal/fileutil" "github.com/pomerium/pomerium/internal/fileutil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
@ -13,7 +16,7 @@ import (
) )
// A ChangeListener is called when configuration changes. // A ChangeListener is called when configuration changes.
type ChangeListener = func(*Config) type ChangeListener = func(context.Context, *Config)
// A ChangeDispatcher manages listeners on config changes. // A ChangeDispatcher manages listeners on config changes.
type ChangeDispatcher struct { type ChangeDispatcher struct {
@ -22,17 +25,17 @@ type ChangeDispatcher struct {
} }
// Trigger triggers a change. // Trigger triggers a change.
func (dispatcher *ChangeDispatcher) Trigger(cfg *Config) { func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) {
dispatcher.Lock() dispatcher.Lock()
defer dispatcher.Unlock() defer dispatcher.Unlock()
for _, li := range dispatcher.onConfigChangeListeners { for _, li := range dispatcher.onConfigChangeListeners {
li(cfg) li(ctx, cfg)
} }
} }
// OnConfigChange adds a listener. // OnConfigChange adds a listener.
func (dispatcher *ChangeDispatcher) OnConfigChange(li ChangeListener) { func (dispatcher *ChangeDispatcher) OnConfigChange(ctx context.Context, li ChangeListener) {
dispatcher.Lock() dispatcher.Lock()
defer dispatcher.Unlock() defer dispatcher.Unlock()
dispatcher.onConfigChangeListeners = append(dispatcher.onConfigChangeListeners, li) dispatcher.onConfigChangeListeners = append(dispatcher.onConfigChangeListeners, li)
@ -41,7 +44,7 @@ func (dispatcher *ChangeDispatcher) OnConfigChange(li ChangeListener) {
// A Source gets configuration. // A Source gets configuration.
type Source interface { type Source interface {
GetConfig() *Config GetConfig() *Config
OnConfigChange(ChangeListener) OnConfigChange(context.Context, ChangeListener)
} }
// A StaticSource always returns the same config. Useful for testing. // A StaticSource always returns the same config. Useful for testing.
@ -65,18 +68,18 @@ func (src *StaticSource) GetConfig() *Config {
} }
// SetConfig sets the config. // SetConfig sets the config.
func (src *StaticSource) SetConfig(cfg *Config) { func (src *StaticSource) SetConfig(ctx context.Context, cfg *Config) {
src.mu.Lock() src.mu.Lock()
defer src.mu.Unlock() defer src.mu.Unlock()
src.cfg = cfg src.cfg = cfg
for _, li := range src.lis { for _, li := range src.lis {
li(cfg) li(ctx, cfg)
} }
} }
// OnConfigChange is ignored for the StaticSource. // OnConfigChange is ignored for the StaticSource.
func (src *StaticSource) OnConfigChange(li ChangeListener) { func (src *StaticSource) OnConfigChange(ctx context.Context, li ChangeListener) {
src.mu.Lock() src.mu.Lock()
defer src.mu.Unlock() defer src.mu.Unlock()
@ -95,38 +98,48 @@ type FileOrEnvironmentSource struct {
// NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource. // NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource.
func NewFileOrEnvironmentSource(configFile string) (*FileOrEnvironmentSource, error) { func NewFileOrEnvironmentSource(configFile string) (*FileOrEnvironmentSource, error) {
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context {
return c.Str("config_file_source", configFile)
})
options, err := newOptionsFromConfig(configFile) options, err := newOptionsFromConfig(configFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cfg := &Config{Options: options} cfg := &Config{Options: options}
metrics.SetConfigInfo(cfg.Options.Services, "local", cfg.Checksum(), true) metrics.SetConfigInfo(ctx, cfg.Options.Services, "local", cfg.Checksum(), true)
src := &FileOrEnvironmentSource{ src := &FileOrEnvironmentSource{
configFile: configFile, configFile: configFile,
config: cfg, config: cfg,
} }
options.viper.OnConfigChange(src.onConfigChange) options.viper.OnConfigChange(src.onConfigChange(ctx))
go options.viper.WatchConfig() go options.viper.WatchConfig()
return src, nil return src, nil
} }
func (src *FileOrEnvironmentSource) onConfigChange(evt fsnotify.Event) { func (src *FileOrEnvironmentSource) onConfigChange(ctx context.Context) func(fsnotify.Event) {
src.mu.Lock() return func(evt fsnotify.Event) {
cfg := src.config ctx := log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
options, err := newOptionsFromConfig(src.configFile) return c.Str("config_change_id", uuid.New().String())
if err == nil { })
cfg = &Config{Options: options} log.Info(ctx).Msg("config: file updated, reconfiguring...")
metrics.SetConfigInfo(cfg.Options.Services, "local", cfg.Checksum(), true) src.mu.Lock()
} else { cfg := src.config
log.Error().Err(err).Msg("config: error updating config") options, err := newOptionsFromConfig(src.configFile)
metrics.SetConfigInfo(cfg.Options.Services, "local", cfg.Checksum(), false) if err == nil {
} cfg = &Config{Options: options}
src.mu.Unlock() metrics.SetConfigInfo(ctx, cfg.Options.Services, "local", cfg.Checksum(), true)
} else {
log.Error(ctx).Err(err).Msg("config: error updating config")
metrics.SetConfigInfo(ctx, cfg.Options.Services, "local", cfg.Checksum(), false)
}
src.mu.Unlock()
src.Trigger(cfg) src.Trigger(ctx, cfg)
}
} }
// GetConfig gets the config. // GetConfig gets the config.
@ -148,7 +161,7 @@ type FileWatcherSource struct {
ChangeDispatcher ChangeDispatcher
} }
// NewFileWatcherSource creates a new FileWatcherSource. // NewFileWatcherSource creates a new FileWatcherSource
func NewFileWatcherSource(underlying Source) *FileWatcherSource { func NewFileWatcherSource(underlying Source) *FileWatcherSource {
src := &FileWatcherSource{ src := &FileWatcherSource{
underlying: underlying, underlying: underlying,
@ -158,13 +171,13 @@ func NewFileWatcherSource(underlying Source) *FileWatcherSource {
ch := src.watcher.Bind() ch := src.watcher.Bind()
go func() { go func() {
for range ch { for range ch {
src.check(underlying.GetConfig()) src.check(context.TODO(), underlying.GetConfig())
} }
}() }()
underlying.OnConfigChange(func(cfg *Config) { underlying.OnConfigChange(context.TODO(), func(ctx context.Context, cfg *Config) {
src.check(cfg) src.check(ctx, cfg)
}) })
src.check(underlying.GetConfig()) src.check(context.TODO(), underlying.GetConfig())
return src return src
} }
@ -176,7 +189,7 @@ func (src *FileWatcherSource) GetConfig() *Config {
return src.computedConfig return src.computedConfig
} }
func (src *FileWatcherSource) check(cfg *Config) { func (src *FileWatcherSource) check(ctx context.Context, cfg *Config) {
if cfg == nil || cfg.Options == nil { if cfg == nil || cfg.Options == nil {
return return
} }
@ -218,5 +231,5 @@ func (src *FileWatcherSource) check(cfg *Config) {
src.computedConfig = cfg.Clone() src.computedConfig = cfg.Clone()
// trigger a change // trigger a change
src.Trigger(src.computedConfig) src.Trigger(ctx, src.computedConfig)
} }

View file

@ -1,6 +1,7 @@
package config package config
import ( import (
"context"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
@ -13,6 +14,8 @@ import (
) )
func TestFileWatcherSource(t *testing.T) { func TestFileWatcherSource(t *testing.T) {
ctx := context.Background()
tmpdir := filepath.Join(os.TempDir(), uuid.New().String()) tmpdir := filepath.Join(os.TempDir(), uuid.New().String())
err := os.MkdirAll(tmpdir, 0o755) err := os.MkdirAll(tmpdir, 0o755)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
@ -33,7 +36,7 @@ func TestFileWatcherSource(t *testing.T) {
src := NewFileWatcherSource(ssrc) src := NewFileWatcherSource(ssrc)
var closeOnce sync.Once var closeOnce sync.Once
ch := make(chan struct{}) ch := make(chan struct{})
src.OnConfigChange(func(cfg *Config) { src.OnConfigChange(context.Background(), func(ctx context.Context, cfg *Config) {
closeOnce.Do(func() { closeOnce.Do(func() {
close(ch) close(ch)
}) })
@ -50,7 +53,7 @@ func TestFileWatcherSource(t *testing.T) {
t.Error("expected OnConfigChange to be fired after modifying a file") t.Error("expected OnConfigChange to be fired after modifying a file")
} }
ssrc.SetConfig(&Config{ ssrc.SetConfig(ctx, &Config{
Options: &Options{ Options: &Options{
CAFile: filepath.Join(tmpdir, "example.txt"), CAFile: filepath.Join(tmpdir, "example.txt"),
}, },

View file

@ -1,6 +1,7 @@
package envoyconfig package envoyconfig
import ( import (
"context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net" "net"
@ -22,7 +23,7 @@ import (
) )
// BuildClusters builds envoy clusters from the given config. // BuildClusters builds envoy clusters from the given config.
func (b *Builder) BuildClusters(cfg *config.Config) ([]*envoy_config_cluster_v3.Cluster, error) { func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*envoy_config_cluster_v3.Cluster, error) {
grpcURL := &url.URL{ grpcURL := &url.URL{
Scheme: "http", Scheme: "http",
Host: b.localGRPCAddress, Host: b.localGRPCAddress,
@ -36,15 +37,15 @@ func (b *Builder) BuildClusters(cfg *config.Config) ([]*envoy_config_cluster_v3.
return nil, err return nil, err
} }
controlGRPC, err := b.buildInternalCluster(cfg.Options, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, true) controlGRPC, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
controlHTTP, err := b.buildInternalCluster(cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, false) controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
authZ, err := b.buildInternalCluster(cfg.Options, "pomerium-authorize", authzURLs, true) authZ, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authzURLs, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -62,7 +63,7 @@ func (b *Builder) BuildClusters(cfg *config.Config) ([]*envoy_config_cluster_v3.
policy.EnvoyOpts = newDefaultEnvoyClusterConfig() policy.EnvoyOpts = newDefaultEnvoyClusterConfig()
} }
if len(policy.To) > 0 { if len(policy.To) > 0 {
cluster, err := b.buildPolicyCluster(cfg.Options, &policy) cluster, err := b.buildPolicyCluster(ctx, cfg.Options, &policy)
if err != nil { if err != nil {
return nil, fmt.Errorf("policy #%d: %w", i, err) return nil, fmt.Errorf("policy #%d: %w", i, err)
} }
@ -78,12 +79,18 @@ func (b *Builder) BuildClusters(cfg *config.Config) ([]*envoy_config_cluster_v3.
return clusters, nil return clusters, nil
} }
func (b *Builder) buildInternalCluster(options *config.Options, name string, dsts []*url.URL, forceHTTP2 bool) (*envoy_config_cluster_v3.Cluster, error) { func (b *Builder) buildInternalCluster(
ctx context.Context,
options *config.Options,
name string,
dsts []*url.URL,
forceHTTP2 bool,
) (*envoy_config_cluster_v3.Cluster, error) {
cluster := newDefaultEnvoyClusterConfig() cluster := newDefaultEnvoyClusterConfig()
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily) cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
var endpoints []Endpoint var endpoints []Endpoint
for _, dst := range dsts { for _, dst := range dsts {
ts, err := b.buildInternalTransportSocket(options, dst) ts, err := b.buildInternalTransportSocket(ctx, options, dst)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -95,14 +102,14 @@ func (b *Builder) buildInternalCluster(options *config.Options, name string, dst
return cluster, nil return cluster, nil
} }
func (b *Builder) buildPolicyCluster(options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) { func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
cluster := new(envoy_config_cluster_v3.Cluster) cluster := new(envoy_config_cluster_v3.Cluster)
proto.Merge(cluster, policy.EnvoyOpts) proto.Merge(cluster, policy.EnvoyOpts)
cluster.AltStatName = getClusterStatsName(policy) cluster.AltStatName = getClusterStatsName(policy)
name := getClusterID(policy) name := getClusterID(policy)
endpoints, err := b.buildPolicyEndpoints(policy) endpoints, err := b.buildPolicyEndpoints(ctx, policy)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -122,10 +129,10 @@ func (b *Builder) buildPolicyCluster(options *config.Options, policy *config.Pol
return cluster, nil return cluster, nil
} }
func (b *Builder) buildPolicyEndpoints(policy *config.Policy) ([]Endpoint, error) { func (b *Builder) buildPolicyEndpoints(ctx context.Context, policy *config.Policy) ([]Endpoint, error) {
var endpoints []Endpoint var endpoints []Endpoint
for _, dst := range policy.To { for _, dst := range policy.To {
ts, err := b.buildPolicyTransportSocket(policy, dst.URL) ts, err := b.buildPolicyTransportSocket(ctx, policy, dst.URL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -134,7 +141,7 @@ func (b *Builder) buildPolicyEndpoints(policy *config.Policy) ([]Endpoint, error
return endpoints, nil return endpoints, nil
} }
func (b *Builder) buildInternalTransportSocket(options *config.Options, endpoint *url.URL) (*envoy_config_core_v3.TransportSocket, error) { func (b *Builder) buildInternalTransportSocket(ctx context.Context, options *config.Options, endpoint *url.URL) (*envoy_config_core_v3.TransportSocket, error) {
if endpoint.Scheme != "https" { if endpoint.Scheme != "https" {
return nil, nil return nil, nil
} }
@ -154,13 +161,13 @@ func (b *Builder) buildInternalTransportSocket(options *config.Options, endpoint
} else if options.CA != "" { } else if options.CA != "" {
bs, err := base64.StdEncoding.DecodeString(options.CA) bs, err := base64.StdEncoding.DecodeString(options.CA)
if err != nil { if err != nil {
log.Error().Err(err).Msg("invalid custom CA certificate") log.Error(ctx).Err(err).Msg("invalid custom CA certificate")
} }
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs) validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
} else { } else {
rootCA, err := getRootCertificateAuthority() rootCA, err := getRootCertificateAuthority()
if err != nil { if err != nil {
log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found") log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else { } else {
validationContext.TrustedCa = b.filemgr.FileDataSource(rootCA) validationContext.TrustedCa = b.filemgr.FileDataSource(rootCA)
} }
@ -183,12 +190,12 @@ func (b *Builder) buildInternalTransportSocket(options *config.Options, endpoint
}, nil }, nil
} }
func (b *Builder) buildPolicyTransportSocket(policy *config.Policy, dst url.URL) (*envoy_config_core_v3.TransportSocket, error) { func (b *Builder) buildPolicyTransportSocket(ctx context.Context, policy *config.Policy, dst url.URL) (*envoy_config_core_v3.TransportSocket, error) {
if dst.Scheme != "https" { if dst.Scheme != "https" {
return nil, nil return nil, nil
} }
vc, err := b.buildPolicyValidationContext(policy, dst) vc, err := b.buildPolicyValidationContext(ctx, policy, dst)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -216,7 +223,7 @@ func (b *Builder) buildPolicyTransportSocket(policy *config.Policy, dst url.URL)
} }
if policy.ClientCertificate != nil { if policy.ClientCertificate != nil {
tlsContext.CommonTlsContext.TlsCertificates = append(tlsContext.CommonTlsContext.TlsCertificates, tlsContext.CommonTlsContext.TlsCertificates = append(tlsContext.CommonTlsContext.TlsCertificates,
b.envoyTLSCertificateFromGoTLSCertificate(policy.ClientCertificate)) b.envoyTLSCertificateFromGoTLSCertificate(ctx, policy.ClientCertificate))
} }
tlsConfig := marshalAny(tlsContext) tlsConfig := marshalAny(tlsContext)
@ -229,6 +236,7 @@ func (b *Builder) buildPolicyTransportSocket(policy *config.Policy, dst url.URL)
} }
func (b *Builder) buildPolicyValidationContext( func (b *Builder) buildPolicyValidationContext(
ctx context.Context,
policy *config.Policy, dst url.URL, policy *config.Policy, dst url.URL,
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) { ) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
sni := dst.Hostname() sni := dst.Hostname()
@ -247,13 +255,13 @@ func (b *Builder) buildPolicyValidationContext(
} else if policy.TLSCustomCA != "" { } else if policy.TLSCustomCA != "" {
bs, err := base64.StdEncoding.DecodeString(policy.TLSCustomCA) bs, err := base64.StdEncoding.DecodeString(policy.TLSCustomCA)
if err != nil { if err != nil {
log.Error().Err(err).Msg("invalid custom CA certificate") log.Error(ctx).Err(err).Msg("invalid custom CA certificate")
} }
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs) validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
} else { } else {
rootCA, err := getRootCertificateAuthority() rootCA, err := getRootCertificateAuthority()
if err != nil { if err != nil {
log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found") log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else { } else {
validationContext.TrustedCa = b.filemgr.FileDataSource(rootCA) validationContext.TrustedCa = b.filemgr.FileDataSource(rootCA)
} }

View file

@ -1,6 +1,7 @@
package envoyconfig package envoyconfig
import ( import (
"context"
"encoding/base64" "encoding/base64"
"os" "os"
"path/filepath" "path/filepath"
@ -18,6 +19,7 @@ import (
) )
func Test_buildPolicyTransportSocket(t *testing.T) { func Test_buildPolicyTransportSocket(t *testing.T) {
ctx := context.Background()
cacheDir, _ := os.UserCacheDir() cacheDir, _ := os.UserCacheDir()
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-32484c314b584447463735303142374c31414145374650305a525539554938594d524855353757313942494d473847535231.pem") customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-32484c314b584447463735303142374c31414145374650305a525539554938594d524855353757313942494d473847535231.pem")
@ -26,14 +28,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
rootCA := b.filemgr.FileDataSource(rootCAPath).GetFilename() rootCA := b.filemgr.FileDataSource(rootCAPath).GetFilename()
t.Run("insecure", func(t *testing.T) { t.Run("insecure", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(&config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com"), To: mustParseWeightedURLs(t, "http://example.com"),
}, *mustParseURL(t, "http://example.com")) }, *mustParseURL(t, "http://example.com"))
require.NoError(t, err) require.NoError(t, err)
assert.Nil(t, ts) assert.Nil(t, ts)
}) })
t.Run("host as sni", func(t *testing.T) { t.Run("host as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(&config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"), To: mustParseWeightedURLs(t, "https://example.com"),
}, *mustParseURL(t, "https://example.com")) }, *mustParseURL(t, "https://example.com"))
require.NoError(t, err) require.NoError(t, err)
@ -67,7 +69,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts) `, ts)
}) })
t.Run("tls_server_name as sni", func(t *testing.T) { t.Run("tls_server_name as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(&config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"), To: mustParseWeightedURLs(t, "https://example.com"),
TLSServerName: "use-this-name.example.com", TLSServerName: "use-this-name.example.com",
}, *mustParseURL(t, "https://example.com")) }, *mustParseURL(t, "https://example.com"))
@ -102,7 +104,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts) `, ts)
}) })
t.Run("tls_skip_verify", func(t *testing.T) { t.Run("tls_skip_verify", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(&config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"), To: mustParseWeightedURLs(t, "https://example.com"),
TLSSkipVerify: true, TLSSkipVerify: true,
}, *mustParseURL(t, "https://example.com")) }, *mustParseURL(t, "https://example.com"))
@ -138,7 +140,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts) `, ts)
}) })
t.Run("custom ca", func(t *testing.T) { t.Run("custom ca", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(&config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"), To: mustParseWeightedURLs(t, "https://example.com"),
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}), TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
}, *mustParseURL(t, "https://example.com")) }, *mustParseURL(t, "https://example.com"))
@ -174,7 +176,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
}) })
t.Run("client certificate", func(t *testing.T) { t.Run("client certificate", func(t *testing.T) {
clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey) clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey)
ts, err := b.buildPolicyTransportSocket(&config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"), To: mustParseWeightedURLs(t, "https://example.com"),
ClientCertificate: clientCert, ClientCertificate: clientCert,
}, *mustParseURL(t, "https://example.com")) }, *mustParseURL(t, "https://example.com"))
@ -219,11 +221,12 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
} }
func Test_buildCluster(t *testing.T) { func Test_buildCluster(t *testing.T) {
ctx := context.Background()
b := New("local-grpc", "local-http", filemgr.NewManager(), nil) b := New("local-grpc", "local-http", filemgr.NewManager(), nil)
rootCAPath, _ := getRootCertificateAuthority() rootCAPath, _ := getRootCertificateAuthority()
rootCA := b.filemgr.FileDataSource(rootCAPath).GetFilename() rootCA := b.filemgr.FileDataSource(rootCAPath).GetFilename()
t.Run("insecure", func(t *testing.T) { t.Run("insecure", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(&config.Policy{ endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"), To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -278,7 +281,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster) `, cluster)
}) })
t.Run("secure", func(t *testing.T) { t.Run("secure", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(&config.Policy{ endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
To: mustParseWeightedURLs(t, To: mustParseWeightedURLs(t,
"https://example.com", "https://example.com",
"https://example.com", "https://example.com",
@ -406,7 +409,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster) `, cluster)
}) })
t.Run("ip addresses", func(t *testing.T) { t.Run("ip addresses", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(&config.Policy{ endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"), To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -459,7 +462,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster) `, cluster)
}) })
t.Run("weights", func(t *testing.T) { t.Run("weights", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(&config.Policy{ endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"), To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -514,7 +517,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster) `, cluster)
}) })
t.Run("localhost", func(t *testing.T) { t.Run("localhost", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(&config.Policy{ endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "http://localhost"), To: mustParseWeightedURLs(t, "http://localhost"),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -557,7 +560,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster) `, cluster)
}) })
t.Run("outlier", func(t *testing.T) { t.Run("outlier", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(&config.Policy{ endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com"), To: mustParseWeightedURLs(t, "http://example.com"),
}) })
require.NoError(t, err) require.NoError(t, err)

View file

@ -3,6 +3,7 @@ package envoyconfig
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
@ -132,7 +133,10 @@ func buildAddress(hostport string, defaultPort int) *envoy_config_core_v3.Addres
} }
} }
func (b *Builder) envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_extensions_transport_sockets_tls_v3.TlsCertificate { func (b *Builder) envoyTLSCertificateFromGoTLSCertificate(
ctx context.Context,
cert *tls.Certificate,
) *envoy_extensions_transport_sockets_tls_v3.TlsCertificate {
envoyCert := &envoy_extensions_transport_sockets_tls_v3.TlsCertificate{} envoyCert := &envoy_extensions_transport_sockets_tls_v3.TlsCertificate{}
var chain bytes.Buffer var chain bytes.Buffer
for _, cbs := range cert.Certificate { for _, cbs := range cert.Certificate {
@ -153,7 +157,7 @@ func (b *Builder) envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate)
}, },
)) ))
} else { } else {
log.Warn().Err(err).Msg("failed to marshal private key for tls config") log.Warn(ctx).Err(err).Msg("failed to marshal private key for tls config")
} }
for _, scts := range cert.SignedCertificateTimestamps { for _, scts := range cert.SignedCertificateTimestamps {
envoyCert.SignedCertificateTimestamp = append(envoyCert.SignedCertificateTimestamp, envoyCert.SignedCertificateTimestamp = append(envoyCert.SignedCertificateTimestamp,
@ -185,10 +189,10 @@ func getRootCertificateAuthority() (string, error) {
} }
} }
if rootCABundle.value == "" { if rootCABundle.value == "" {
log.Error().Strs("known-locations", knownRootLocations). log.Error(context.TODO()).Strs("known-locations", knownRootLocations).
Msgf("no root certificates were found in any of the known locations") Msgf("no root certificates were found in any of the known locations")
} else { } else {
log.Info().Msgf("using %s as the system root certificate authority bundle", rootCABundle.value) log.Info(context.TODO()).Msgf("using %s as the system root certificate authority bundle", rootCABundle.value)
} }
}) })
if rootCABundle.value == "" { if rootCABundle.value == "" {

View file

@ -2,6 +2,7 @@
package filemgr package filemgr
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -34,7 +35,7 @@ func (mgr *Manager) BytesDataSource(fileName string, data []byte) *envoy_config_
fileName = fmt.Sprintf("%s-%x%s", fileName[:len(fileName)-len(ext)], h, ext) fileName = fmt.Sprintf("%s-%x%s", fileName[:len(fileName)-len(ext)], h, ext)
if err := os.MkdirAll(mgr.cfg.cacheDir, 0o700); err != nil { if err := os.MkdirAll(mgr.cfg.cacheDir, 0o700); err != nil {
log.Error().Err(err).Msg("filemgr: error creating cache directory, falling back to inline bytes") log.Error(context.TODO()).Err(err).Msg("filemgr: error creating cache directory, falling back to inline bytes")
return inlineBytes(data) return inlineBytes(data)
} }
@ -42,11 +43,11 @@ func (mgr *Manager) BytesDataSource(fileName string, data []byte) *envoy_config_
if _, err := os.Stat(filePath); os.IsNotExist(err) { if _, err := os.Stat(filePath); os.IsNotExist(err) {
err = ioutil.WriteFile(filePath, data, 0o600) err = ioutil.WriteFile(filePath, data, 0o600)
if err != nil { if err != nil {
log.Error().Err(err).Msg("filemgr: error writing cache file, falling back to inline bytes") log.Error(context.TODO()).Err(err).Msg("filemgr: error writing cache file, falling back to inline bytes")
return inlineBytes(data) return inlineBytes(data)
} }
} else if err != nil { } else if err != nil {
log.Error().Err(err).Msg("filemgr: error reading cache file, falling back to inline bytes") log.Error(context.TODO()).Err(err).Msg("filemgr: error reading cache file, falling back to inline bytes")
return inlineBytes(data) return inlineBytes(data)
} }
@ -62,7 +63,7 @@ func (mgr *Manager) ClearCache() {
return os.Remove(p) return os.Remove(p)
}) })
if err != nil { if err != nil {
log.Error().Err(err).Msg("failed to clear envoy file cache") log.Error(context.TODO()).Err(err).Msg("failed to clear envoy file cache")
} }
} }

View file

@ -1,6 +1,7 @@
package envoyconfig package envoyconfig
import ( import (
"context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net" "net"
@ -187,7 +188,7 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen
CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{ CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{
TlsParams: tlsParams, TlsParams: tlsParams,
TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{ TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{
b.envoyTLSCertificateFromGoTLSCertificate(cert), b.envoyTLSCertificateFromGoTLSCertificate(context.TODO(), cert),
}, },
AlpnProtocols: []string{"h2", "http/1.1"}, AlpnProtocols: []string{"h2", "http/1.1"},
}, },
@ -645,19 +646,21 @@ func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_con
} }
func (b *Builder) buildDownstreamTLSContext(cfg *config.Config, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { func (b *Builder) buildDownstreamTLSContext(cfg *config.Config, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
ctx := context.TODO()
certs, err := cfg.AllCertificates() certs, err := cfg.AllCertificates()
if err != nil { if err != nil {
log.Warn().Str("domain", domain).Err(err).Msg("failed to get all certificates from config") log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get all certificates from config")
return nil return nil
} }
cert, err := cryptutil.GetCertificateForDomain(certs, domain) cert, err := cryptutil.GetCertificateForDomain(certs, domain)
if err != nil { if err != nil {
log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain") log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get certificate for domain")
return nil return nil
} }
envoyCert := b.envoyTLSCertificateFromGoTLSCertificate(cert) envoyCert := b.envoyTLSCertificateFromGoTLSCertificate(context.TODO(), cert)
return &envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext{ return &envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext{
CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{ CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{
TlsParams: tlsParams, TlsParams: tlsParams,

View file

@ -1,6 +1,7 @@
package config package config
import ( import (
"context"
"crypto/tls" "crypto/tls"
"net/http" "net/http"
"sync/atomic" "sync/atomic"
@ -8,6 +9,8 @@ import (
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/tripper" "github.com/pomerium/pomerium/internal/tripper"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/rs/zerolog"
) )
type httpTransport struct { type httpTransport struct {
@ -18,16 +21,19 @@ type httpTransport struct {
// NewHTTPTransport creates a new http transport. If CA or CAFile is set, the transport will // NewHTTPTransport creates a new http transport. If CA or CAFile is set, the transport will
// add the CA to system cert pool. // add the CA to system cert pool.
func NewHTTPTransport(src Source) http.RoundTripper { func NewHTTPTransport(src Source) http.RoundTripper {
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context {
return c.Caller()
})
t := new(httpTransport) t := new(httpTransport)
t.underlying, _ = http.DefaultTransport.(*http.Transport) t.underlying, _ = http.DefaultTransport.(*http.Transport)
src.OnConfigChange(func(cfg *Config) { src.OnConfigChange(ctx, func(ctx context.Context, cfg *Config) {
t.update(cfg.Options) t.update(ctx, cfg.Options)
}) })
t.update(src.GetConfig().Options) t.update(ctx, src.GetConfig().Options)
return t return t
} }
func (t *httpTransport) update(options *Options) { func (t *httpTransport) update(ctx context.Context, options *Options) {
nt := new(http.Transport) nt := new(http.Transport)
if t.underlying != nil { if t.underlying != nil {
nt = t.underlying.Clone() nt = t.underlying.Clone()
@ -40,7 +46,7 @@ func (t *httpTransport) update(options *Options) {
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} }
} else { } else {
log.Error().Err(err).Msg("config: error getting cert pool") log.Error(ctx).Err(err).Msg("config: error getting cert pool")
} }
} }
t.transport.Store(nt) t.transport.Store(nt)
@ -78,7 +84,7 @@ func NewPolicyHTTPTransport(options *Options, policy *Policy) http.RoundTripper
tlsClientConfig.MinVersion = tls.VersionTLS12 tlsClientConfig.MinVersion = tls.VersionTLS12
isCustomClientConfig = true isCustomClientConfig = true
} else { } else {
log.Error().Err(err).Msg("config: error getting ca cert pool") log.Error(context.TODO()).Err(err).Msg("config: error getting ca cert pool")
} }
} }
@ -89,7 +95,7 @@ func NewPolicyHTTPTransport(options *Options, policy *Policy) http.RoundTripper
tlsClientConfig.MinVersion = tls.VersionTLS12 tlsClientConfig.MinVersion = tls.VersionTLS12
isCustomClientConfig = true isCustomClientConfig = true
} else { } else {
log.Error().Err(err).Msg("config: error getting custom ca cert pool") log.Error(context.TODO()).Err(err).Msg("config: error getting custom ca cert pool")
} }
} }

View file

@ -1,6 +1,7 @@
package config package config
import ( import (
"context"
"sync" "sync"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
@ -12,10 +13,10 @@ type LogManager struct {
} }
// NewLogManager creates a new LogManager. // NewLogManager creates a new LogManager.
func NewLogManager(src Source) *LogManager { func NewLogManager(ctx context.Context, src Source) *LogManager {
mgr := &LogManager{} mgr := &LogManager{}
src.OnConfigChange(mgr.OnConfigChange) src.OnConfigChange(ctx, mgr.OnConfigChange)
mgr.OnConfigChange(src.GetConfig()) mgr.OnConfigChange(ctx, src.GetConfig())
return mgr return mgr
} }
@ -25,7 +26,7 @@ func (mgr *LogManager) Close() error {
} }
// OnConfigChange is called whenever configuration changes. // OnConfigChange is called whenever configuration changes.
func (mgr *LogManager) OnConfigChange(cfg *Config) { func (mgr *LogManager) OnConfigChange(ctx context.Context, cfg *Config) {
if cfg == nil || cfg.Options == nil { if cfg == nil || cfg.Options == nil {
return return
} }

View file

@ -1,6 +1,7 @@
package config package config
import ( import (
"context"
"net/http" "net/http"
"os" "os"
"sync" "sync"
@ -9,6 +10,8 @@ import (
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/rs/zerolog"
) )
// A MetricsManager manages metrics for a given configuration. // A MetricsManager manages metrics for a given configuration.
@ -22,11 +25,14 @@ type MetricsManager struct {
} }
// NewMetricsManager creates a new MetricsManager. // NewMetricsManager creates a new MetricsManager.
func NewMetricsManager(src Source) *MetricsManager { func NewMetricsManager(ctx context.Context, src Source) *MetricsManager {
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "metrics_manager")
})
mgr := &MetricsManager{} mgr := &MetricsManager{}
metrics.RegisterInfoMetrics() metrics.RegisterInfoMetrics()
src.OnConfigChange(mgr.OnConfigChange) src.OnConfigChange(ctx, mgr.OnConfigChange)
mgr.OnConfigChange(src.GetConfig()) mgr.OnConfigChange(ctx, src.GetConfig())
return mgr return mgr
} }
@ -36,7 +42,7 @@ func (mgr *MetricsManager) Close() error {
} }
// OnConfigChange updates the metrics manager when configuration is changed. // OnConfigChange updates the metrics manager when configuration is changed.
func (mgr *MetricsManager) OnConfigChange(cfg *Config) { func (mgr *MetricsManager) OnConfigChange(ctx context.Context, cfg *Config) {
mgr.mu.Lock() mgr.mu.Lock()
defer mgr.mu.Unlock() defer mgr.mu.Unlock()
@ -63,7 +69,7 @@ func (mgr *MetricsManager) updateInfo(cfg *Config) {
hostname, err := os.Hostname() hostname, err := os.Hostname()
if err != nil { if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to get OS hostname") log.Error(context.TODO()).Err(err).Msg("telemetry/metrics: failed to get OS hostname")
hostname = "__unknown__" hostname = "__unknown__"
} }
@ -84,13 +90,13 @@ func (mgr *MetricsManager) updateServer(cfg *Config) {
mgr.handler = nil mgr.handler = nil
if mgr.addr == "" { if mgr.addr == "" {
log.Info().Msg("metrics: http server disabled") log.Info(context.TODO()).Msg("metrics: http server disabled")
return return
} }
handler, err := metrics.PrometheusHandler(EnvoyAdminURL, mgr.installationID) handler, err := metrics.PrometheusHandler(EnvoyAdminURL, mgr.installationID)
if err != nil { if err != nil {
log.Error().Err(err).Msg("metrics: failed to create prometheus handler") log.Error(context.TODO()).Err(err).Msg("metrics: failed to create prometheus handler")
return return
} }

View file

@ -1,6 +1,7 @@
package config package config
import ( import (
"context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
@ -12,12 +13,13 @@ import (
) )
func TestMetricsManager(t *testing.T) { func TestMetricsManager(t *testing.T) {
ctx := context.Background()
src := NewStaticSource(&Config{ src := NewStaticSource(&Config{
Options: &Options{ Options: &Options{
MetricsAddr: "ADDRESS", MetricsAddr: "ADDRESS",
}, },
}) })
mgr := NewMetricsManager(src) mgr := NewMetricsManager(ctx, src)
srv1 := httptest.NewServer(mgr) srv1 := httptest.NewServer(mgr)
defer srv1.Close() defer srv1.Close()
srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -42,7 +44,7 @@ func TestMetricsManagerBasicAuth(t *testing.T) {
MetricsBasicAuth: base64.StdEncoding.EncodeToString([]byte("x:y")), MetricsBasicAuth: base64.StdEncoding.EncodeToString([]byte("x:y")),
}, },
}) })
mgr := NewMetricsManager(src) mgr := NewMetricsManager(context.Background(), src)
srv1 := httptest.NewServer(mgr) srv1 := httptest.NewServer(mgr)
defer srv1.Close() defer srv1.Close()

View file

@ -2,6 +2,7 @@ package config
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
@ -427,7 +428,7 @@ func (o *Options) viperIsSet(key string) bool {
// parseHeaders handles unmarshalling any custom headers correctly from the // parseHeaders handles unmarshalling any custom headers correctly from the
// environment or viper's parsed keys // environment or viper's parsed keys
func (o *Options) parseHeaders() error { func (o *Options) parseHeaders(ctx context.Context) error {
var headers map[string]string var headers map[string]string
if o.HeadersEnv != "" { if o.HeadersEnv != "" {
// Handle JSON by default via viper // Handle JSON by default via viper
@ -450,7 +451,7 @@ func (o *Options) parseHeaders() error {
} }
if o.viperIsSet("headers") { if o.viperIsSet("headers") {
log.Warn().Msg("config: headers has been renamed to set_response_headers, it will be removed in v0.16") log.Warn(ctx).Msg("config: headers has been renamed to set_response_headers, it will be removed in v0.16")
} }
// option was renamed from `headers` to `set_response_headers`. Both config settings are supported. // option was renamed from `headers` to `set_response_headers`. Both config settings are supported.
@ -507,6 +508,7 @@ func bindEnvs(o *Options, v *viper.Viper) error {
// Validate ensures the Options fields are valid, and hydrated. // Validate ensures the Options fields are valid, and hydrated.
func (o *Options) Validate() error { func (o *Options) Validate() error {
ctx := context.TODO()
if !IsValidService(o.Services) { if !IsValidService(o.Services) {
return fmt.Errorf("config: %s is an invalid service type", o.Services) return fmt.Errorf("config: %s is an invalid service type", o.Services)
} }
@ -591,7 +593,7 @@ func (o *Options) Validate() error {
return fmt.Errorf("config: failed to parse policy: %w", err) return fmt.Errorf("config: failed to parse policy: %w", err)
} }
if err := o.parseHeaders(); err != nil { if err := o.parseHeaders(ctx); err != nil {
return fmt.Errorf("config: failed to parse headers: %w", err) return fmt.Errorf("config: failed to parse headers: %w", err)
} }
@ -669,7 +671,7 @@ func (o *Options) Validate() error {
// GoogleCloudServerlessAuthenticationServiceAccount // GoogleCloudServerlessAuthenticationServiceAccount
if o.Provider == "google" && o.GoogleCloudServerlessAuthenticationServiceAccount == "" { if o.Provider == "google" && o.GoogleCloudServerlessAuthenticationServiceAccount == "" {
o.GoogleCloudServerlessAuthenticationServiceAccount = o.ServiceAccount o.GoogleCloudServerlessAuthenticationServiceAccount = o.ServiceAccount
log.Info().Msg("defaulting to idp_service_account for google_cloud_serverless_authentication_service_account") log.Info(ctx).Msg("defaulting to idp_service_account for google_cloud_serverless_authentication_service_account")
} }
// strip quotes from redirect address (#811) // strip quotes from redirect address (#811)
@ -683,7 +685,7 @@ func (o *Options) Validate() error {
switch o.Provider { switch o.Provider {
case azure.Name, github.Name, gitlab.Name, google.Name, okta.Name, onelogin.Name: case azure.Name, github.Name, gitlab.Name, google.Name, okta.Name, onelogin.Name:
if len(o.Scopes) > 0 { if len(o.Scopes) > 0 {
log.Warn().Msg(idpCustomScopesWarnMsg) log.Warn(ctx).Msg(idpCustomScopesWarnMsg)
} }
default: default:
} }

View file

@ -1,6 +1,7 @@
package config package config
import ( import (
"context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -193,7 +194,7 @@ func Test_parseHeaders(t *testing.T) {
o.viperSet("headers", tt.viperHeaders) o.viperSet("headers", tt.viperHeaders)
o.viperSet("HeadersEnv", tt.envHeaders) o.viperSet("HeadersEnv", tt.envHeaders)
o.HeadersEnv = tt.envHeaders o.HeadersEnv = tt.envHeaders
err := o.parseHeaders() err := o.parseHeaders(context.Background())
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Error condition unexpected: err=%s", err) t.Errorf("Error condition unexpected: err=%s", err)

View file

@ -1,16 +1,18 @@
package config package config
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"sync" "sync"
octrace "go.opencensus.io/trace"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/rs/zerolog"
octrace "go.opencensus.io/trace"
) )
// TracingOptions are the options for tracing. // TracingOptions are the options for tracing.
@ -60,10 +62,13 @@ type TraceManager struct {
} }
// NewTraceManager creates a new TraceManager. // NewTraceManager creates a new TraceManager.
func NewTraceManager(src Source) *TraceManager { func NewTraceManager(ctx context.Context, src Source) *TraceManager {
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "trace_manager")
})
mgr := &TraceManager{} mgr := &TraceManager{}
src.OnConfigChange(mgr.OnConfigChange) src.OnConfigChange(ctx, mgr.OnConfigChange)
mgr.OnConfigChange(src.GetConfig()) mgr.OnConfigChange(ctx, src.GetConfig())
return mgr return mgr
} }
@ -79,18 +84,18 @@ func (mgr *TraceManager) Close() error {
} }
// OnConfigChange updates the manager whenever the configuration is changed. // OnConfigChange updates the manager whenever the configuration is changed.
func (mgr *TraceManager) OnConfigChange(cfg *Config) { func (mgr *TraceManager) OnConfigChange(ctx context.Context, cfg *Config) {
mgr.mu.Lock() mgr.mu.Lock()
defer mgr.mu.Unlock() defer mgr.mu.Unlock()
traceOpts, err := NewTracingOptions(cfg.Options) traceOpts, err := NewTracingOptions(cfg.Options)
if err != nil { if err != nil {
log.Error().Err(err).Msg("trace: failed to build tracing options") log.Error(ctx).Err(err).Msg("trace: failed to build tracing options")
return return
} }
if reflect.DeepEqual(traceOpts, mgr.traceOpts) { if reflect.DeepEqual(traceOpts, mgr.traceOpts) {
log.Debug().Msg("no change detected in trace options") log.Debug(ctx).Msg("no change detected in trace options")
return return
} }
mgr.traceOpts = traceOpts mgr.traceOpts = traceOpts
@ -104,11 +109,11 @@ func (mgr *TraceManager) OnConfigChange(cfg *Config) {
return return
} }
log.Info().Interface("options", traceOpts).Msg("trace: starting exporter") log.Info(ctx).Interface("options", traceOpts).Msg("trace: starting exporter")
mgr.exporter, err = trace.RegisterTracing(traceOpts) mgr.exporter, err = trace.RegisterTracing(traceOpts)
if err != nil { if err != nil {
log.Error().Err(err).Msg("trace: failed to register exporter") log.Error(ctx).Err(err).Msg("trace: failed to register exporter")
return return
} }
} }

View file

@ -123,12 +123,12 @@ func TestTraceManager(t *testing.T) {
TracingSampleRate: 1, TracingSampleRate: 1,
}}) }})
_ = NewTraceManager(src) _ = NewTraceManager(ctx, src)
_, span := trace.StartSpan(ctx, "Example") _, span := trace.StartSpan(ctx, "Example")
span.End() span.End()
src.SetConfig(&Config{Options: &Options{ src.SetConfig(ctx, &Config{Options: &Options{
TracingProvider: "zipkin", TracingProvider: "zipkin",
ZipkinEndpoint: srv2.URL, ZipkinEndpoint: srv2.URL,
TracingSampleRate: 1, TracingSampleRate: 1,

View file

@ -104,13 +104,13 @@ func New(cfg *config.Config) (*DataBroker, error) {
} }
// OnConfigChange is called whenever configuration is changed. // OnConfigChange is called whenever configuration is changed.
func (c *DataBroker) OnConfigChange(cfg *config.Config) { func (c *DataBroker) OnConfigChange(ctx context.Context, cfg *config.Config) {
err := c.update(cfg) err := c.update(cfg)
if err != nil { if err != nil {
log.Error().Err(err).Msg("databroker: error updating configuration") log.Error(ctx).Err(err).Msg("databroker: error updating configuration")
} }
c.dataBrokerServer.OnConfigChange(cfg) c.dataBrokerServer.OnConfigChange(ctx, cfg)
} }
// Register registers all the gRPC services with the given server. // Register registers all the gRPC services with the given server.

View file

@ -27,7 +27,7 @@ func newDataBrokerServer(cfg *config.Config) *dataBrokerServer {
} }
// OnConfigChange updates the underlying databroker server whenever configuration is changed. // OnConfigChange updates the underlying databroker server whenever configuration is changed.
func (srv *dataBrokerServer) OnConfigChange(cfg *config.Config) { func (srv *dataBrokerServer) OnConfigChange(ctx context.Context, cfg *config.Config) {
srv.server.UpdateConfig(srv.getOptions(cfg)...) srv.server.UpdateConfig(srv.getOptions(cfg)...)
srv.setKey(cfg) srv.setKey(cfg)
} }

View file

@ -5,7 +5,8 @@ import (
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
"github.com/rs/zerolog/log" "github.com/pomerium/pomerium/internal/log"
"golang.org/x/net/publicsuffix" "golang.org/x/net/publicsuffix"
) )
@ -51,6 +52,6 @@ type loggingRoundTripper struct {
func (rt *loggingRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) { func (rt *loggingRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) {
res, err = rt.RoundTripper.RoundTrip(req) res, err = rt.RoundTripper.RoundTrip(req)
log.Debug().Str("method", req.Method).Str("url", req.URL.String()).Msg("http request") log.Debug(req.Context()).Str("method", req.Method).Str("url", req.URL.String()).Msg("http request")
return res, err return res, err
} }

View file

@ -8,7 +8,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"github.com/rs/zerolog/log" "github.com/pomerium/pomerium/internal/log"
) )
type cmdOption func(*exec.Cmd) type cmdOption func(*exec.Cmd)
@ -53,7 +53,7 @@ func run(ctx context.Context, name string, options ...cmdOption) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to create stderr pipe for %s: %w", name, err) return fmt.Errorf("failed to create stderr pipe for %s: %w", name, err)
} }
go cmdLogger(stderr) go cmdLogger(ctx, stderr)
defer stderr.Close() defer stderr.Close()
} }
if cmd.Stdout == nil { if cmd.Stdout == nil {
@ -61,17 +61,17 @@ func run(ctx context.Context, name string, options ...cmdOption) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to create stdout pipe for %s: %w", name, err) return fmt.Errorf("failed to create stdout pipe for %s: %w", name, err)
} }
go cmdLogger(stdout) go cmdLogger(ctx, stdout)
defer stdout.Close() defer stdout.Close()
} }
log.Debug().Strs("args", cmd.Args).Msgf("running %s", name) log.Debug(ctx).Strs("args", cmd.Args).Msgf("running %s", name)
return cmd.Run() return cmd.Run()
} }
func cmdLogger(rdr io.Reader) { func cmdLogger(ctx context.Context, rdr io.Reader) {
s := bufio.NewScanner(rdr) s := bufio.NewScanner(rdr)
for s.Scan() { for s.Scan() {
log.Debug().Msg(s.Text()) log.Debug(ctx).Msg(s.Text())
} }
} }

View file

@ -15,9 +15,9 @@ import (
"time" "time"
"github.com/google/go-jsonnet" "github.com/google/go-jsonnet"
"github.com/rs/zerolog/log"
"github.com/pomerium/pomerium/integration/internal/netutil" "github.com/pomerium/pomerium/integration/internal/netutil"
"github.com/pomerium/pomerium/internal/log"
) )
var requiredDeployments = []string{ var requiredDeployments = []string{
@ -180,7 +180,7 @@ func applyManifests(ctx context.Context, jsonsrc string) error {
return fmt.Errorf("error applying manifests: %w", err) return fmt.Errorf("error applying manifests: %w", err)
} }
log.Info().Msg("waiting for deployments to come up") log.Info(ctx).Msg("waiting for deployments to come up")
ctx, clearTimeout := context.WithTimeout(ctx, 15*time.Minute) ctx, clearTimeout := context.WithTimeout(ctx, 15*time.Minute)
defer clearTimeout() defer clearTimeout()
ticker := time.NewTicker(time.Second * 5) ticker := time.NewTicker(time.Second * 5)
@ -218,7 +218,7 @@ func applyManifests(ctx context.Context, jsonsrc string) error {
for _, dep := range requiredDeployments { for _, dep := range requiredDeployments {
if byName[dep] < 1 { if byName[dep] < 1 {
done = false done = false
log.Warn().Str("deployment", dep).Msg("deployment is not ready yet") log.Warn(ctx).Str("deployment", dep).Msg("deployment is not ready yet")
} }
} }
if done { if done {
@ -233,7 +233,7 @@ func applyManifests(ctx context.Context, jsonsrc string) error {
<-ticker.C <-ticker.C
} }
time.Sleep(time.Minute) time.Sleep(time.Minute)
log.Info().Msg("all deployments are ready") log.Info(ctx).Msg("all deployments are ready")
return nil return nil
} }

View file

@ -69,15 +69,15 @@ func newManager(ctx context.Context,
if err != nil { if err != nil {
return nil, err return nil, err
} }
mgr.src.OnConfigChange(func(cfg *config.Config) { mgr.src.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) {
err := mgr.update(cfg) err := mgr.update(cfg)
if err != nil { if err != nil {
log.Error().Err(err).Msg("autocert: error updating config") log.Error(ctx).Err(err).Msg("autocert: error updating config")
return return
} }
cfg = mgr.GetConfig() cfg = mgr.GetConfig()
mgr.Trigger(cfg) mgr.Trigger(ctx, cfg)
}) })
go func() { go func() {
ticker := time.NewTicker(checkInterval) ticker := time.NewTicker(checkInterval)
@ -90,7 +90,7 @@ func newManager(ctx context.Context,
case <-ticker.C: case <-ticker.C:
err := mgr.renewConfigCerts() err := mgr.renewConfigCerts()
if err != nil { if err != nil {
log.Error().Err(err).Msg("autocert: error updating config") log.Error(context.TODO()).Err(err).Msg("autocert: error updating config")
return return
} }
} }
@ -153,7 +153,7 @@ func (mgr *Manager) renewConfigCerts() error {
} }
mgr.config = cfg mgr.config = cfg
mgr.Trigger(cfg) mgr.Trigger(context.TODO(), cfg)
return nil return nil
} }
@ -172,10 +172,10 @@ func (mgr *Manager) update(cfg *config.Config) error {
func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.Certificate, error) { func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.Certificate, error) {
cert, err := cm.CacheManagedCertificate(domain) cert, err := cm.CacheManagedCertificate(domain)
if err != nil { if err != nil {
log.Info().Str("domain", domain).Msg("obtaining certificate") log.Info(context.TODO()).Str("domain", domain).Msg("obtaining certificate")
err = cm.ObtainCert(context.Background(), domain, false) err = cm.ObtainCert(context.Background(), domain, false)
if err != nil { if err != nil {
log.Error().Err(err).Msg("autocert failed to obtain client certificate") log.Error(context.TODO()).Err(err).Msg("autocert failed to obtain client certificate")
return certmagic.Certificate{}, errObtainCertFailed return certmagic.Certificate{}, errObtainCertFailed
} }
metrics.RecordAutocertRenewal() metrics.RecordAutocertRenewal()
@ -187,13 +187,13 @@ func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.C
// renewCert attempts to renew given certificate. // renewCert attempts to renew given certificate.
func (mgr *Manager) renewCert(domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) { func (mgr *Manager) renewCert(domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) {
expired := time.Now().After(cert.Leaf.NotAfter) expired := time.Now().After(cert.Leaf.NotAfter)
log.Info().Str("domain", domain).Msg("renewing certificate") log.Info(context.TODO()).Str("domain", domain).Msg("renewing certificate")
err := cm.RenewCert(context.Background(), domain, false) err := cm.RenewCert(context.Background(), domain, false)
if err != nil { if err != nil {
if expired { if expired {
return certmagic.Certificate{}, errRenewCertFailed return certmagic.Certificate{}, errRenewCertFailed
} }
log.Warn().Err(err).Msg("renew client certificated failed, use existing cert") log.Warn(context.TODO()).Err(err).Msg("renew client certificated failed, use existing cert")
} }
return cm.CacheManagedCertificate(domain) return cm.CacheManagedCertificate(domain)
} }
@ -220,11 +220,11 @@ func (mgr *Manager) updateAutocert(cfg *config.Config) error {
return fmt.Errorf("autocert: failed to renew client certificate: %w", err) return fmt.Errorf("autocert: failed to renew client certificate: %w", err)
} }
if err != nil { if err != nil {
log.Error().Err(err).Msg("autocert: failed to obtain client certificate") log.Error(context.TODO()).Err(err).Msg("autocert: failed to obtain client certificate")
continue continue
} }
log.Info().Strs("names", cert.Names).Msg("autocert: added certificate") log.Info(context.TODO()).Strs("names", cert.Names).Msg("autocert: added certificate")
cfg.AutoCertificates = append(cfg.AutoCertificates, cert.Certificate) cfg.AutoCertificates = append(cfg.AutoCertificates, cert.Certificate)
} }
@ -260,10 +260,10 @@ func (mgr *Manager) updateServer(cfg *config.Config) {
}), }),
} }
go func() { go func() {
log.Info().Str("addr", hsrv.Addr).Msg("starting http redirect server") log.Info(context.TODO()).Str("addr", hsrv.Addr).Msg("starting http redirect server")
err := hsrv.ListenAndServe() err := hsrv.ListenAndServe()
if err != nil { if err != nil {
log.Error().Err(err).Msg("failed to run http redirect server") log.Error(context.TODO()).Err(err).Msg("failed to run http redirect server")
} }
}() }()
mgr.srv = hsrv mgr.srv = hsrv

View file

@ -12,6 +12,7 @@ import (
"syscall" "syscall"
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/pomerium/pomerium/authenticate" "github.com/pomerium/pomerium/authenticate"
@ -32,7 +33,7 @@ import (
// Run runs the main pomerium application. // Run runs the main pomerium application.
func Run(ctx context.Context, configFile string) error { func Run(ctx context.Context, configFile string) error {
log.Info().Str("version", version.FullVersion()).Msg("cmd/pomerium") log.Info(ctx).Str("version", version.FullVersion()).Msg("cmd/pomerium")
var src config.Source var src config.Source
@ -41,8 +42,8 @@ func Run(ctx context.Context, configFile string) error {
return err return err
} }
src = databroker.NewConfigSource(src) src = databroker.NewConfigSource(ctx, src)
logMgr := config.NewLogManager(src) logMgr := config.NewLogManager(ctx, src)
defer logMgr.Close() defer logMgr.Close()
// trigger changes when underlying files are changed // trigger changes when underlying files are changed
@ -56,9 +57,9 @@ func Run(ctx context.Context, configFile string) error {
// override the default http transport so we can use the custom CA in the TLS client config (#1570) // override the default http transport so we can use the custom CA in the TLS client config (#1570)
http.DefaultTransport = config.NewHTTPTransport(src) http.DefaultTransport = config.NewHTTPTransport(src)
metricsMgr := config.NewMetricsManager(src) metricsMgr := config.NewMetricsManager(ctx, src)
defer metricsMgr.Close() defer metricsMgr.Close()
traceMgr := config.NewTraceManager(src) traceMgr := config.NewTraceManager(ctx, src)
defer traceMgr.Close() defer traceMgr.Close()
// setup the control plane // setup the control plane
@ -66,43 +67,46 @@ func Run(ctx context.Context, configFile string) error {
if err != nil { if err != nil {
return fmt.Errorf("error creating control plane: %w", err) return fmt.Errorf("error creating control plane: %w", err)
} }
src.OnConfigChange(func(cfg *config.Config) { src.OnConfigChange(ctx,
if err := controlPlane.OnConfigChange(cfg); err != nil { func(ctx context.Context, cfg *config.Config) {
log.Error().Err(err).Msg("config change") if err := controlPlane.OnConfigChange(ctx, cfg); err != nil {
} log.Error(ctx).Err(err).Msg("config change")
}) }
})
if err = controlPlane.OnConfigChange(src.GetConfig()); err != nil { if err = controlPlane.OnConfigChange(log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("config_file_source", configFile).Bool("bootstrap", true)
}), src.GetConfig()); err != nil {
return fmt.Errorf("applying config: %w", err) return fmt.Errorf("applying config: %w", err)
} }
_, grpcPort, _ := net.SplitHostPort(controlPlane.GRPCListener.Addr().String()) _, grpcPort, _ := net.SplitHostPort(controlPlane.GRPCListener.Addr().String())
_, httpPort, _ := net.SplitHostPort(controlPlane.HTTPListener.Addr().String()) _, httpPort, _ := net.SplitHostPort(controlPlane.HTTPListener.Addr().String())
log.Info().Str("port", grpcPort).Msg("gRPC server started") log.Info(ctx).Str("port", grpcPort).Msg("gRPC server started")
log.Info().Str("port", httpPort).Msg("HTTP server started") log.Info(ctx).Str("port", httpPort).Msg("HTTP server started")
// create envoy server // create envoy server
envoyServer, err := envoy.NewServer(src, grpcPort, httpPort, controlPlane.Builder) envoyServer, err := envoy.NewServer(ctx, src, grpcPort, httpPort, controlPlane.Builder)
if err != nil { if err != nil {
return fmt.Errorf("error creating envoy server: %w", err) return fmt.Errorf("error creating envoy server: %w", err)
} }
defer envoyServer.Close() defer envoyServer.Close()
// add services // add services
if err := setupAuthenticate(src, controlPlane); err != nil { if err := setupAuthenticate(ctx, src, controlPlane); err != nil {
return err return err
} }
var authorizeServer *authorize.Authorize var authorizeServer *authorize.Authorize
if config.IsAuthorize(src.GetConfig().Options.Services) { if config.IsAuthorize(src.GetConfig().Options.Services) {
authorizeServer, err = setupAuthorize(src, controlPlane) authorizeServer, err = setupAuthorize(ctx, src, controlPlane)
if err != nil { if err != nil {
return err return err
} }
} }
var dataBrokerServer *databroker_service.DataBroker var dataBrokerServer *databroker_service.DataBroker
if config.IsDataBroker(src.GetConfig().Options.Services) { if config.IsDataBroker(src.GetConfig().Options.Services) {
dataBrokerServer, err = setupDataBroker(src, controlPlane) dataBrokerServer, err = setupDataBroker(ctx, src, controlPlane)
if err != nil { if err != nil {
return fmt.Errorf("setting up databroker: %w", err) return fmt.Errorf("setting up databroker: %w", err)
} }
@ -112,10 +116,10 @@ func Run(ctx context.Context, configFile string) error {
} }
} }
if err = setupRegistryReporter(src); err != nil { if err = setupRegistryReporter(ctx, src); err != nil {
return fmt.Errorf("setting up registry reporter: %w", err) return fmt.Errorf("setting up registry reporter: %w", err)
} }
if err := setupProxy(src, controlPlane); err != nil { if err := setupProxy(ctx, src, controlPlane); err != nil {
return err return err
} }
@ -159,7 +163,7 @@ func Run(ctx context.Context, configFile string) error {
return eg.Wait() return eg.Wait()
} }
func setupAuthenticate(src config.Source, controlPlane *controlplane.Server) error { func setupAuthenticate(ctx context.Context, src config.Source, controlPlane *controlplane.Server) error {
if !config.IsAuthenticate(src.GetConfig().Options.Services) { if !config.IsAuthenticate(src.GetConfig().Options.Services) {
return nil return nil
} }
@ -174,56 +178,56 @@ func setupAuthenticate(src config.Source, controlPlane *controlplane.Server) err
return fmt.Errorf("error getting authenticate URL: %w", err) return fmt.Errorf("error getting authenticate URL: %w", err)
} }
src.OnConfigChange(svc.OnConfigChange) src.OnConfigChange(ctx, svc.OnConfigChange)
svc.OnConfigChange(src.GetConfig()) svc.OnConfigChange(ctx, src.GetConfig())
host := urlutil.StripPort(authenticateURL.Host) host := urlutil.StripPort(authenticateURL.Host)
sr := controlPlane.HTTPRouter.Host(host).Subrouter() sr := controlPlane.HTTPRouter.Host(host).Subrouter()
svc.Mount(sr) svc.Mount(sr)
log.Info().Str("host", host).Msg("enabled authenticate service") log.Info(context.TODO()).Str("host", host).Msg("enabled authenticate service")
return nil return nil
} }
func setupAuthorize(src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) { func setupAuthorize(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) {
svc, err := authorize.New(src.GetConfig()) svc, err := authorize.New(src.GetConfig())
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating authorize service: %w", err) return nil, fmt.Errorf("error creating authorize service: %w", err)
} }
envoy_service_auth_v3.RegisterAuthorizationServer(controlPlane.GRPCServer, svc) envoy_service_auth_v3.RegisterAuthorizationServer(controlPlane.GRPCServer, svc)
log.Info().Msg("enabled authorize service") log.Info(context.TODO()).Msg("enabled authorize service")
src.OnConfigChange(svc.OnConfigChange) src.OnConfigChange(ctx, svc.OnConfigChange)
svc.OnConfigChange(src.GetConfig()) svc.OnConfigChange(ctx, src.GetConfig())
return svc, nil return svc, nil
} }
func setupDataBroker(src config.Source, controlPlane *controlplane.Server) (*databroker_service.DataBroker, error) { func setupDataBroker(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*databroker_service.DataBroker, error) {
svc, err := databroker_service.New(src.GetConfig()) svc, err := databroker_service.New(src.GetConfig())
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating databroker service: %w", err) return nil, fmt.Errorf("error creating databroker service: %w", err)
} }
svc.Register(controlPlane.GRPCServer) svc.Register(controlPlane.GRPCServer)
log.Info().Msg("enabled databroker service") log.Info(context.TODO()).Msg("enabled databroker service")
src.OnConfigChange(svc.OnConfigChange) src.OnConfigChange(ctx, svc.OnConfigChange)
svc.OnConfigChange(src.GetConfig()) svc.OnConfigChange(ctx, src.GetConfig())
return svc, nil return svc, nil
} }
func setupRegistryServer(src config.Source, controlPlane *controlplane.Server) error { func setupRegistryServer(src config.Source, controlPlane *controlplane.Server) error {
svc := registry.NewInMemoryServer(context.TODO(), registryTTL) svc := registry.NewInMemoryServer(context.TODO(), registryTTL)
registry_pb.RegisterRegistryServer(controlPlane.GRPCServer, svc) registry_pb.RegisterRegistryServer(controlPlane.GRPCServer, svc)
log.Info().Msg("enabled service discovery") log.Info(context.TODO()).Msg("enabled service discovery")
return nil return nil
} }
func setupRegistryReporter(src config.Source) error { func setupRegistryReporter(ctx context.Context, src config.Source) error {
reporter := new(registry.Reporter) reporter := new(registry.Reporter)
src.OnConfigChange(reporter.OnConfigChange) src.OnConfigChange(ctx, reporter.OnConfigChange)
reporter.OnConfigChange(src.GetConfig()) reporter.OnConfigChange(ctx, src.GetConfig())
return nil return nil
} }
func setupProxy(src config.Source, controlPlane *controlplane.Server) error { func setupProxy(ctx context.Context, src config.Source, controlPlane *controlplane.Server) error {
if !config.IsProxy(src.GetConfig().Options.Services) { if !config.IsProxy(src.GetConfig().Options.Services) {
return nil return nil
} }
@ -234,9 +238,9 @@ func setupProxy(src config.Source, controlPlane *controlplane.Server) error {
} }
controlPlane.HTTPRouter.PathPrefix("/").Handler(svc) controlPlane.HTTPRouter.PathPrefix("/").Handler(svc)
log.Info().Msg("enabled proxy service") log.Info(context.TODO()).Msg("enabled proxy service")
src.OnConfigChange(svc.OnConfigChange) src.OnConfigChange(ctx, svc.OnConfigChange)
svc.OnConfigChange(src.GetConfig()) svc.OnConfigChange(ctx, src.GetConfig())
return nil return nil
} }

View file

@ -19,7 +19,7 @@ func (srv *Server) StreamAccessLogs(stream envoy_service_accesslog_v3.AccessLogS
for { for {
msg, err := stream.Recv() msg, err := stream.Recv()
if err != nil { if err != nil {
log.Error().Err(err).Msg("access log stream error, disconnecting") log.Error(stream.Context()).Err(err).Msg("access log stream error, disconnecting")
return err return err
} }
@ -27,9 +27,9 @@ func (srv *Server) StreamAccessLogs(stream envoy_service_accesslog_v3.AccessLogS
reqPath := entry.GetRequest().GetPath() reqPath := entry.GetRequest().GetPath()
var evt *zerolog.Event var evt *zerolog.Event
if reqPath == "/ping" || reqPath == "/healthz" { if reqPath == "/ping" || reqPath == "/healthz" {
evt = log.Debug() evt = log.Debug(stream.Context())
} else { } else {
evt = log.Info() evt = log.Info(stream.Context())
} }
// common properties // common properties
evt = evt.Str("service", "envoy") evt = evt.Str("service", "envoy")

View file

@ -9,6 +9,7 @@ import (
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@ -106,7 +107,11 @@ func NewServer(name string, metricsMgr *config.MetricsManager) (*Server, error)
srv.reproxy, srv.reproxy,
) )
res, err := srv.buildDiscoveryResources() ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
return c.Str("server_name", name)
})
res, err := srv.buildDiscoveryResources(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -123,7 +128,7 @@ func (srv *Server) Run(ctx context.Context) error {
// start the gRPC server // start the gRPC server
eg.Go(func() error { eg.Go(func() error {
log.Info().Str("addr", srv.GRPCListener.Addr().String()).Msg("starting control-plane gRPC server") log.Info(ctx).Str("addr", srv.GRPCListener.Addr().String()).Msg("starting control-plane gRPC server")
return srv.GRPCServer.Serve(srv.GRPCListener) return srv.GRPCServer.Serve(srv.GRPCListener)
}) })
@ -160,7 +165,7 @@ func (srv *Server) Run(ctx context.Context) error {
// start the HTTP server // start the HTTP server
eg.Go(func() error { eg.Go(func() error {
log.Info().Str("addr", srv.HTTPListener.Addr().String()).Msg("starting control-plane HTTP server") log.Info(ctx).Str("addr", srv.HTTPListener.Addr().String()).Msg("starting control-plane HTTP server")
return hsrv.Serve(srv.HTTPListener) return hsrv.Serve(srv.HTTPListener)
}) })
@ -178,17 +183,17 @@ func (srv *Server) Run(ctx context.Context) error {
} }
// OnConfigChange updates the pomerium config options. // OnConfigChange updates the pomerium config options.
func (srv *Server) OnConfigChange(cfg *config.Config) error { func (srv *Server) OnConfigChange(ctx context.Context, cfg *config.Config) error {
srv.reproxy.Update(cfg) srv.reproxy.Update(ctx, cfg)
prev := srv.currentConfig.Load() prev := srv.currentConfig.Load()
srv.currentConfig.Store(versionedConfig{ srv.currentConfig.Store(versionedConfig{
Config: cfg, Config: cfg,
version: prev.version + 1, version: prev.version + 1,
}) })
res, err := srv.buildDiscoveryResources() res, err := srv.buildDiscoveryResources(ctx)
if err != nil { if err != nil {
return err return err
} }
srv.xdsmgr.Update(res) srv.xdsmgr.Update(ctx, res)
return nil return nil
} }

View file

@ -1,6 +1,7 @@
package controlplane package controlplane
import ( import (
"context"
"encoding/hex" "encoding/hex"
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
@ -14,11 +15,11 @@ const (
listenerTypeURL = "type.googleapis.com/envoy.config.listener.v3.Listener" listenerTypeURL = "type.googleapis.com/envoy.config.listener.v3.Listener"
) )
func (srv *Server) buildDiscoveryResources() (map[string][]*envoy_service_discovery_v3.Resource, error) { func (srv *Server) buildDiscoveryResources(ctx context.Context) (map[string][]*envoy_service_discovery_v3.Resource, error) {
resources := map[string][]*envoy_service_discovery_v3.Resource{} resources := map[string][]*envoy_service_discovery_v3.Resource{}
cfg := srv.currentConfig.Load() cfg := srv.currentConfig.Load()
clusters, err := srv.Builder.BuildClusters(cfg.Config) clusters, err := srv.Builder.BuildClusters(ctx, cfg.Config)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -2,6 +2,7 @@
package xdsmgr package xdsmgr
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"sync" "sync"
@ -51,7 +52,7 @@ func (mgr *Manager) DeltaAggregatedResources(
stateByTypeURL := map[string]*streamState{} stateByTypeURL := map[string]*streamState{}
getDeltaResponse := func(typeURL string) *envoy_service_discovery_v3.DeltaDiscoveryResponse { getDeltaResponse := func(ctx context.Context, typeURL string) *envoy_service_discovery_v3.DeltaDiscoveryResponse {
mgr.mu.Lock() mgr.mu.Lock()
defer mgr.mu.Unlock() defer mgr.mu.Unlock()
@ -85,7 +86,7 @@ func (mgr *Manager) DeltaAggregatedResources(
return res return res
} }
handleDeltaRequest := func(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) { handleDeltaRequest := func(ctx context.Context, req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
mgr.mu.Lock() mgr.mu.Lock()
defer mgr.mu.Unlock() defer mgr.mu.Unlock()
@ -109,8 +110,9 @@ func (mgr *Manager) DeltaAggregatedResources(
case req.GetErrorDetail() != nil: case req.GetErrorDetail() != nil:
// a NACK // a NACK
bs, _ := json.Marshal(req.ErrorDetail.Details) bs, _ := json.Marshal(req.ErrorDetail.Details)
log.Error(). log.Error(ctx).
Err(errors.New(req.ErrorDetail.Message)). Err(errors.New(req.ErrorDetail.Message)).
Str("nonce", req.ResponseNonce).
Int32("code", req.ErrorDetail.Code). Int32("code", req.ErrorDetail.Code).
RawJSON("details", bs).Msg("error applying configuration") RawJSON("details", bs).Msg("error applying configuration")
// - set the client resource versions to the current resource versions // - set the client resource versions to the current resource versions
@ -121,12 +123,18 @@ func (mgr *Manager) DeltaAggregatedResources(
case req.GetResponseNonce() == mgr.nonce: case req.GetResponseNonce() == mgr.nonce:
// an ACK for the last response // an ACK for the last response
// - set the client resource versions to the current resource versions // - set the client resource versions to the current resource versions
log.Debug(ctx).
Str("nonce", req.ResponseNonce).
Msg("ACK")
state.clientResourceVersions = make(map[string]string) state.clientResourceVersions = make(map[string]string)
for _, resource := range mgr.resources[req.GetTypeUrl()] { for _, resource := range mgr.resources[req.GetTypeUrl()] {
state.clientResourceVersions[resource.Name] = resource.Version state.clientResourceVersions[resource.Name] = resource.Version
} }
default: default:
// an ACK for a response that's not the last response // an ACK for a response that's not the last response
log.Debug(ctx).
Str("nonce", req.ResponseNonce).
Msg("stale ACK")
} }
// update subscriptions // update subscriptions
@ -168,15 +176,16 @@ func (mgr *Manager) DeltaAggregatedResources(
}) })
// 2. handle incoming requests or resource changes // 2. handle incoming requests or resource changes
eg.Go(func() error { eg.Go(func() error {
changeCtx := ctx
for { for {
var typeURLs []string var typeURLs []string
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case req := <-incoming: case req := <-incoming:
handleDeltaRequest(req) handleDeltaRequest(changeCtx, req)
typeURLs = []string{req.GetTypeUrl()} typeURLs = []string{req.GetTypeUrl()}
case <-ch: case changeCtx = <-ch:
mgr.mu.Lock() mgr.mu.Lock()
for typeURL := range mgr.resources { for typeURL := range mgr.resources {
typeURLs = append(typeURLs, typeURL) typeURLs = append(typeURLs, typeURL)
@ -185,7 +194,7 @@ func (mgr *Manager) DeltaAggregatedResources(
} }
for _, typeURL := range typeURLs { for _, typeURL := range typeURLs {
res := getDeltaResponse(typeURL) res := getDeltaResponse(changeCtx, typeURL)
if res == nil { if res == nil {
continue continue
} }
@ -194,6 +203,10 @@ func (mgr *Manager) DeltaAggregatedResources(
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case outgoing <- res: case outgoing <- res:
log.Info(changeCtx).
Str("nounce", res.Nonce).
Str("type", res.TypeUrl).
Msg("send update")
} }
} }
} }
@ -224,11 +237,11 @@ func (mgr *Manager) StreamAggregatedResources(
// Update updates the state of resources. If any changes are made they will be pushed to any listening // Update updates the state of resources. If any changes are made they will be pushed to any listening
// streams. For each TypeURL the list of resources should be the complete list of resources. // streams. For each TypeURL the list of resources should be the complete list of resources.
func (mgr *Manager) Update(resources map[string][]*envoy_service_discovery_v3.Resource) { func (mgr *Manager) Update(ctx context.Context, resources map[string][]*envoy_service_discovery_v3.Resource) {
mgr.mu.Lock() mgr.mu.Lock()
mgr.nonce = uuid.New().String() mgr.nonce = uuid.New().String()
mgr.resources = resources mgr.resources = resources
mgr.mu.Unlock() mgr.mu.Unlock()
mgr.signal.Broadcast() mgr.signal.Broadcast(ctx)
} }

View file

@ -28,7 +28,7 @@ func TestManager(t *testing.T) {
origOnHandleDeltaRequest := onHandleDeltaRequest origOnHandleDeltaRequest := onHandleDeltaRequest
defer func() { onHandleDeltaRequest = origOnHandleDeltaRequest }() defer func() { onHandleDeltaRequest = origOnHandleDeltaRequest }()
onHandleDeltaRequest = func(state *streamState) { onHandleDeltaRequest = func(state *streamState) {
stateChanged.Broadcast() stateChanged.Broadcast(ctx)
} }
srv := grpc.NewServer() srv := grpc.NewServer()
@ -94,7 +94,7 @@ func TestManager(t *testing.T) {
}, msg.GetResources()) }, msg.GetResources())
ack(msg.Nonce) ack(msg.Nonce)
mgr.Update(map[string][]*envoy_service_discovery_v3.Resource{ mgr.Update(ctx, map[string][]*envoy_service_discovery_v3.Resource{
typeURL: {{Name: "r1", Version: "2"}}, typeURL: {{Name: "r1", Version: "2"}},
}) })
@ -105,7 +105,7 @@ func TestManager(t *testing.T) {
}, msg.GetResources()) }, msg.GetResources())
ack(msg.Nonce) ack(msg.Nonce)
mgr.Update(map[string][]*envoy_service_discovery_v3.Resource{ mgr.Update(ctx, map[string][]*envoy_service_discovery_v3.Resource{
typeURL: nil, typeURL: nil,
}) })

View file

@ -1,6 +1,7 @@
package databroker package databroker
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"time" "time"
@ -65,7 +66,7 @@ func WithSharedKey(sharedKey string) ServerOption {
return func(cfg *serverConfig) { return func(cfg *serverConfig) {
key, err := base64.StdEncoding.DecodeString(sharedKey) key, err := base64.StdEncoding.DecodeString(sharedKey)
if err != nil || len(key) != cryptutil.DefaultKeySize { if err != nil || len(key) != cryptutil.DefaultKeySize {
log.Error().Err(err).Msgf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize) log.Error(context.TODO()).Err(err).Msgf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
return return
} }
cfg.secret = key cfg.secret = key

View file

@ -35,22 +35,22 @@ type dbConfig struct {
} }
// NewConfigSource creates a new ConfigSource. // NewConfigSource creates a new ConfigSource.
func NewConfigSource(underlying config.Source, listeners ...config.ChangeListener) *ConfigSource { func NewConfigSource(ctx context.Context, underlying config.Source, listeners ...config.ChangeListener) *ConfigSource {
src := &ConfigSource{ src := &ConfigSource{
dbConfigs: map[string]dbConfig{}, dbConfigs: map[string]dbConfig{},
} }
for _, li := range listeners { for _, li := range listeners {
src.OnConfigChange(li) src.OnConfigChange(ctx, li)
} }
underlying.OnConfigChange(func(cfg *config.Config) { underlying.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) {
src.mu.Lock() src.mu.Lock()
src.underlyingConfig = cfg.Clone() src.underlyingConfig = cfg.Clone()
src.mu.Unlock() src.mu.Unlock()
src.rebuild(false) src.rebuild(ctx, firstTime(false))
}) })
src.underlyingConfig = underlying.GetConfig() src.underlyingConfig = underlying.GetConfig()
src.rebuild(true) src.rebuild(ctx, firstTime(true))
return src return src
} }
@ -62,8 +62,10 @@ func (src *ConfigSource) GetConfig() *config.Config {
return src.computedConfig return src.computedConfig
} }
func (src *ConfigSource) rebuild(firstTime bool) { type firstTime bool
_, span := trace.StartSpan(context.Background(), "databroker.config_source.rebuild")
func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
_, span := trace.StartSpan(ctx, "databroker.config_source.rebuild")
defer span.End() defer span.End()
src.mu.Lock() src.mu.Lock()
@ -78,7 +80,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
for _, policy := range cfg.Options.GetAllPolicies() { for _, policy := range cfg.Options.GetAllPolicies() {
id, err := policy.RouteID() id, err := policy.RouteID()
if err != nil { if err != nil {
log.Warn().Err(err). log.Warn(ctx).Err(err).
Str("policy", policy.String()). Str("policy", policy.String()).
Msg("databroker: invalid policy config, ignoring") Msg("databroker: invalid policy config, ignoring")
return return
@ -95,7 +97,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
err := cfg.Options.Validate() err := cfg.Options.Validate()
if err != nil { if err != nil {
metrics.SetDBConfigRejected(cfg.Options.Services, id, cfgpb.version, err) metrics.SetDBConfigRejected(ctx, cfg.Options.Services, id, cfgpb.version, err)
return return
} }
@ -103,7 +105,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
policy, err := config.NewPolicyFromProto(routepb) policy, err := config.NewPolicyFromProto(routepb)
if err != nil { if err != nil {
errCount++ errCount++
log.Warn().Err(err). log.Warn(ctx).Err(err).
Str("db_config_id", id). Str("db_config_id", id).
Msg("databroker: error converting protobuf into policy") Msg("databroker: error converting protobuf into policy")
continue continue
@ -112,7 +114,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
err = policy.Validate() err = policy.Validate()
if err != nil { if err != nil {
errCount++ errCount++
log.Warn().Err(err). log.Warn(ctx).Err(err).
Str("db_config_id", id). Str("db_config_id", id).
Str("policy", policy.String()). Str("policy", policy.String()).
Msg("databroker: invalid policy, ignoring") Msg("databroker: invalid policy, ignoring")
@ -122,7 +124,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
routeID, err := policy.RouteID() routeID, err := policy.RouteID()
if err != nil { if err != nil {
errCount++ errCount++
log.Warn().Err(err). log.Warn(ctx).Err(err).
Str("db_config_id", id). Str("db_config_id", id).
Str("policy", policy.String()). Str("policy", policy.String()).
Msg("databroker: cannot establish policy route ID, ignoring") Msg("databroker: cannot establish policy route ID, ignoring")
@ -131,7 +133,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
if _, ok := seen[routeID]; ok { if _, ok := seen[routeID]; ok {
errCount++ errCount++
log.Warn().Err(err). log.Warn(ctx).Err(err).
Str("db_config_id", id). Str("db_config_id", id).
Str("policy", policy.String()). Str("policy", policy.String()).
Msg("databroker: duplicate policy detected, ignoring") Msg("databroker: duplicate policy detected, ignoring")
@ -141,7 +143,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
additionalPolicies = append(additionalPolicies, *policy) additionalPolicies = append(additionalPolicies, *policy)
} }
metrics.SetDBConfigInfo(cfg.Options.Services, id, cfgpb.version, int64(errCount)) metrics.SetDBConfigInfo(ctx, cfg.Options.Services, id, cfgpb.version, int64(errCount))
} }
// add the additional policies here since calling `Validate` will reset them. // add the additional policies here since calling `Validate` will reset them.
@ -149,10 +151,10 @@ func (src *ConfigSource) rebuild(firstTime bool) {
src.computedConfig = cfg src.computedConfig = cfg
if !firstTime { if !firstTime {
src.Trigger(cfg) src.Trigger(ctx, cfg)
} }
metrics.SetConfigInfo(cfg.Options.Services, "databroker", cfg.Checksum(), true) metrics.SetConfigInfo(ctx, cfg.Options.Services, "databroker", cfg.Checksum(), true)
} }
func (src *ConfigSource) runUpdater(cfg *config.Config) { func (src *ConfigSource) runUpdater(cfg *config.Config) {
@ -191,7 +193,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
cc, err := grpc.NewGRPCClientConn(connectionOptions) cc, err := grpc.NewGRPCClientConn(connectionOptions)
if err != nil { if err != nil {
log.Error().Err(err).Msg("databroker: failed to create gRPC connection to data broker") log.Error(context.TODO()).Err(err).Msg("databroker: failed to create gRPC connection to data broker")
return return
} }
@ -237,7 +239,7 @@ func (s *syncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64,
var cfgpb configpb.Config var cfgpb configpb.Config
err := record.GetData().UnmarshalTo(&cfgpb) err := record.GetData().UnmarshalTo(&cfgpb)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("databroker: error decoding config") log.Warn(ctx).Err(err).Msg("databroker: error decoding config")
delete(s.src.dbConfigs, record.GetId()) delete(s.src.dbConfigs, record.GetId())
continue continue
} }
@ -246,5 +248,5 @@ func (s *syncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64,
} }
s.src.mu.Unlock() s.src.mu.Unlock()
s.src.rebuild(false) s.src.rebuild(ctx, firstTime(false))
} }

View file

@ -37,9 +37,9 @@ func TestConfigSource(t *testing.T) {
base.InsecureServer = true base.InsecureServer = true
base.GRPCInsecure = true base.GRPCInsecure = true
src := NewConfigSource(config.NewStaticSource(&config.Config{ src := NewConfigSource(ctx, config.NewStaticSource(&config.Config{
Options: base, Options: base,
}), func(cfg *config.Config) { }), func(_ context.Context, cfg *config.Config) {
cfgs <- cfg cfgs <- cfg
}) })
cfgs <- src.GetConfig() cfgs <- src.GetConfig()

View file

@ -10,7 +10,6 @@ import (
"sync" "sync"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/rs/zerolog"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
@ -35,7 +34,6 @@ const (
// Server implements the databroker service using an in memory database. // Server implements the databroker service using an in memory database.
type Server struct { type Server struct {
cfg *serverConfig cfg *serverConfig
log zerolog.Logger
mu sync.RWMutex mu sync.RWMutex
version uint64 version uint64
@ -44,33 +42,31 @@ type Server struct {
// New creates a new server. // New creates a new server.
func New(options ...ServerOption) *Server { func New(options ...ServerOption) *Server {
srv := &Server{ srv := &Server{}
log: log.With().Str("service", "databroker").Logger(),
}
srv.UpdateConfig(options...) srv.UpdateConfig(options...)
return srv return srv
} }
func (srv *Server) initVersion() { func (srv *Server) initVersion(ctx context.Context) {
db, _, err := srv.getBackendLocked() db, _, err := srv.getBackendLocked()
if err != nil { if err != nil {
log.Error().Err(err).Msg("failed to init server version") log.Error(ctx).Err(err).Msg("failed to init server version")
return return
} }
// Get version from storage first. // Get version from storage first.
r, err := db.Get(context.Background(), recordTypeServerVersion, serverVersionKey) r, err := db.Get(ctx, recordTypeServerVersion, serverVersionKey)
switch { switch {
case err == nil: case err == nil:
var sv wrapperspb.UInt64Value var sv wrapperspb.UInt64Value
if err := r.GetData().UnmarshalTo(&sv); err == nil { if err := r.GetData().UnmarshalTo(&sv); err == nil {
srv.log.Debug().Uint64("server_version", sv.Value).Msg("got db version from Backend") log.Debug(ctx).Uint64("server_version", sv.Value).Msg("got db version from Backend")
srv.version = sv.Value srv.version = sv.Value
} }
return return
case errors.Is(err, storage.ErrNotFound): // no server version, so we'll create a new one case errors.Is(err, storage.ErrNotFound): // no server version, so we'll create a new one
case err != nil: case err != nil:
log.Error().Err(err).Msg("failed to retrieve server version") log.Error(ctx).Err(err).Msg("failed to retrieve server version")
return return
} }
@ -81,7 +77,7 @@ func (srv *Server) initVersion() {
Id: serverVersionKey, Id: serverVersionKey,
Data: data, Data: data,
}); err != nil { }); err != nil {
srv.log.Warn().Err(err).Msg("failed to save server version.") log.Warn(ctx).Err(err).Msg("failed to save server version.")
} }
} }
@ -90,9 +86,11 @@ func (srv *Server) UpdateConfig(options ...ServerOption) {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
ctx := context.TODO()
cfg := newServerConfig(options...) cfg := newServerConfig(options...)
if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) { if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) {
log.Debug().Msg("databroker: no changes detected, re-using existing DBs") log.Debug(ctx).Msg("databroker: no changes detected, re-using existing DBs")
return return
} }
srv.cfg = cfg srv.cfg = cfg
@ -100,19 +98,19 @@ func (srv *Server) UpdateConfig(options ...ServerOption) {
if srv.backend != nil { if srv.backend != nil {
err := srv.backend.Close() err := srv.backend.Close()
if err != nil { if err != nil {
log.Error().Err(err).Msg("databroker: error closing backend") log.Error(ctx).Err(err).Msg("databroker: error closing backend")
} }
srv.backend = nil srv.backend = nil
} }
srv.initVersion() srv.initVersion(ctx)
} }
// Get gets a record from the in-memory list. // Get gets a record from the in-memory list.
func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databroker.GetResponse, error) { func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databroker.GetResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.Get") _, span := trace.StartSpan(ctx, "databroker.grpc.Get")
defer span.End() defer span.End()
srv.log.Info(). log.Info(ctx).
Str("peer", grpcutil.GetPeerAddr(ctx)). Str("peer", grpcutil.GetPeerAddr(ctx)).
Str("type", req.GetType()). Str("type", req.GetType()).
Str("id", req.GetId()). Str("id", req.GetId()).
@ -141,7 +139,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*databroker.QueryResponse, error) { func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*databroker.QueryResponse, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.Query") _, span := trace.StartSpan(ctx, "databroker.grpc.Query")
defer span.End() defer span.End()
srv.log.Info(). log.Info(ctx).
Str("peer", grpcutil.GetPeerAddr(ctx)). Str("peer", grpcutil.GetPeerAddr(ctx)).
Str("type", req.GetType()). Str("type", req.GetType()).
Str("query", req.GetQuery()). Str("query", req.GetQuery()).
@ -185,7 +183,7 @@ func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databr
defer span.End() defer span.End()
record := req.GetRecord() record := req.GetRecord()
srv.log.Info(). log.Info(ctx).
Str("peer", grpcutil.GetPeerAddr(ctx)). Str("peer", grpcutil.GetPeerAddr(ctx)).
Str("type", record.GetType()). Str("type", record.GetType()).
Str("id", record.GetId()). Str("id", record.GetId()).
@ -208,7 +206,7 @@ func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databr
func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error { func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error {
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync") _, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync")
defer span.End() defer span.End()
srv.log.Info(). log.Info(stream.Context()).
Str("peer", grpcutil.GetPeerAddr(stream.Context())). Str("peer", grpcutil.GetPeerAddr(stream.Context())).
Uint64("server_version", req.GetServerVersion()). Uint64("server_version", req.GetServerVersion()).
Uint64("record_version", req.GetRecordVersion()). Uint64("record_version", req.GetRecordVersion()).
@ -251,7 +249,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error { func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error {
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.SyncLatest") _, span := trace.StartSpan(stream.Context(), "databroker.grpc.SyncLatest")
defer span.End() defer span.End()
srv.log.Info(). log.Info(stream.Context()).
Str("peer", grpcutil.GetPeerAddr(stream.Context())). Str("peer", grpcutil.GetPeerAddr(stream.Context())).
Str("type", req.GetType()). Str("type", req.GetType()).
Msg("sync latest") Msg("sync latest")
@ -333,9 +331,10 @@ func (srv *Server) getBackendLocked() (backend storage.Backend, version uint64,
} }
func (srv *Server) newBackendLocked() (backend storage.Backend, err error) { func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
ctx := context.Background()
caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile) caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("failed to read databroker CA file") log.Warn(ctx).Err(err).Msg("failed to read databroker CA file")
} }
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
RootCAs: caCertPool, RootCAs: caCertPool,
@ -348,10 +347,10 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
switch srv.cfg.storageType { switch srv.cfg.storageType {
case config.StorageInMemoryName: case config.StorageInMemoryName:
srv.log.Info().Msg("using in-memory store") log.Info(ctx).Msg("using in-memory store")
return inmemory.New(), nil return inmemory.New(), nil
case config.StorageRedisName: case config.StorageRedisName:
srv.log.Info().Msg("using redis store") log.Info(ctx).Msg("using redis store")
backend, err = redis.New( backend, err = redis.New(
srv.cfg.storageConnectionString, srv.cfg.storageConnectionString,
redis.WithTLSConfig(tlsConfig), redis.WithTLSConfig(tlsConfig),

View file

@ -10,7 +10,6 @@ import (
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
) )
@ -19,7 +18,6 @@ func newServer(cfg *serverConfig) *Server {
return &Server{ return &Server{
version: 11, version: 11,
cfg: cfg, cfg: cfg,
log: log.With().Str("service", "databroker").Logger(),
} }
} }

View file

@ -69,19 +69,25 @@ func getConfig(options ...Option) *config {
// The Provider retrieves users and groups from gitlab. // The Provider retrieves users and groups from gitlab.
type Provider struct { type Provider struct {
cfg *config cfg *config
log zerolog.Logger
} }
// New creates a new Provider. // New creates a new Provider.
func New(options ...Option) *Provider { func New(options ...Option) *Provider {
return &Provider{ return &Provider{
cfg: getConfig(options...), cfg: getConfig(options...),
log: log.With().Str("service", "directory").Str("provider", "gitlab").Logger(),
} }
} }
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "directory").Str("provider", "gitlab")
})
}
// User returns the user record for the given id. // User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) { func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
ctx = withLog(ctx)
du := &directory.User{ du := &directory.User{
Id: userID, Id: userID,
} }
@ -107,11 +113,13 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc
// UserGroups gets the directory user groups for gitlab. // UserGroups gets the directory user groups for gitlab.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) { func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
ctx = withLog(ctx)
if p.cfg.serviceAccount == nil { if p.cfg.serviceAccount == nil {
return nil, nil, fmt.Errorf("gitlab: service account not defined") return nil, nil, fmt.Errorf("gitlab: service account not defined")
} }
p.log.Info().Msg("getting user groups") log.Info(ctx).Msg("getting user groups")
groups, err := p.listGroups(ctx, "") groups, err := p.listGroups(ctx, "")
if err != nil { if err != nil {

View file

@ -93,7 +93,6 @@ func getConfig(options ...Option) *config {
// A Provider is an Okta user group directory provider. // A Provider is an Okta user group directory provider.
type Provider struct { type Provider struct {
cfg *config cfg *config
log zerolog.Logger
lastUpdated *time.Time lastUpdated *time.Time
groups map[string]*directory.Group groups map[string]*directory.Group
} }
@ -102,13 +101,20 @@ type Provider struct {
func New(options ...Option) *Provider { func New(options ...Option) *Provider {
return &Provider{ return &Provider{
cfg: getConfig(options...), cfg: getConfig(options...),
log: log.With().Str("service", "directory").Str("provider", "okta").Logger(),
groups: make(map[string]*directory.Group), groups: make(map[string]*directory.Group),
} }
} }
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "directory").Str("provider", "okta")
})
}
// User returns the user record for the given id. // User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) { func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
ctx = withLog(ctx)
if p.cfg.serviceAccount == nil { if p.cfg.serviceAccount == nil {
return nil, ErrServiceAccountNotDefined return nil, ErrServiceAccountNotDefined
} }
@ -139,11 +145,13 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc
// UserGroups fetches the groups of which the user is a member // UserGroups fetches the groups of which the user is a member
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups // https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) { func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
ctx = withLog(ctx)
if p.cfg.serviceAccount == nil { if p.cfg.serviceAccount == nil {
return nil, nil, ErrServiceAccountNotDefined return nil, nil, ErrServiceAccountNotDefined
} }
p.log.Info().Msg("getting user groups") log.Info(ctx).Msg("getting user groups")
if p.cfg.providerURL == nil { if p.cfg.providerURL == nil {
return nil, nil, ErrProviderURLNotDefined return nil, nil, ErrProviderURLNotDefined
@ -164,7 +172,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
// the cached lookup and the local groups list // the cached lookup and the local groups list
var apiErr *APIError var apiErr *APIError
if errors.As(err, &apiErr) && apiErr.HTTPStatusCode == http.StatusNotFound { if errors.As(err, &apiErr) && apiErr.HTTPStatusCode == http.StatusNotFound {
log.Debug().Str("group", group.Id).Msg("okta: group was removed") log.Debug(ctx).Str("group", group.Id).Msg("okta: group was removed")
delete(p.groups, group.Id) delete(p.groups, group.Id)
groups = append(groups[:i], groups[i+1:]...) groups = append(groups[:i], groups[i+1:]...)
i-- i--

View file

@ -78,7 +78,6 @@ func getConfig(options ...Option) *config {
// The Provider retrieves users and groups from onelogin. // The Provider retrieves users and groups from onelogin.
type Provider struct { type Provider struct {
cfg *config cfg *config
log zerolog.Logger
mu sync.RWMutex mu sync.RWMutex
token *oauth2.Token token *oauth2.Token
@ -89,10 +88,15 @@ func New(options ...Option) *Provider {
cfg := getConfig(options...) cfg := getConfig(options...)
return &Provider{ return &Provider{
cfg: cfg, cfg: cfg,
log: log.With().Str("service", "directory").Str("provider", "onelogin").Logger(),
} }
} }
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "directory").Str("provider", "onelogin")
})
}
// User returns the user record for the given id. // User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) { func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
if p.cfg.serviceAccount == nil { if p.cfg.serviceAccount == nil {
@ -102,6 +106,8 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc
Id: userID, Id: userID,
} }
ctx = withLog(ctx)
token, err := p.getToken(ctx) token, err := p.getToken(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@ -124,7 +130,9 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
return nil, nil, fmt.Errorf("onelogin: service account not defined") return nil, nil, fmt.Errorf("onelogin: service account not defined")
} }
p.log.Info().Msg("getting user groups") ctx = withLog(ctx)
log.Info(ctx).Msg("getting user groups")
token, err := p.getToken(ctx) token, err := p.getToken(ctx)
if err != nil { if err != nil {
@ -252,7 +260,7 @@ func (p *Provider) apiGet(ctx context.Context, accessToken string, uri string, o
return "", err return "", err
} }
p.log.Info(). log.Info(ctx).
Str("url", uri). Str("url", uri).
Interface("result", result). Interface("result", result).
Msg("api request") Msg("api request")

View file

@ -50,8 +50,9 @@ func GetProvider(options Options) (provider Provider) {
globalProvider.Lock() globalProvider.Lock()
defer globalProvider.Unlock() defer globalProvider.Unlock()
ctx := context.TODO()
if globalProvider.provider != nil && cmp.Equal(globalProvider.options, options) { if globalProvider.provider != nil && cmp.Equal(globalProvider.options, options) {
log.Debug().Str("provider", options.Provider).Msg("directory: no change detected, reusing existing directory provider") log.Debug(ctx).Str("provider", options.Provider).Msg("directory: no change detected, reusing existing directory provider")
return globalProvider.provider return globalProvider.provider
} }
defer func() { defer func() {
@ -72,7 +73,7 @@ func GetProvider(options Options) (provider Provider) {
auth0.WithDomain(options.ProviderURL), auth0.WithDomain(options.ProviderURL),
auth0.WithServiceAccount(serviceAccount)) auth0.WithServiceAccount(serviceAccount))
} }
log.Warn(). log.Warn(ctx).
Str("service", "directory"). Str("service", "directory").
Str("provider", options.Provider). Str("provider", options.Provider).
Err(err). Err(err).
@ -82,7 +83,7 @@ func GetProvider(options Options) (provider Provider) {
if err == nil { if err == nil {
return azure.New(azure.WithServiceAccount(serviceAccount)) return azure.New(azure.WithServiceAccount(serviceAccount))
} }
log.Warn(). log.Warn(ctx).
Str("service", "directory"). Str("service", "directory").
Str("provider", options.Provider). Str("provider", options.Provider).
Err(err). Err(err).
@ -92,7 +93,7 @@ func GetProvider(options Options) (provider Provider) {
if err == nil { if err == nil {
return github.New(github.WithServiceAccount(serviceAccount)) return github.New(github.WithServiceAccount(serviceAccount))
} }
log.Warn(). log.Warn(ctx).
Str("service", "directory"). Str("service", "directory").
Str("provider", options.Provider). Str("provider", options.Provider).
Err(err). Err(err).
@ -107,7 +108,7 @@ func GetProvider(options Options) (provider Provider) {
gitlab.WithURL(providerURL), gitlab.WithURL(providerURL),
gitlab.WithServiceAccount(serviceAccount)) gitlab.WithServiceAccount(serviceAccount))
} }
log.Warn(). log.Warn(ctx).
Str("service", "directory"). Str("service", "directory").
Str("provider", options.Provider). Str("provider", options.Provider).
Err(err). Err(err).
@ -117,7 +118,7 @@ func GetProvider(options Options) (provider Provider) {
if err == nil { if err == nil {
return google.New(google.WithServiceAccount(serviceAccount)) return google.New(google.WithServiceAccount(serviceAccount))
} }
log.Warn(). log.Warn(ctx).
Str("service", "directory"). Str("service", "directory").
Str("provider", options.Provider). Str("provider", options.Provider).
Err(err). Err(err).
@ -129,7 +130,7 @@ func GetProvider(options Options) (provider Provider) {
okta.WithProviderURL(providerURL), okta.WithProviderURL(providerURL),
okta.WithServiceAccount(serviceAccount)) okta.WithServiceAccount(serviceAccount))
} }
log.Warn(). log.Warn(ctx).
Str("service", "directory"). Str("service", "directory").
Str("provider", options.Provider). Str("provider", options.Provider).
Err(err). Err(err).
@ -139,7 +140,7 @@ func GetProvider(options Options) (provider Provider) {
if err == nil { if err == nil {
return onelogin.New(onelogin.WithServiceAccount(serviceAccount)) return onelogin.New(onelogin.WithServiceAccount(serviceAccount))
} }
log.Warn(). log.Warn(ctx).
Str("service", "directory"). Str("service", "directory").
Str("provider", options.Provider). Str("provider", options.Provider).
Err(err). Err(err).
@ -151,14 +152,14 @@ func GetProvider(options Options) (provider Provider) {
ping.WithProviderURL(providerURL), ping.WithProviderURL(providerURL),
ping.WithServiceAccount(serviceAccount)) ping.WithServiceAccount(serviceAccount))
} }
log.Warn(). log.Warn(ctx).
Str("service", "directory"). Str("service", "directory").
Str("provider", options.Provider). Str("provider", options.Provider).
Err(err). Err(err).
Msg("invalid service account for ping directory provider") Msg("invalid service account for ping directory provider")
} }
log.Warn(). log.Warn(ctx).
Str("provider", options.Provider). Str("provider", options.Provider).
Msg("no directory provider implementation found, disabling support for groups") Msg("no directory provider implementation found, disabling support for groups")
return nullProvider{} return nullProvider{}

View file

@ -67,7 +67,7 @@ type Server struct {
} }
// NewServer creates a new server with traffic routed by envoy. // NewServer creates a new server with traffic routed by envoy.
func NewServer(src config.Source, grpcPort, httpPort string, builder *envoyconfig.Builder) (*Server, error) { func NewServer(ctx context.Context, src config.Source, grpcPort, httpPort string, builder *envoyconfig.Builder) (*Server, error) {
wd := filepath.Join(os.TempDir(), workingDirectoryName) wd := filepath.Join(os.TempDir(), workingDirectoryName)
err := os.MkdirAll(wd, embeddedEnvoyPermissions) err := os.MkdirAll(wd, embeddedEnvoyPermissions)
if err != nil { if err != nil {
@ -76,7 +76,7 @@ func NewServer(src config.Source, grpcPort, httpPort string, builder *envoyconfi
envoyPath, err := extractEmbeddedEnvoy() envoyPath, err := extractEmbeddedEnvoy()
if err != nil { if err != nil {
log.Warn().Err(err).Send() log.Warn(ctx).Err(err).Send()
envoyPath = "envoy" envoyPath = "envoy"
} }
@ -98,7 +98,7 @@ func NewServer(src config.Source, grpcPort, httpPort string, builder *envoyconfi
return nil, fmt.Errorf("invalid envoy binary, expected %s but got %s", Checksum, s) return nil, fmt.Errorf("invalid envoy binary, expected %s but got %s", Checksum, s)
} }
} else { } else {
log.Info().Msg("no checksum defined, envoy binary will not be verified!") log.Info(ctx).Msg("no checksum defined, envoy binary will not be verified!")
} }
srv := &Server{ srv := &Server{
@ -108,12 +108,12 @@ func NewServer(src config.Source, grpcPort, httpPort string, builder *envoyconfi
httpPort: httpPort, httpPort: httpPort,
envoyPath: envoyPath, envoyPath: envoyPath,
} }
go srv.runProcessCollector() go srv.runProcessCollector(ctx)
src.OnConfigChange(srv.onConfigChange) src.OnConfigChange(ctx, srv.onConfigChange)
srv.onConfigChange(src.GetConfig()) srv.onConfigChange(ctx, src.GetConfig())
log.Info(). log.Info(ctx).
Str("path", envoyPath). Str("path", envoyPath).
Str("checksum", Checksum). Str("checksum", Checksum).
Msg("running envoy") Msg("running envoy")
@ -130,7 +130,7 @@ func (srv *Server) Close() error {
if srv.cmd != nil && srv.cmd.Process != nil { if srv.cmd != nil && srv.cmd.Process != nil {
err = srv.cmd.Process.Kill() err = srv.cmd.Process.Kill()
if err != nil { if err != nil {
log.Error().Err(err).Str("service", "envoy").Msg("envoy: failed to kill process on close") log.Error(context.TODO()).Err(err).Str("service", "envoy").Msg("envoy: failed to kill process on close")
} }
srv.cmd = nil srv.cmd = nil
} }
@ -138,17 +138,17 @@ func (srv *Server) Close() error {
return err return err
} }
func (srv *Server) onConfigChange(cfg *config.Config) { func (srv *Server) onConfigChange(ctx context.Context, cfg *config.Config) {
srv.update(cfg) srv.update(ctx, cfg)
} }
func (srv *Server) update(cfg *config.Config) { func (srv *Server) update(ctx context.Context, cfg *config.Config) {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
tracingOptions, err := config.NewTracingOptions(cfg.Options) tracingOptions, err := config.NewTracingOptions(cfg.Options)
if err != nil { if err != nil {
log.Error().Err(err).Str("service", "envoy").Msg("invalid tracing config") log.Error(ctx).Err(err).Str("service", "envoy").Msg("invalid tracing config")
return return
} }
@ -159,24 +159,24 @@ func (srv *Server) update(cfg *config.Config) {
} }
if cmp.Equal(srv.options, options, cmp.AllowUnexported(serverOptions{})) { if cmp.Equal(srv.options, options, cmp.AllowUnexported(serverOptions{})) {
log.Debug().Str("service", "envoy").Msg("envoy: no config changes detected") log.Debug(ctx).Str("service", "envoy").Msg("envoy: no config changes detected")
return return
} }
srv.options = options srv.options = options
if err := srv.writeConfig(cfg); err != nil { if err := srv.writeConfig(ctx, cfg); err != nil {
log.Error().Err(err).Str("service", "envoy").Msg("envoy: failed to write envoy config") log.Error(ctx).Err(err).Str("service", "envoy").Msg("envoy: failed to write envoy config")
return return
} }
log.Info().Msg("envoy: starting envoy process") log.Info(ctx).Msg("envoy: starting envoy process")
if err := srv.run(); err != nil { if err := srv.run(ctx); err != nil {
log.Error().Err(err).Str("service", "envoy").Msg("envoy: failed to run envoy process") log.Error(ctx).Err(err).Str("service", "envoy").Msg("envoy: failed to run envoy process")
return return
} }
} }
func (srv *Server) run() error { func (srv *Server) run(ctx context.Context) error {
args := []string{ args := []string{
"-c", configFileName, "-c", configFileName,
"--log-level", srv.options.logLevel, "--log-level", srv.options.logLevel,
@ -198,13 +198,13 @@ func (srv *Server) run() error {
if err != nil { if err != nil {
return fmt.Errorf("error creating stderr pipe for envoy: %w", err) return fmt.Errorf("error creating stderr pipe for envoy: %w", err)
} }
go srv.handleLogs(stderr) go srv.handleLogs(ctx, stderr)
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
if err != nil { if err != nil {
return fmt.Errorf("error creating stderr pipe for envoy: %w", err) return fmt.Errorf("error creating stderr pipe for envoy: %w", err)
} }
go srv.handleLogs(stdout) go srv.handleLogs(ctx, stdout)
// make sure envoy is killed if we're killed // make sure envoy is killed if we're killed
cmd.SysProcAttr = sysProcAttr cmd.SysProcAttr = sysProcAttr
@ -216,10 +216,10 @@ func (srv *Server) run() error {
// release the previous process so we can hot-reload // release the previous process so we can hot-reload
if srv.cmd != nil && srv.cmd.Process != nil { if srv.cmd != nil && srv.cmd.Process != nil {
log.Info().Msg("envoy: releasing envoy process for hot-reload") log.Info(ctx).Msg("envoy: releasing envoy process for hot-reload")
err := srv.cmd.Process.Release() err := srv.cmd.Process.Release()
if err != nil { if err != nil {
log.Warn().Err(err).Str("service", "envoy").Msg("envoy: failed to release envoy process for hot-reload") log.Warn(ctx).Err(err).Str("service", "envoy").Msg("envoy: failed to release envoy process for hot-reload")
} }
} }
srv.cmd = cmd srv.cmd = cmd
@ -227,14 +227,14 @@ func (srv *Server) run() error {
return nil return nil
} }
func (srv *Server) writeConfig(cfg *config.Config) error { func (srv *Server) writeConfig(ctx context.Context, cfg *config.Config) error {
confBytes, err := srv.buildBootstrapConfig(cfg) confBytes, err := srv.buildBootstrapConfig(cfg)
if err != nil { if err != nil {
return err return err
} }
cfgPath := filepath.Join(srv.wd, configFileName) cfgPath := filepath.Join(srv.wd, configFileName)
log.Debug().Str("service", "envoy").Str("location", cfgPath).Msg("wrote config file to location") log.Debug(ctx).Str("service", "envoy").Str("location", cfgPath).Msg("wrote config file to location")
return atomic.WriteFile(cfgPath, bytes.NewReader(confBytes)) return atomic.WriteFile(cfgPath, bytes.NewReader(confBytes))
} }
@ -313,9 +313,10 @@ func (srv *Server) parseLog(line string) (name string, logLevel string, msg stri
return return
} }
func (srv *Server) handleLogs(rc io.ReadCloser) { func (srv *Server) handleLogs(ctx context.Context, rc io.ReadCloser) {
defer rc.Close() defer rc.Close()
l := log.With().Str("service", "envoy").Logger()
bo := backoff.NewExponentialBackOff() bo := backoff.NewExponentialBackOff()
s := bufio.NewReader(rc) s := bufio.NewReader(rc)
@ -325,7 +326,7 @@ func (srv *Server) handleLogs(rc io.ReadCloser) {
if errors.Is(err, io.EOF) || errors.Is(err, os.ErrClosed) { if errors.Is(err, io.EOF) || errors.Is(err, os.ErrClosed) {
break break
} }
log.Error().Err(err).Msg("failed to read log") log.Error(ctx).Err(err).Msg("failed to read log")
time.Sleep(bo.NextBackOff()) time.Sleep(bo.NextBackOff())
continue continue
} }
@ -336,7 +337,8 @@ func (srv *Server) handleLogs(rc io.ReadCloser) {
if name == "" { if name == "" {
name = "envoy" name = "envoy"
} }
lvl := zerolog.DebugLevel
lvl := zerolog.ErrorLevel
if x, err := zerolog.ParseLevel(logLevel); err == nil { if x, err := zerolog.ParseLevel(logLevel); err == nil {
lvl = x lvl = x
} }
@ -354,14 +356,13 @@ func (srv *Server) handleLogs(rc io.ReadCloser) {
continue continue
} }
log.WithLevel(lvl). l.WithLevel(lvl).
Str("service", "envoy").
Str("name", name). Str("name", name).
Msg(msg) Msg(msg)
} }
} }
func (srv *Server) runProcessCollector() { func (srv *Server) runProcessCollector(ctx context.Context) {
// macos is not supported // macos is not supported
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {
return return
@ -369,7 +370,7 @@ func (srv *Server) runProcessCollector() {
pc := metrics.NewProcessCollector("envoy") pc := metrics.NewProcessCollector("envoy")
if err := view.Register(pc.Views()...); err != nil { if err := view.Register(pc.Views()...); err != nil {
log.Error().Err(err).Msg("failed to register envoy process metric views") log.Error(ctx).Err(err).Msg("failed to register envoy process metric views")
} }
const collectInterval = time.Second * 10 const collectInterval = time.Second * 10
@ -387,7 +388,7 @@ func (srv *Server) runProcessCollector() {
if pid > 0 { if pid > 0 {
err := pc.Measure(context.Background(), pid) err := pc.Measure(context.Background(), pid)
if err != nil { if err != nil {
log.Error().Err(err).Msg("failed to measure envoy process metrics") log.Error(ctx).Err(err).Msg("failed to measure envoy process metrics")
} }
} }
} }

View file

@ -1,6 +1,7 @@
package envoy package envoy
import ( import (
"context"
"io/ioutil" "io/ioutil"
"regexp" "regexp"
"strings" "strings"
@ -43,6 +44,6 @@ func Benchmark_handleLogs(b *testing.B) {
b.ResetTimer() b.ResetTimer()
b.ReportAllocs() b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
srv.handleLogs(rc) srv.handleLogs(context.Background(), rc)
} }
} }

View file

@ -1,9 +1,11 @@
package fileutil package fileutil
import ( import (
"context"
"sync" "sync"
"github.com/rjeczalik/notify" "github.com/rjeczalik/notify"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/signal" "github.com/pomerium/pomerium/internal/signal"
@ -29,6 +31,10 @@ func (watcher *Watcher) Add(filePath string) {
watcher.mu.Lock() watcher.mu.Lock()
defer watcher.mu.Unlock() defer watcher.mu.Unlock()
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context {
return c.Str("watch_file", filePath)
})
// already watching // already watching
if _, ok := watcher.filePaths[filePath]; ok { if _, ok := watcher.filePaths[filePath]; ok {
return return
@ -37,18 +43,18 @@ func (watcher *Watcher) Add(filePath string) {
ch := make(chan notify.EventInfo, 1) ch := make(chan notify.EventInfo, 1)
go func() { go func() {
for evt := range ch { for evt := range ch {
log.Info().Str("path", evt.Path()).Str("event", evt.Event().String()).Msg("filemgr: detected file change") log.Info(ctx).Str("event", evt.Event().String()).Msg("filemgr: detected file change")
watcher.Signal.Broadcast() watcher.Signal.Broadcast(ctx)
} }
}() }()
err := notify.Watch(filePath, ch, notify.All) err := notify.Watch(filePath, ch, notify.All)
if err != nil { if err != nil {
log.Error().Err(err).Str("path", filePath).Msg("filemgr: error watching file path") log.Error(ctx).Err(err).Msg("filemgr: error watching file path")
notify.Stop(ch) notify.Stop(ch)
close(ch) close(ch)
return return
} }
log.Debug().Str("path", filePath).Msg("filemgr: watching file for changes") log.Debug(ctx).Msg("filemgr: watching file for changes")
watcher.filePaths[filePath] = ch watcher.filePaths[filePath] = ch
} }

View file

@ -2,6 +2,7 @@
package reproxy package reproxy
import ( import (
"context"
"encoding/base64" "encoding/base64"
"errors" "errors"
"math/rand" "math/rand"
@ -120,7 +121,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler {
} }
// Update updates the handler with new configuration. // Update updates the handler with new configuration.
func (h *Handler) Update(cfg *config.Config) { func (h *Handler) Update(ctx context.Context, cfg *config.Config) {
h.mu.Lock() h.mu.Lock()
defer h.mu.Unlock() defer h.mu.Unlock()
@ -130,7 +131,7 @@ func (h *Handler) Update(cfg *config.Config) {
for i, p := range cfg.Options.Policies { for i, p := range cfg.Options.Policies {
id, err := p.RouteID() id, err := p.RouteID()
if err != nil { if err != nil {
log.Warn().Err(err).Msg("reproxy: error getting route id") log.Warn(ctx).Err(err).Msg("reproxy: error getting route id")
continue continue
} }
h.policies[id] = &cfg.Options.Policies[i] h.policies[id] = &cfg.Options.Policies[i]

View file

@ -1,6 +1,7 @@
package reproxy package reproxy
import ( import (
"context"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -58,7 +59,7 @@ func TestMiddleware(t *testing.T) {
}}, }},
}, },
} }
h.Update(cfg) h.Update(context.Background(), cfg)
policyID, _ := cfg.Options.Policies[0].RouteID() policyID, _ := cfg.Options.Policies[0].RouteID()

View file

@ -97,8 +97,8 @@ func Shutdown(srv *http.Server) {
rec := <-sigint rec := <-sigint
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
log.Info().Str("signal", rec.String()).Msg("internal/httputil: shutting down servers") log.Info(context.TODO()).Str("signal", rec.String()).Msg("internal/httputil: shutting down servers")
if err := srv.Shutdown(ctx); err != nil { if err := srv.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("internal/httputil: shutdown failed") log.Error(context.TODO()).Err(err).Msg("internal/httputil: shutdown failed")
} }
} }

View file

@ -47,7 +47,6 @@ type (
// A Manager refreshes identity information using session and user data. // A Manager refreshes identity information using session and user data.
type Manager struct { type Manager struct {
cfg *atomicConfig cfg *atomicConfig
log zerolog.Logger
sessionScheduler *scheduler.Scheduler sessionScheduler *scheduler.Scheduler
userScheduler *scheduler.Scheduler userScheduler *scheduler.Scheduler
@ -68,7 +67,6 @@ func New(
) *Manager { ) *Manager {
mgr := &Manager{ mgr := &Manager{
cfg: newAtomicConfig(newConfig()), cfg: newAtomicConfig(newConfig()),
log: log.With().Str("service", "identity_manager").Logger(),
sessionScheduler: scheduler.New(), sessionScheduler: scheduler.New(),
userScheduler: scheduler.New(), userScheduler: scheduler.New(),
@ -79,6 +77,12 @@ func New(
return mgr return mgr
} }
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "identity_manager")
})
}
// UpdateConfig updates the manager with the new options. // UpdateConfig updates the manager with the new options.
func (mgr *Manager) UpdateConfig(options ...Option) { func (mgr *Manager) UpdateConfig(options ...Option) {
mgr.cfg.Store(newConfig(options...)) mgr.cfg.Store(newConfig(options...))
@ -86,10 +90,11 @@ func (mgr *Manager) UpdateConfig(options ...Option) {
// Run runs the manager. This method blocks until an error occurs or the given context is canceled. // Run runs the manager. This method blocks until an error occurs or the given context is canceled.
func (mgr *Manager) Run(ctx context.Context) error { func (mgr *Manager) Run(ctx context.Context) error {
ctx = withLog(ctx)
update := make(chan updateRecordsMessage, 1) update := make(chan updateRecordsMessage, 1)
clear := make(chan struct{}, 1) clear := make(chan struct{}, 1)
syncer := newDataBrokerSyncer(mgr.cfg, mgr.log, update, clear) syncer := newDataBrokerSyncer(ctx, mgr.cfg, update, clear)
eg, ctx := errgroup.WithContext(ctx) eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error { eg.Go(func() error {
@ -119,7 +124,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
mgr.onUpdateRecords(ctx, msg) mgr.onUpdateRecords(ctx, msg)
} }
mgr.log.Info(). log.Info(ctx).
Int("directory_groups", len(mgr.directoryGroups)). Int("directory_groups", len(mgr.directoryGroups)).
Int("directory_users", len(mgr.directoryUsers)). Int("directory_users", len(mgr.directoryUsers)).
Int("sessions", mgr.sessions.Len()). Int("sessions", mgr.sessions.Len()).
@ -196,7 +201,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
} }
func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) { func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) {
mgr.log.Info().Msg("refreshing directory users") log.Info(ctx).Msg("refreshing directory users")
ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.Load().groupRefreshTimeout) ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.Load().groupRefreshTimeout)
defer clearTimeout() defer clearTimeout()
@ -208,7 +213,7 @@ func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) {
msg += ". You may need to increase the identity provider directory timeout setting" msg += ". You may need to increase the identity provider directory timeout setting"
msg += "(https://www.pomerium.io/reference/#identity-provider-refresh-directory-settings)" msg += "(https://www.pomerium.io/reference/#identity-provider-refresh-directory-settings)"
} }
mgr.log.Warn().Err(err).Msg(msg) log.Warn(ctx).Err(err).Msg(msg)
return return
} }
@ -232,7 +237,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
id := newDG.GetId() id := newDG.GetId()
any, err := anypb.New(newDG) any, err := anypb.New(newDG)
if err != nil { if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory group") log.Warn(ctx).Err(err).Msg("failed to marshal directory group")
return return
} }
eg.Go(func() error { eg.Go(func() error {
@ -262,7 +267,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
id := curDG.GetId() id := curDG.GetId()
any, err := anypb.New(curDG) any, err := anypb.New(curDG)
if err != nil { if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory group") log.Warn(ctx).Err(err).Msg("failed to marshal directory group")
return return
} }
eg.Go(func() error { eg.Go(func() error {
@ -287,7 +292,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
} }
if err := eg.Wait(); err != nil { if err := eg.Wait(); err != nil {
mgr.log.Warn().Err(err).Msg("manager: failed to merge groups") log.Warn(ctx).Err(err).Msg("manager: failed to merge groups")
} }
} }
@ -305,7 +310,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
id := newDU.GetId() id := newDU.GetId()
any, err := anypb.New(newDU) any, err := anypb.New(newDU)
if err != nil { if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory user") log.Warn(ctx).Err(err).Msg("failed to marshal directory user")
return return
} }
eg.Go(func() error { eg.Go(func() error {
@ -335,7 +340,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
id := curDU.GetId() id := curDU.GetId()
any, err := anypb.New(curDU) any, err := anypb.New(curDU)
if err != nil { if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory user") log.Warn(ctx).Err(err).Msg("failed to marshal directory user")
return return
} }
eg.Go(func() error { eg.Go(func() error {
@ -361,19 +366,19 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
} }
if err := eg.Wait(); err != nil { if err := eg.Wait(); err != nil {
mgr.log.Warn().Err(err).Msg("manager: failed to merge users") log.Warn(ctx).Err(err).Msg("manager: failed to merge users")
} }
} }
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) { func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) {
mgr.log.Info(). log.Info(ctx).
Str("user_id", userID). Str("user_id", userID).
Str("session_id", sessionID). Str("session_id", sessionID).
Msg("refreshing session") Msg("refreshing session")
s, ok := mgr.sessions.Get(userID, sessionID) s, ok := mgr.sessions.Get(userID, sessionID)
if !ok { if !ok {
mgr.log.Warn(). log.Warn(ctx).
Str("user_id", userID). Str("user_id", userID).
Str("session_id", sessionID). Str("session_id", sessionID).
Msg("no session found for refresh") Msg("no session found for refresh")
@ -382,7 +387,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
expiry := s.GetExpiresAt().AsTime() expiry := s.GetExpiresAt().AsTime()
if !expiry.After(time.Now()) { if !expiry.After(time.Now()) {
mgr.log.Info(). log.Info(ctx).
Str("user_id", userID). Str("user_id", userID).
Str("session_id", sessionID). Str("session_id", sessionID).
Msg("deleting expired session") Msg("deleting expired session")
@ -391,7 +396,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
} }
if s.Session == nil || s.Session.OauthToken == nil { if s.Session == nil || s.Session.OauthToken == nil {
mgr.log.Warn(). log.Warn(ctx).
Str("user_id", userID). Str("user_id", userID).
Str("session_id", sessionID). Str("session_id", sessionID).
Msg("no session oauth2 token found for refresh") Msg("no session oauth2 token found for refresh")
@ -400,13 +405,13 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
newToken, err := mgr.cfg.Load().authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s) newToken, err := mgr.cfg.Load().authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s)
if isTemporaryError(err) { if isTemporaryError(err) {
mgr.log.Error().Err(err). log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token") Msg("failed to refresh oauth2 token")
return return
} else if err != nil { } else if err != nil {
mgr.log.Error().Err(err). log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token, deleting session") Msg("failed to refresh oauth2 token, deleting session")
@ -417,13 +422,13 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
err = mgr.cfg.Load().authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &s) err = mgr.cfg.Load().authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &s)
if isTemporaryError(err) { if isTemporaryError(err) {
mgr.log.Error().Err(err). log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to update user info") Msg("failed to update user info")
return return
} else if err != nil { } else if err != nil {
mgr.log.Error().Err(err). log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to update user info, deleting session") Msg("failed to update user info, deleting session")
@ -433,7 +438,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
res, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session) res, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
if err != nil { if err != nil {
mgr.log.Error().Err(err). log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to update session") Msg("failed to update session")
@ -444,13 +449,13 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
} }
func (mgr *Manager) refreshUser(ctx context.Context, userID string) { func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
mgr.log.Info(). log.Info(ctx).
Str("user_id", userID). Str("user_id", userID).
Msg("refreshing user") Msg("refreshing user")
u, ok := mgr.users.Get(userID) u, ok := mgr.users.Get(userID)
if !ok { if !ok {
mgr.log.Warn(). log.Warn(ctx).
Str("user_id", userID). Str("user_id", userID).
Msg("no user found for refresh") Msg("no user found for refresh")
return return
@ -460,7 +465,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
for _, s := range mgr.sessions.GetSessionsForUser(userID) { for _, s := range mgr.sessions.GetSessionsForUser(userID) {
if s.Session == nil || s.Session.OauthToken == nil { if s.Session == nil || s.Session.OauthToken == nil {
mgr.log.Warn(). log.Warn(ctx).
Str("user_id", userID). Str("user_id", userID).
Msg("no session oauth2 token found for refresh") Msg("no session oauth2 token found for refresh")
continue continue
@ -468,13 +473,13 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
err := mgr.cfg.Load().authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u) err := mgr.cfg.Load().authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u)
if isTemporaryError(err) { if isTemporaryError(err) {
mgr.log.Error().Err(err). log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to update user info") Msg("failed to update user info")
return return
} else if err != nil { } else if err != nil {
mgr.log.Error().Err(err). log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to update user info, deleting session") Msg("failed to update user info, deleting session")
@ -484,7 +489,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
record, err := user.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User) record, err := user.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
if err != nil { if err != nil {
mgr.log.Error().Err(err). log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to update user") Msg("failed to update user")
@ -502,7 +507,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
var pbDirectoryGroup directory.Group var pbDirectoryGroup directory.Group
err := record.GetData().UnmarshalTo(&pbDirectoryGroup) err := record.GetData().UnmarshalTo(&pbDirectoryGroup)
if err != nil { if err != nil {
mgr.log.Warn().Msgf("error unmarshaling directory group: %s", err) log.Warn(ctx).Msgf("error unmarshaling directory group: %s", err)
continue continue
} }
mgr.onUpdateDirectoryGroup(ctx, &pbDirectoryGroup) mgr.onUpdateDirectoryGroup(ctx, &pbDirectoryGroup)
@ -510,7 +515,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
var pbDirectoryUser directory.User var pbDirectoryUser directory.User
err := record.GetData().UnmarshalTo(&pbDirectoryUser) err := record.GetData().UnmarshalTo(&pbDirectoryUser)
if err != nil { if err != nil {
mgr.log.Warn().Msgf("error unmarshaling directory user: %s", err) log.Warn(ctx).Msgf("error unmarshaling directory user: %s", err)
continue continue
} }
mgr.onUpdateDirectoryUser(ctx, &pbDirectoryUser) mgr.onUpdateDirectoryUser(ctx, &pbDirectoryUser)
@ -518,7 +523,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
var pbSession session.Session var pbSession session.Session
err := record.GetData().UnmarshalTo(&pbSession) err := record.GetData().UnmarshalTo(&pbSession)
if err != nil { if err != nil {
mgr.log.Warn().Msgf("error unmarshaling session: %s", err) log.Warn(ctx).Msgf("error unmarshaling session: %s", err)
continue continue
} }
mgr.onUpdateSession(ctx, record, &pbSession) mgr.onUpdateSession(ctx, record, &pbSession)
@ -526,7 +531,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
var pbUser user.User var pbUser user.User
err := record.GetData().UnmarshalTo(&pbUser) err := record.GetData().UnmarshalTo(&pbUser)
if err != nil { if err != nil {
mgr.log.Warn().Msgf("error unmarshaling user: %s", err) log.Warn(ctx).Msgf("error unmarshaling user: %s", err)
continue continue
} }
} }
@ -578,7 +583,7 @@ func (mgr *Manager) onUpdateDirectoryGroup(_ context.Context, pbDirectoryGroup *
func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Session) { func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Session) {
err := session.Delete(ctx, mgr.cfg.Load().dataBrokerClient, pbSession.GetId()) err := session.Delete(ctx, mgr.cfg.Load().dataBrokerClient, pbSession.GetId())
if err != nil { if err != nil {
mgr.log.Error().Err(err). log.Error(ctx).Err(err).
Str("session_id", pbSession.GetId()). Str("session_id", pbSession.GetId()).
Msg("failed to delete session") Msg("failed to delete session")
} }

View file

@ -3,14 +3,11 @@ package manager
import ( import (
"context" "context"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
) )
type dataBrokerSyncer struct { type dataBrokerSyncer struct {
cfg *atomicConfig cfg *atomicConfig
log zerolog.Logger
update chan<- updateRecordsMessage update chan<- updateRecordsMessage
clear chan<- struct{} clear chan<- struct{}
@ -19,14 +16,13 @@ type dataBrokerSyncer struct {
} }
func newDataBrokerSyncer( func newDataBrokerSyncer(
ctx context.Context,
cfg *atomicConfig, cfg *atomicConfig,
log zerolog.Logger,
update chan<- updateRecordsMessage, update chan<- updateRecordsMessage,
clear chan<- struct{}, clear chan<- struct{},
) *dataBrokerSyncer { ) *dataBrokerSyncer {
syncer := &dataBrokerSyncer{ syncer := &dataBrokerSyncer{
cfg: cfg, cfg: cfg,
log: log,
update: update, update: update,
clear: clear, clear: clear,

View file

@ -142,7 +142,7 @@ func (p *Provider) userEmail(ctx context.Context, t *oauth2.Token, v interface{}
Email string `json:"email"` Email string `json:"email"`
Verified bool `json:"email_verified"` Verified bool `json:"email_verified"`
} }
log.Debug().Interface("emails", response).Msg("github: user emails") log.Debug(ctx).Interface("emails", response).Msg("github: user emails")
for _, email := range response { for _, email := range response {
if email.Primary && email.Verified { if email.Primary && email.Verified {
out.Email = email.Email out.Email = email.Email

View file

@ -88,35 +88,55 @@ func With() zerolog.Context {
} }
// Level creates a child logger with the minimum accepted level set to level. // Level creates a child logger with the minimum accepted level set to level.
func Level(level zerolog.Level) zerolog.Logger { func Level(ctx context.Context, level zerolog.Level) *zerolog.Logger {
return Logger().Level(level) l := contextLogger(ctx).Level(level)
return &l
} }
// Debug starts a new message with debug level. // Debug starts a new message with debug level.
// //
// You must call Msg on the returned event in order to send the event. // You must call Msg on the returned event in order to send the event.
func Debug() *zerolog.Event { func Debug(ctx context.Context) *zerolog.Event {
return Logger().Debug() return contextLogger(ctx).Debug()
} }
// Info starts a new message with info level. // Info starts a new message with info level.
// //
// You must call Msg on the returned event in order to send the event. // You must call Msg on the returned event in order to send the event.
func Info() *zerolog.Event { func Info(ctx context.Context) *zerolog.Event {
return Logger().Info() return contextLogger(ctx).Info()
} }
// Warn starts a new message with warn level. // Warn starts a new message with warn level.
// //
// You must call Msg on the returned event in order to send the event. // You must call Msg on the returned event in order to send the event.
func Warn() *zerolog.Event { func Warn(ctx context.Context) *zerolog.Event {
return Logger().Warn() return contextLogger(ctx).Warn()
}
func contextLogger(ctx context.Context) *zerolog.Logger {
global := Logger()
if global.GetLevel() == zerolog.Disabled {
return global
}
l := zerolog.Ctx(ctx)
if l.GetLevel() == zerolog.Disabled { // no logger associated with context
return global
}
return l
}
// WithContext returns a context that has an associated logger and extra fields set via update
func WithContext(ctx context.Context, update func(c zerolog.Context) zerolog.Context) context.Context {
l := contextLogger(ctx).With().Logger()
l.UpdateContext(update)
return l.WithContext(ctx)
} }
// Error starts a new message with error level. // Error starts a new message with error level.
// //
// You must call Msg on the returned event in order to send the event. // You must call Msg on the returned event in order to send the event.
func Error() *zerolog.Event { func Error(ctx context.Context) *zerolog.Event {
return Logger().Error() return Logger().Error()
} }
@ -136,18 +156,11 @@ func Panic() *zerolog.Event {
return Logger().Panic() return Logger().Panic()
} }
// WithLevel starts a new message with level.
//
// You must call Msg on the returned event in order to send the event.
func WithLevel(level zerolog.Level) *zerolog.Event {
return Logger().WithLevel(level)
}
// Log starts a new message with no level. Setting zerolog.GlobalLevel to // Log starts a new message with no level. Setting zerolog.GlobalLevel to
// zerolog.Disabled will still disable events produced by this method. // zerolog.Disabled will still disable events produced by this method.
// //
// You must call Msg on the returned event in order to send the event. // You must call Msg on the returned event in order to send the event.
func Log() *zerolog.Event { func Log(ctx context.Context) *zerolog.Event {
return Logger().Log() return Logger().Log()
} }

View file

@ -1,6 +1,7 @@
package log_test package log_test
import ( import (
"context"
"errors" "errors"
"flag" "flag"
"time" "time"
@ -54,7 +55,7 @@ func ExamplePrintf() {
// Example of a log with no particular "level" // Example of a log with no particular "level"
func ExampleLog() { func ExampleLog() {
setup() setup()
log.Log().Msg("hello world") log.Log(context.Background()).Msg("hello world")
// Output: {"time":1199811905,"message":"hello world"} // Output: {"time":1199811905,"message":"hello world"}
} }
@ -62,7 +63,7 @@ func ExampleLog() {
// Example of a log at a particular "level" (in this case, "debug") // Example of a log at a particular "level" (in this case, "debug")
func ExampleDebug() { func ExampleDebug() {
setup() setup()
log.Debug().Msg("hello world") log.Debug(context.Background()).Msg("hello world")
// Output: {"level":"debug","time":1199811905,"message":"hello world"} // Output: {"level":"debug","time":1199811905,"message":"hello world"}
} }
@ -70,7 +71,7 @@ func ExampleDebug() {
// Example of a log at a particular "level" (in this case, "info") // Example of a log at a particular "level" (in this case, "info")
func ExampleInfo() { func ExampleInfo() {
setup() setup()
log.Info().Msg("hello world") log.Info(context.Background()).Msg("hello world")
// Output: {"level":"info","time":1199811905,"message":"hello world"} // Output: {"level":"info","time":1199811905,"message":"hello world"}
} }
@ -78,7 +79,7 @@ func ExampleInfo() {
// Example of a log at a particular "level" (in this case, "warn") // Example of a log at a particular "level" (in this case, "warn")
func ExampleWarn() { func ExampleWarn() {
setup() setup()
log.Warn().Msg("hello world") log.Warn(context.Background()).Msg("hello world")
// Output: {"level":"warn","time":1199811905,"message":"hello world"} // Output: {"level":"warn","time":1199811905,"message":"hello world"}
} }
@ -86,7 +87,7 @@ func ExampleWarn() {
// Example of a log at a particular "level" (in this case, "error") // Example of a log at a particular "level" (in this case, "error")
func ExampleError() { func ExampleError() {
setup() setup()
log.Error().Msg("hello world") log.Error(context.Background()).Msg("hello world")
// Output: {"level":"error","time":1199811905,"message":"hello world"} // Output: {"level":"error","time":1199811905,"message":"hello world"}
} }
@ -119,10 +120,10 @@ func Example() {
zerolog.SetGlobalLevel(zerolog.DebugLevel) zerolog.SetGlobalLevel(zerolog.DebugLevel)
} }
log.Debug().Msg("This message appears only when log level set to Debug") log.Debug(context.Background()).Msg("This message appears only when log level set to Debug")
log.Info().Msg("This message appears when log level set to Debug or Info") log.Info(context.Background()).Msg("This message appears when log level set to Debug or Info")
if e := log.Debug(); e.Enabled() { if e := log.Debug(context.Background()); e.Enabled() {
// Compute log output only if enabled. // Compute log output only if enabled.
value := "bar" value := "bar"
e.Str("foo", value).Msg("some debug message") e.Str("foo", value).Msg("some debug message")
@ -134,19 +135,19 @@ func Example() {
func ExampleSetLevel() { func ExampleSetLevel() {
setup() setup()
log.SetLevel("info") log.SetLevel("info")
log.Debug().Msg("Debug") log.Debug(context.Background()).Msg("Debug")
log.Info().Msg("Debug or Info") log.Info(context.Background()).Msg("Debug or Info")
log.SetLevel("warn") log.SetLevel("warn")
log.Debug().Msg("Debug") log.Debug(context.Background()).Msg("Debug")
log.Info().Msg("Debug or Info") log.Info(context.Background()).Msg("Debug or Info")
log.Warn().Msg("Debug or Info or Warn") log.Warn(context.Background()).Msg("Debug or Info or Warn")
log.SetLevel("error") log.SetLevel("error")
log.Debug().Msg("Debug") log.Debug(context.Background()).Msg("Debug")
log.Info().Msg("Debug or Info") log.Info(context.Background()).Msg("Debug or Info")
log.Warn().Msg("Debug or Info or Warn") log.Warn(context.Background()).Msg("Debug or Info or Warn")
log.Error().Msg("Debug or Info or Warn or Error") log.Error(context.Background()).Msg("Debug or Info or Warn or Error")
log.SetLevel("default-fall-through") log.SetLevel("default-fall-through")
log.Debug().Msg("Debug") log.Debug(context.Background()).Msg("Debug")
// Output: // Output:
// {"level":"info","time":1199811905,"message":"Debug or Info"} // {"level":"info","time":1199811905,"message":"Debug or Info"}
@ -154,3 +155,32 @@ func ExampleSetLevel() {
// {"level":"error","time":1199811905,"message":"Debug or Info or Warn or Error"} // {"level":"error","time":1199811905,"message":"Debug or Info or Warn or Error"}
// {"level":"debug","time":1199811905,"message":"Debug"} // {"level":"debug","time":1199811905,"message":"Debug"}
} }
func ExampleContext() {
setup()
bg := context.Background()
ctx1 := log.WithContext(bg, func(c zerolog.Context) zerolog.Context {
return c.Str("param_one", "one")
})
ctx2 := log.WithContext(ctx1, func(c zerolog.Context) zerolog.Context {
return c.Str("param_two", "two")
})
log.Warn(bg).Str("non_context_param", "value").Msg("background")
log.Warn(ctx1).Str("non_context_param", "value").Msg("first")
log.Warn(ctx2).Str("non_context_param", "value").Msg("second")
for i := 0; i < 10; i++ {
ctx1 = log.WithContext(ctx1, func(c zerolog.Context) zerolog.Context {
return c.Int("counter", i)
})
}
log.Info(ctx1).Str("non_ctx_param", "value").Msg("after counter")
/*
{"level":"warn","ctx":"one","param":"first","time":1199811905,"message":"first"}
{"level":"warn","ctx":"two","param":"second","time":1199811905,"message":"second"}
{"level":"warn","param":"third","time":1199811905,"message":"third"}
*/
}

View file

@ -51,7 +51,7 @@ func (s *inMemoryServer) periodicCheck(ctx context.Context) {
return return
case <-time.After(after): case <-time.After(after):
if s.lockAndRmExpired() { if s.lockAndRmExpired() {
s.onchange.Broadcast() s.onchange.Broadcast(ctx)
} }
} }
} }
@ -70,7 +70,7 @@ func (s *inMemoryServer) Report(ctx context.Context, req *pb.RegisterRequest) (*
} }
if updated { if updated {
s.onchange.Broadcast() s.onchange.Broadcast(ctx)
} }
return &pb.RegisterResponse{ return &pb.RegisterResponse{
@ -171,7 +171,7 @@ func (s *inMemoryServer) Watch(req *pb.ListRequest, srv pb.Registry_WatchServer)
} }
} }
func (s *inMemoryServer) getServiceUpdates(ctx context.Context, kinds map[pb.ServiceKind]bool, updates chan struct{}) ([]*pb.Service, error) { func (s *inMemoryServer) getServiceUpdates(ctx context.Context, kinds map[pb.ServiceKind]bool, updates chan context.Context) ([]*pb.Service, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()

View file

@ -23,25 +23,25 @@ type Reporter struct {
} }
// OnConfigChange applies configuration changes to the reporter // OnConfigChange applies configuration changes to the reporter
func (r *Reporter) OnConfigChange(cfg *config.Config) { func (r *Reporter) OnConfigChange(ctx context.Context, cfg *config.Config) {
if r.cancel != nil { if r.cancel != nil {
r.cancel() r.cancel()
} }
services, err := getReportedServices(cfg) services, err := getReportedServices(cfg)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("metrics announce to service registry is disabled") log.Warn(ctx).Err(err).Msg("metrics announce to service registry is disabled")
} }
sharedKey, err := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) sharedKey, err := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
if err != nil { if err != nil {
log.Error().Err(err).Msg("decoding shared key") log.Error(ctx).Err(err).Msg("decoding shared key")
return return
} }
urls, err := cfg.Options.GetDataBrokerURLs() urls, err := cfg.Options.GetDataBrokerURLs()
if err != nil { if err != nil {
log.Error().Err(err).Msg("invalid databroker urls") log.Error(ctx).Err(err).Msg("invalid databroker urls")
return return
} }
@ -58,7 +58,7 @@ func (r *Reporter) OnConfigChange(cfg *config.Config) {
SignedJWTKey: sharedKey, SignedJWTKey: sharedKey,
}) })
if err != nil { if err != nil {
log.Error().Err(err).Msg("connecting to registry") log.Error(ctx).Err(err).Msg("connecting to registry")
return return
} }
@ -145,7 +145,7 @@ func runReporter(
after = resp.CallBackAfter.AsDuration() after = resp.CallBackAfter.AsDuration()
backoff.Reset() backoff.Reset()
case <-ctx.Done(): case <-ctx.Done():
log.Info().Msg("service registry reporter stopping") log.Info(ctx).Msg("service registry reporter stopping")
return return
} }
} }

View file

@ -2,28 +2,29 @@
package signal package signal
import ( import (
"context"
"sync" "sync"
) )
// A Signal is used to let multiple listeners know when something happened. // A Signal is used to let multiple listeners know when something happened.
type Signal struct { type Signal struct {
mu sync.Mutex mu sync.Mutex
chs map[chan struct{}]struct{} chs map[chan context.Context]struct{}
} }
// New creates a new Signal. // New creates a new Signal.
func New() *Signal { func New() *Signal {
return &Signal{ return &Signal{
chs: make(map[chan struct{}]struct{}), chs: make(map[chan context.Context]struct{}),
} }
} }
// Broadcast signals all the listeners. Broadcast never blocks. // Broadcast signals all the listeners. Broadcast never blocks.
func (s *Signal) Broadcast() { func (s *Signal) Broadcast(ctx context.Context) {
s.mu.Lock() s.mu.Lock()
for ch := range s.chs { for ch := range s.chs {
select { select {
case ch <- struct{}{}: case ch <- ctx:
default: default:
} }
} }
@ -32,8 +33,8 @@ func (s *Signal) Broadcast() {
// Bind creates a new listening channel bound to the signal. The channel used has a size of 1 // Bind creates a new listening channel bound to the signal. The channel used has a size of 1
// and any given broadcast will signal at least one event, but may signal more than one. // and any given broadcast will signal at least one event, but may signal more than one.
func (s *Signal) Bind() chan struct{} { func (s *Signal) Bind() chan context.Context {
ch := make(chan struct{}, 1) ch := make(chan context.Context, 1)
s.mu.Lock() s.mu.Lock()
s.chs[ch] = struct{}{} s.chs[ch] = struct{}{}
s.mu.Unlock() s.mu.Unlock()
@ -41,7 +42,7 @@ func (s *Signal) Bind() chan struct{} {
} }
// Unbind stops the listening channel bound to the signal. // Unbind stops the listening channel bound to the signal.
func (s *Signal) Unbind(ch chan struct{}) { func (s *Signal) Unbind(ch chan context.Context) {
s.mu.Lock() s.mu.Lock()
delete(s.chs, ch) delete(s.chs, ch)
s.mu.Unlock() s.mu.Unlock()

View file

@ -1,6 +1,7 @@
package tcptunnel package tcptunnel
import ( import (
"context"
"crypto/tls" "crypto/tls"
"github.com/pomerium/pomerium/internal/cliutil" "github.com/pomerium/pomerium/internal/cliutil"
@ -19,7 +20,7 @@ func getConfig(options ...Option) *config {
if jwtCache, err := cliutil.NewLocalJWTCache(); err == nil { if jwtCache, err := cliutil.NewLocalJWTCache(); err == nil {
WithJWTCache(jwtCache)(cfg) WithJWTCache(jwtCache)(cfg)
} else { } else {
log.Error().Err(err).Msg("tcptunnel: error creating local JWT cache, using in-memory JWT cache") log.Error(context.TODO()).Err(err).Msg("tcptunnel: error creating local JWT cache, using in-memory JWT cache")
WithJWTCache(cliutil.NewMemoryJWTCache())(cfg) WithJWTCache(cliutil.NewMemoryJWTCache())(cfg)
} }
for _, o := range options { for _, o := range options {

View file

@ -43,7 +43,7 @@ func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) erro
return err return err
} }
defer func() { _ = li.Close() }() defer func() { _ = li.Close() }()
log.Info().Msg("tcptunnel: listening on " + li.Addr().String()) log.Info(ctx).Msg("tcptunnel: listening on " + li.Addr().String())
go func() { go func() {
<-ctx.Done() <-ctx.Done()
@ -62,7 +62,7 @@ func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) erro
} }
if nerr, ok := err.(net.Error); ok && nerr.Temporary() { if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
log.Warn().Err(err).Msg("tcptunnel: temporarily failed to accept local connection") log.Warn(ctx).Err(err).Msg("tcptunnel: temporarily failed to accept local connection")
select { select {
case <-time.After(bo.NextBackOff()): case <-time.After(bo.NextBackOff()):
case <-ctx.Done(): case <-ctx.Done():
@ -79,7 +79,7 @@ func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) erro
err := tun.Run(ctx, conn) err := tun.Run(ctx, conn)
if err != nil { if err != nil {
log.Error().Err(err).Msg("tcptunnel: error serving local connection") log.Error(ctx).Err(err).Msg("tcptunnel: error serving local connection")
} }
}() }()
} }
@ -102,7 +102,7 @@ func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter) error {
} }
func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string, retryCount int) error { func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string, retryCount int) error {
log.Info(). log.Info(ctx).
Str("dst", tun.cfg.dstHost). Str("dst", tun.cfg.dstHost).
Str("proxy", tun.cfg.proxyHost). Str("proxy", tun.cfg.proxyHost).
Bool("secure", tun.cfg.tlsConfig != nil). Bool("secure", tun.cfg.tlsConfig != nil).
@ -132,7 +132,7 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string,
} }
defer func() { defer func() {
_ = remote.Close() _ = remote.Close()
log.Info().Msg("tcptunnel: connection closed") log.Info(ctx).Msg("tcptunnel: connection closed")
}() }()
if done := ctx.Done(); done != nil { if done := ctx.Done(); done != nil {
go func() { go func() {
@ -189,7 +189,7 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string,
return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode) return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode)
} }
log.Info().Msg("tcptunnel: connection established") log.Info(ctx).Msg("tcptunnel: connection established")
errc := make(chan error, 2) errc := make(chan error, 2)
go func() { go func() {

View file

@ -142,7 +142,7 @@ func GRPCClientInterceptor(service string) grpc.UnaryClientInterceptor {
tag.Upsert(TagKeyGRPCService, rpcService), tag.Upsert(TagKeyGRPCService, rpcService),
) )
if tagErr != nil { if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "GRPCClientInterceptor").Msg("telemetry/metrics: failed to create context") log.Warn(ctx).Err(tagErr).Str("context", "GRPCClientInterceptor").Msg("telemetry/metrics: failed to create context")
return invoker(ctx, method, req, reply, cc, opts...) return invoker(ctx, method, req, reply, cc, opts...)
} }
@ -181,7 +181,7 @@ func (h *GRPCServerMetricsHandler) TagRPC(ctx context.Context, tagInfo *grpcstat
tag.Upsert(TagKeyGRPCService, rpcService), tag.Upsert(TagKeyGRPCService, rpcService),
) )
if tagErr != nil { if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "GRPCServerStatsHandler").Msg("telemetry/metrics: failed to create context") log.Warn(ctx).Err(tagErr).Str("context", "GRPCServerStatsHandler").Msg("telemetry/metrics: failed to create context")
return ctx return ctx
} }

View file

@ -120,7 +120,7 @@ func HTTPMetricsHandler(getInstallationID func() string, service string) func(ne
tag.Upsert(TagKeyHTTPMethod, r.Method), tag.Upsert(TagKeyHTTPMethod, r.Method),
) )
if tagErr != nil { if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "HTTPMetricsHandler").Msg("telemetry/metrics: failed to create metrics tag") log.Warn(ctx).Err(tagErr).Str("context", "HTTPMetricsHandler").Msg("telemetry/metrics: failed to create metrics tag")
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
@ -148,7 +148,7 @@ func HTTPMetricsRoundTripper(getInstallationID func() string, service string, de
tag.Upsert(TagKeyDestination, destination), tag.Upsert(TagKeyDestination, destination),
) )
if tagErr != nil { if tagErr != nil {
log.Warn().Err(tagErr).Str("context", "HTTPMetricsRoundTripper").Msg("telemetry/metrics: failed to create metrics tag") log.Warn(ctx).Err(tagErr).Str("context", "HTTPMetricsRoundTripper").Msg("telemetry/metrics: failed to create metrics tag")
return next.RoundTrip(r) return next.RoundTrip(r)
} }

View file

@ -104,8 +104,8 @@ func RecordIdentityManagerLastRefresh() {
// SetDBConfigInfo records status, databroker version and error count while parsing // SetDBConfigInfo records status, databroker version and error count while parsing
// the configuration from a databroker // the configuration from a databroker
func SetDBConfigInfo(service, configID string, version uint64, errCount int64) { func SetDBConfigInfo(ctx context.Context, service, configID string, version uint64, errCount int64) {
log.Info(). log.Info(ctx).
Str("service", service). Str("service", service).
Str("config_id", configID). Str("config_id", configID).
Uint64("version", version). Uint64("version", version).
@ -113,14 +113,14 @@ func SetDBConfigInfo(service, configID string, version uint64, errCount int64) {
Msg("set db config info") Msg("set db config info")
if err := stats.RecordWithTags( if err := stats.RecordWithTags(
context.Background(), ctx,
[]tag.Mutator{ []tag.Mutator{
tag.Insert(TagKeyService, service), tag.Insert(TagKeyService, service),
tag.Insert(TagConfigID, configID), tag.Insert(TagConfigID, configID),
}, },
configDBVersion.M(int64(version)), configDBVersion.M(int64(version)),
); err != nil { ); err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to record config version number") log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to record config version number")
} }
if err := stats.RecordWithTags( if err := stats.RecordWithTags(
@ -131,20 +131,20 @@ func SetDBConfigInfo(service, configID string, version uint64, errCount int64) {
}, },
configDBErrors.M(errCount), configDBErrors.M(errCount),
); err != nil { ); err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to record config error count") log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to record config error count")
} }
} }
// SetDBConfigRejected records that a certain databroker config version has been rejected // SetDBConfigRejected records that a certain databroker config version has been rejected
func SetDBConfigRejected(service, configID string, version uint64, err error) { func SetDBConfigRejected(ctx context.Context, service, configID string, version uint64, err error) {
log.Warn().Err(err).Msg("databroker: invalid config detected, ignoring") log.Warn(ctx).Err(err).Msg("databroker: invalid config detected, ignoring")
SetDBConfigInfo(service, configID, version, -1) SetDBConfigInfo(ctx, service, configID, version, -1)
} }
// SetConfigInfo records the status, checksum and timestamp of a configuration // SetConfigInfo records the status, checksum and timestamp of a configuration
// reload. You must register InfoViews or the related config views before calling // reload. You must register InfoViews or the related config views before calling
func SetConfigInfo(service, configName string, checksum uint64, success bool) { func SetConfigInfo(ctx context.Context, service, configName string, checksum uint64, success bool) {
if success { if success {
registry.setConfigChecksum(service, configName, checksum) registry.setConfigChecksum(service, configName, checksum)
@ -154,7 +154,7 @@ func SetConfigInfo(service, configName string, checksum uint64, success bool) {
[]tag.Mutator{serviceTag}, []tag.Mutator{serviceTag},
configLastReload.M(time.Now().Unix()), configLastReload.M(time.Now().Unix()),
); err != nil { ); err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to record config checksum timestamp") log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to record config checksum timestamp")
} }
if err := stats.RecordWithTags( if err := stats.RecordWithTags(
@ -162,12 +162,12 @@ func SetConfigInfo(service, configName string, checksum uint64, success bool) {
[]tag.Mutator{serviceTag}, []tag.Mutator{serviceTag},
configLastReloadSuccess.M(1), configLastReloadSuccess.M(1),
); err != nil { ); err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to record config reload") log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to record config reload")
} }
} else { } else {
stats.Record(context.Background(), configLastReloadSuccess.M(0)) stats.Record(context.Background(), configLastReloadSuccess.M(0))
} }
log.Info(). log.Info(ctx).
Str("service", service). Str("service", service).
Str("config", configName). Str("config", configName).
Str("checksum", fmt.Sprintf("%x", checksum)). Str("checksum", fmt.Sprintf("%x", checksum)).

View file

@ -1,6 +1,7 @@
package metrics package metrics
import ( import (
"context"
"fmt" "fmt"
"runtime" "runtime"
"testing" "testing"
@ -28,7 +29,7 @@ func Test_SetConfigInfo(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
view.Unregister(InfoViews...) view.Unregister(InfoViews...)
view.Register(InfoViews...) view.Register(InfoViews...)
SetConfigInfo("test_service", "test config", 0, tt.success) SetConfigInfo(context.Background(), "test_service", "test config", 0, tt.success)
testDataRetrieval(ConfigLastReloadView, t, tt.wantLastReload) testDataRetrieval(ConfigLastReloadView, t, tt.wantLastReload)
testDataRetrieval(ConfigLastReloadSuccessView, t, tt.wantLastReloadSuccess) testDataRetrieval(ConfigLastReloadSuccessView, t, tt.wantLastReloadSuccess)
@ -55,7 +56,7 @@ func Test_SetDBConfigInfo(t *testing.T) {
t.Run(fmt.Sprintf("version=%d errors=%d", tt.version, tt.errCount), func(t *testing.T) { t.Run(fmt.Sprintf("version=%d errors=%d", tt.version, tt.errCount), func(t *testing.T) {
view.Unregister(InfoViews...) view.Unregister(InfoViews...)
view.Register(InfoViews...) view.Register(InfoViews...)
SetDBConfigInfo("test_service", "test_config", tt.version, tt.errCount) SetDBConfigInfo(context.TODO(), "test_service", "test_config", tt.version, tt.errCount)
testDataRetrieval(ConfigDBVersionView, t, tt.wantVersion) testDataRetrieval(ConfigDBVersionView, t, tt.wantVersion)
testDataRetrieval(ConfigDBErrorsView, t, tt.wantErrors) testDataRetrieval(ConfigDBErrorsView, t, tt.wantErrors)

View file

@ -89,26 +89,26 @@ func newProxyMetricsHandler(exporter *ocprom.Exporter, envoyURL url.URL, install
err := writeMetricsWithInstallationID(w, rec.Body, installationID) err := writeMetricsWithInstallationID(w, rec.Body, installationID)
if err != nil { if err != nil {
log.Error().Err(err).Send() log.Error(r.Context()).Err(err).Send()
return return
} }
req, err := http.NewRequestWithContext(r.Context(), "GET", envoyURL.String(), nil) req, err := http.NewRequestWithContext(r.Context(), "GET", envoyURL.String(), nil)
if err != nil { if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to create request for envoy") log.Error(r.Context()).Err(err).Msg("telemetry/metrics: failed to create request for envoy")
return return
} }
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: fail to fetch proxy metrics") log.Error(r.Context()).Err(err).Msg("telemetry/metrics: fail to fetch proxy metrics")
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
err = writeMetricsWithInstallationID(w, resp.Body, installationID) err = writeMetricsWithInstallationID(w, resp.Body, installationID)
if err != nil { if err != nil {
log.Error().Err(err).Send() log.Error(r.Context()).Err(err).Send()
return return
} }
} }

View file

@ -1,6 +1,7 @@
package metrics package metrics
import ( import (
"context"
"runtime" "runtime"
"sync" "sync"
@ -34,6 +35,7 @@ func newMetricRegistry() *metricRegistry {
} }
func (r *metricRegistry) init() { func (r *metricRegistry) init() {
ctx := context.TODO()
r.Do( r.Do(
func() { func() {
r.registry = metric.NewRegistry() r.registry = metric.NewRegistry()
@ -49,7 +51,7 @@ func (r *metricRegistry) init() {
), ),
) )
if err != nil { if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to register build info metric") log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to register build info metric")
} }
r.configChecksum, err = r.registry.AddFloat64Gauge(metrics.ConfigChecksumDecimal, r.configChecksum, err = r.registry.AddFloat64Gauge(metrics.ConfigChecksumDecimal,
@ -57,7 +59,7 @@ func (r *metricRegistry) init() {
metric.WithLabelKeys(metrics.ServiceLabel, metrics.ConfigLabel), metric.WithLabelKeys(metrics.ServiceLabel, metrics.ConfigLabel),
) )
if err != nil { if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to register config checksum metric") log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to register config checksum metric")
} }
r.policyCount, err = r.registry.AddInt64DerivedGauge(metrics.PolicyCountTotal, r.policyCount, err = r.registry.AddInt64DerivedGauge(metrics.PolicyCountTotal,
@ -65,12 +67,12 @@ func (r *metricRegistry) init() {
metric.WithLabelKeys(metrics.ServiceLabel), metric.WithLabelKeys(metrics.ServiceLabel),
) )
if err != nil { if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to register policy count metric") log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to register policy count metric")
} }
err = registerAutocertMetrics(r.registry) err = registerAutocertMetrics(r.registry)
if err != nil { if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to register autocert metrics") log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to register autocert metrics")
} }
}) })
} }
@ -89,7 +91,7 @@ func (r *metricRegistry) setBuildInfo(service, hostname string) {
metricdata.NewLabelValue(hostname), metricdata.NewLabelValue(hostname),
) )
if err != nil { if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to get build info metric") log.Error(context.TODO()).Err(err).Msg("telemetry/metrics: failed to get build info metric")
} }
// This sets our build_info metric to a constant 1 per // This sets our build_info metric to a constant 1 per
@ -103,7 +105,7 @@ func (r *metricRegistry) addPolicyCountCallback(service string, f func() int64)
} }
err := r.policyCount.UpsertEntry(f, metricdata.NewLabelValue(service)) err := r.policyCount.UpsertEntry(f, metricdata.NewLabelValue(service))
if err != nil { if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to get policy count metric") log.Error(context.TODO()).Err(err).Msg("telemetry/metrics: failed to get policy count metric")
} }
} }
@ -113,7 +115,7 @@ func (r *metricRegistry) setConfigChecksum(service string, configName string, ch
} }
m, err := r.configChecksum.GetEntry(metricdata.NewLabelValue(service), metricdata.NewLabelValue(configName)) m, err := r.configChecksum.GetEntry(metricdata.NewLabelValue(service), metricdata.NewLabelValue(configName))
if err != nil { if err != nil {
log.Error().Err(err).Msg("telemetry/metrics: failed to get config checksum metric") log.Error(context.TODO()).Err(err).Msg("telemetry/metrics: failed to get config checksum metric")
} }
m.Set(float64(checksum)) m.Set(float64(checksum))
} }
@ -122,13 +124,13 @@ func (r *metricRegistry) addInt64DerivedGaugeMetric(name, desc, service string,
m, err := r.registry.AddInt64DerivedGauge(name, metric.WithDescription(desc), m, err := r.registry.AddInt64DerivedGauge(name, metric.WithDescription(desc),
metric.WithLabelKeys(metrics.ServiceLabel)) metric.WithLabelKeys(metrics.ServiceLabel))
if err != nil { if err != nil {
log.Error().Err(err).Str("service", service).Msg("telemetry/metrics: failed to register metric") log.Error(context.TODO()).Err(err).Str("service", service).Msg("telemetry/metrics: failed to register metric")
return return
} }
err = m.UpsertEntry(f, metricdata.NewLabelValue(service)) err = m.UpsertEntry(f, metricdata.NewLabelValue(service))
if err != nil { if err != nil {
log.Error().Err(err).Str("service", service).Msg("telemetry/metrics: failed to update metric") log.Error(context.TODO()).Err(err).Str("service", service).Msg("telemetry/metrics: failed to update metric")
return return
} }
} }
@ -137,13 +139,13 @@ func (r *metricRegistry) addInt64DerivedCumulativeMetric(name, desc, service str
m, err := r.registry.AddInt64DerivedCumulative(name, metric.WithDescription(desc), m, err := r.registry.AddInt64DerivedCumulative(name, metric.WithDescription(desc),
metric.WithLabelKeys(metrics.ServiceLabel)) metric.WithLabelKeys(metrics.ServiceLabel))
if err != nil { if err != nil {
log.Error().Err(err).Str("service", service).Msg("telemetry/metrics: failed to register metric") log.Error(context.TODO()).Err(err).Str("service", service).Msg("telemetry/metrics: failed to register metric")
return return
} }
err = m.UpsertEntry(f, metricdata.NewLabelValue(service)) err = m.UpsertEntry(f, metricdata.NewLabelValue(service))
if err != nil { if err != nil {
log.Error().Err(err).Str("service", service).Msg("telemetry/metrics: failed to update metric") log.Error(context.TODO()).Err(err).Str("service", service).Msg("telemetry/metrics: failed to update metric")
return return
} }
} }

View file

@ -57,6 +57,6 @@ func RecordStorageOperation(ctx context.Context, tags *StorageOperationTags, dur
storageOperationDuration.M(duration.Milliseconds()), storageOperationDuration.M(duration.Milliseconds()),
) )
if err != nil { if err != nil {
log.Warn().Err(err).Msg("internal/telemetry/metrics: failed to record") log.Warn(ctx).Err(err).Msg("internal/telemetry/metrics: failed to record")
} }
} }

View file

@ -77,7 +77,7 @@ func RegisterTracing(opts *TracingOptions) (trace.Exporter, error) {
} }
trace.ApplyConfig(trace.Config{DefaultSampler: trace.ProbabilitySampler(opts.SampleRate)}) trace.ApplyConfig(trace.Config{DefaultSampler: trace.ProbabilitySampler(opts.SampleRate)})
log.Debug().Interface("Opts", opts).Msg("telemetry/trace: exporter created") log.Debug(context.TODO()).Interface("Opts", opts).Msg("telemetry/trace: exporter created")
return exporter, nil return exporter, nil
} }

View file

@ -1,6 +1,7 @@
package cryptutil package cryptutil
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
@ -14,9 +15,10 @@ import (
// GetCertPool gets a cert pool for the given CA or CAFile. // GetCertPool gets a cert pool for the given CA or CAFile.
func GetCertPool(ca, caFile string) (*x509.CertPool, error) { func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
ctx := context.TODO()
rootCAs, err := x509.SystemCertPool() rootCAs, err := x509.SystemCertPool()
if err != nil { if err != nil {
log.Error().Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one") log.Error(ctx).Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one")
rootCAs = x509.NewCertPool() rootCAs = x509.NewCertPool()
} }
if ca == "" && caFile == "" { if ca == "" && caFile == "" {
@ -38,7 +40,7 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
if ok := rootCAs.AppendCertsFromPEM(data); !ok { if ok := rootCAs.AppendCertsFromPEM(data); !ok {
return nil, fmt.Errorf("failed to append any PEM-encoded certificates") return nil, fmt.Errorf("failed to append any PEM-encoded certificates")
} }
log.Debug().Msg("pkg/cryptutil: added custom certificate authority") log.Debug(ctx).Msg("pkg/cryptutil: added custom certificate authority")
return rootCAs, nil return rootCAs, nil
} }

View file

@ -60,6 +60,7 @@ type Options struct {
// NewGRPCClientConn returns a new gRPC pomerium service client connection. // NewGRPCClientConn returns a new gRPC pomerium service client connection.
func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) { func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
ctx := context.TODO()
if len(opts.Addrs) == 0 { if len(opts.Addrs) == 0 {
return nil, errors.New("internal/grpc: connection address required") return nil, errors.New("internal/grpc: connection address required")
} }
@ -105,7 +106,7 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
} }
if opts.WithInsecure { if opts.WithInsecure {
log.Info().Str("addr", connAddr).Msg("internal/grpc: grpc with insecure") log.Info(ctx).Str("addr", connAddr).Msg("internal/grpc: grpc with insecure")
dialOptions = append(dialOptions, grpc.WithInsecure()) dialOptions = append(dialOptions, grpc.WithInsecure())
} else { } else {
rootCAs, err := cryptutil.GetCertPool(opts.CA, opts.CAFile) rootCAs, err := cryptutil.GetCertPool(opts.CA, opts.CAFile)
@ -117,7 +118,7 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
// override allowed certificate name string, typically used when doing behind ingress connection // override allowed certificate name string, typically used when doing behind ingress connection
if opts.OverrideCertificateName != "" { if opts.OverrideCertificateName != "" {
log.Debug().Str("cert-override-name", opts.OverrideCertificateName).Msg("internal/grpc: grpc") log.Debug(ctx).Str("cert-override-name", opts.OverrideCertificateName).Msg("internal/grpc: grpc")
err := cert.OverrideServerName(opts.OverrideCertificateName) err := cert.OverrideServerName(opts.OverrideCertificateName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -169,7 +170,7 @@ func GetGRPCClientConn(name string, opts *Options) (*grpc.ClientConn, error) {
err := current.conn.Close() err := current.conn.Close()
if err != nil { if err != nil {
log.Error().Err(err).Msg("grpc: failed to close existing connection") log.Error(context.TODO()).Err(err).Msg("grpc: failed to close existing connection")
} }
} }

View file

@ -101,7 +101,7 @@ func (syncer *Syncer) Run(ctx context.Context) error {
} }
if err != nil { if err != nil {
syncer.log().Error().Err(err).Msg("sync") log.Error(syncer.logCtx(ctx)).Err(err).Msg("sync")
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
@ -112,42 +112,42 @@ func (syncer *Syncer) Run(ctx context.Context) error {
} }
func (syncer *Syncer) init(ctx context.Context) error { func (syncer *Syncer) init(ctx context.Context) error {
syncer.log().Info().Msg("initial sync") log.Info(syncer.logCtx(ctx)).Msg("initial sync")
records, recordVersion, serverVersion, err := InitialSync(ctx, syncer.handler.GetDataBrokerServiceClient(), &SyncLatestRequest{ records, recordVersion, serverVersion, err := InitialSync(syncer.logCtx(ctx), syncer.handler.GetDataBrokerServiceClient(), &SyncLatestRequest{
Type: syncer.cfg.typeURL, Type: syncer.cfg.typeURL,
}) })
if err != nil { if err != nil {
syncer.log().Error().Err(err).Msg("error during initial sync") log.Error(syncer.logCtx(ctx)).Err(err).Msg("error during initial sync")
return err return err
} }
syncer.backoff.Reset() syncer.backoff.Reset()
// reset the records as we have to sync latest // reset the records as we have to sync latest
syncer.handler.ClearRecords(ctx) syncer.handler.ClearRecords(syncer.logCtx(ctx))
syncer.recordVersion = recordVersion syncer.recordVersion = recordVersion
syncer.serverVersion = serverVersion syncer.serverVersion = serverVersion
syncer.handler.UpdateRecords(ctx, serverVersion, records) syncer.handler.UpdateRecords(syncer.logCtx(ctx), serverVersion, records)
return nil return nil
} }
func (syncer *Syncer) sync(ctx context.Context) error { func (syncer *Syncer) sync(ctx context.Context) error {
stream, err := syncer.handler.GetDataBrokerServiceClient().Sync(ctx, &SyncRequest{ stream, err := syncer.handler.GetDataBrokerServiceClient().Sync(syncer.logCtx(ctx), &SyncRequest{
ServerVersion: syncer.serverVersion, ServerVersion: syncer.serverVersion,
RecordVersion: syncer.recordVersion, RecordVersion: syncer.recordVersion,
}) })
if err != nil { if err != nil {
syncer.log().Error().Err(err).Msg("error during sync") log.Error(syncer.logCtx(ctx)).Err(err).Msg("error during sync")
return err return err
} }
syncer.log().Info().Msg("listening for updates") log.Info(syncer.logCtx(ctx)).Msg("listening for updates")
for { for {
res, err := stream.Recv() res, err := stream.Recv()
if status.Code(err) == codes.Aborted { if status.Code(err) == codes.Aborted {
syncer.log().Error().Err(err).Msg("aborted sync due to mismatched server version") log.Error(syncer.logCtx(ctx)).Err(err).Msg("aborted sync due to mismatched server version")
// server version changed, so re-init // server version changed, so re-init
syncer.serverVersion = 0 syncer.serverVersion = 0
return nil return nil
@ -155,14 +155,13 @@ func (syncer *Syncer) sync(ctx context.Context) error {
return err return err
} }
syncer.log().Debug(). log.Debug(syncer.logCtx(ctx)).
Uint("version", uint(res.Record.GetVersion())). Uint("version", uint(res.Record.GetVersion())).
Str("type", res.Record.Type).
Str("id", res.Record.Id). Str("id", res.Record.Id).
Msg("syncer got record") Msg("syncer got record")
if syncer.recordVersion != res.GetRecord().GetVersion()-1 { if syncer.recordVersion != res.GetRecord().GetVersion()-1 {
syncer.log().Error().Err(err). log.Error(syncer.logCtx(ctx)).Err(err).
Uint64("received", res.GetRecord().GetVersion()). Uint64("received", res.GetRecord().GetVersion()).
Msg("aborted sync due to missing record") Msg("aborted sync due to missing record")
syncer.serverVersion = 0 syncer.serverVersion = 0
@ -170,15 +169,17 @@ func (syncer *Syncer) sync(ctx context.Context) error {
} }
syncer.recordVersion = res.GetRecord().GetVersion() syncer.recordVersion = res.GetRecord().GetVersion()
if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() { if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() {
syncer.handler.UpdateRecords(ctx, syncer.serverVersion, []*Record{res.GetRecord()}) syncer.handler.UpdateRecords(syncer.logCtx(ctx), syncer.serverVersion, []*Record{res.GetRecord()})
} }
} }
} }
func (syncer *Syncer) log() *zerolog.Logger { // logCtx adds log params to context which
l := log.With().Str("syncer_id", syncer.id). func (syncer *Syncer) logCtx(ctx context.Context) context.Context {
Str("type", syncer.cfg.typeURL). return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
Uint64("server_version", syncer.serverVersion). return c.Str("syncer_id", syncer.id).
Uint64("record_version", syncer.recordVersion).Logger() Str("type", syncer.cfg.typeURL).
return &l Uint64("server_version", syncer.serverVersion).
Uint64("record_version", syncer.recordVersion)
})
} }

View file

@ -9,9 +9,11 @@ import (
"time" "time"
"github.com/google/btree" "github.com/google/btree"
"github.com/rs/zerolog"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/signal" "github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
@ -141,14 +143,20 @@ func (backend *Backend) GetAll(_ context.Context) ([]*databroker.Record, uint64,
} }
// Put puts a record into the in-memory store. // Put puts a record into the in-memory store.
func (backend *Backend) Put(_ context.Context, record *databroker.Record) error { func (backend *Backend) Put(ctx context.Context, record *databroker.Record) error {
if record == nil { if record == nil {
return fmt.Errorf("records cannot be nil") return fmt.Errorf("records cannot be nil")
} }
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("db_op", "put").
Str("db_id", record.Id).
Str("db_type", record.Type)
})
backend.mu.Lock() backend.mu.Lock()
defer backend.mu.Unlock() defer backend.mu.Unlock()
defer backend.onChange.Broadcast() defer backend.onChange.Broadcast(ctx)
record.ModifiedAt = timestamppb.Now() record.ModifiedAt = timestamppb.Now()
record.Version = backend.nextVersion() record.Version = backend.nextVersion()

View file

@ -12,7 +12,7 @@ type recordStream struct {
ctx context.Context ctx context.Context
backend *Backend backend *Backend
changed chan struct{} changed chan context.Context
ready []*databroker.Record ready []*databroker.Record
version uint64 version uint64

View file

@ -15,7 +15,7 @@ type logger struct {
} }
func (l logger) Printf(ctx context.Context, format string, v ...interface{}) { func (l logger) Printf(ctx context.Context, format string, v ...interface{}) {
log.Info().Str("service", "redis").Msgf(format, v...) log.Info(ctx).Str("service", "redis").Msgf(format, v...)
} }
func init() { func init() {

View file

@ -51,6 +51,7 @@ type Backend struct {
// New creates a new redis storage backend. // New creates a new redis storage backend.
func New(rawURL string, options ...Option) (*Backend, error) { func New(rawURL string, options ...Option) (*Backend, error) {
ctx := context.TODO()
cfg := getConfig(options...) cfg := getConfig(options...)
backend := &Backend{ backend := &Backend{
cfg: cfg, cfg: cfg,
@ -63,7 +64,7 @@ func New(rawURL string, options ...Option) (*Backend, error) {
return nil, err return nil, err
} }
metrics.AddRedisMetrics(backend.client.PoolStats) metrics.AddRedisMetrics(backend.client.PoolStats)
go backend.listenForVersionChanges() go backend.listenForVersionChanges(ctx)
if cfg.expiry != 0 { if cfg.expiry != 0 {
go func() { go func() {
ticker := time.NewTicker(time.Minute) ticker := time.NewTicker(time.Minute)
@ -75,7 +76,7 @@ func New(rawURL string, options ...Option) (*Backend, error) {
case <-ticker.C: case <-ticker.C:
} }
backend.removeChangesBefore(time.Now().Add(-cfg.expiry)) backend.removeChangesBefore(ctx, time.Now().Add(-cfg.expiry))
} }
}() }()
} }
@ -146,7 +147,7 @@ func (backend *Backend) GetAll(ctx context.Context) (records []*databroker.Recor
var record databroker.Record var record databroker.Record
err := proto.Unmarshal([]byte(result), &record) err := proto.Unmarshal([]byte(result), &record)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("redis: invalid record detected") log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
continue continue
} }
records = append(records, &record) records = append(records, &record)
@ -246,8 +247,8 @@ func (backend *Backend) incrementVersion(ctx context.Context,
return ErrExceededMaxRetries return ErrExceededMaxRetries
} }
func (backend *Backend) listenForVersionChanges() { func (backend *Backend) listenForVersionChanges(ctx context.Context) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(ctx)
go func() { go func() {
<-backend.closed <-backend.closed
cancel() cancel()
@ -274,14 +275,13 @@ outer:
switch msg.(type) { switch msg.(type) {
case *redis.Message: case *redis.Message:
backend.onChange.Broadcast() backend.onChange.Broadcast(ctx)
} }
} }
} }
} }
func (backend *Backend) removeChangesBefore(cutoff time.Time) { func (backend *Backend) removeChangesBefore(ctx context.Context, cutoff time.Time) {
ctx := context.Background()
for { for {
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{ cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
Min: "-inf", Min: "-inf",
@ -291,7 +291,7 @@ func (backend *Backend) removeChangesBefore(cutoff time.Time) {
}) })
results, err := cmd.Result() results, err := cmd.Result()
if err != nil { if err != nil {
log.Error().Err(err).Msg("redis: error retrieving changes for expiration") log.Error(ctx).Err(err).Msg("redis: error retrieving changes for expiration")
return return
} }
@ -303,7 +303,7 @@ func (backend *Backend) removeChangesBefore(cutoff time.Time) {
var record databroker.Record var record databroker.Record
err = proto.Unmarshal([]byte(results[0]), &record) err = proto.Unmarshal([]byte(results[0]), &record)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("redis: invalid record detected") log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
record.ModifiedAt = timestamppb.New(cutoff.Add(-time.Second)) // set the modified so will delete it record.ModifiedAt = timestamppb.New(cutoff.Add(-time.Second)) // set the modified so will delete it
} }
@ -315,7 +315,7 @@ func (backend *Backend) removeChangesBefore(cutoff time.Time) {
// remove the record // remove the record
err = backend.client.ZRem(ctx, changesSetKey, results[0]).Err() err = backend.client.ZRem(ctx, changesSetKey, results[0]).Err()
if err != nil { if err != nil {
log.Error().Err(err).Msg("redis: error removing member") log.Error(ctx).Err(err).Msg("redis: error removing member")
return return
} }
} }

View file

@ -181,7 +181,7 @@ func TestExpiry(t *testing.T) {
_ = stream.Close() _ = stream.Close()
require.Len(t, records, 1000) require.Len(t, records, 1000)
backend.removeChangesBefore(time.Now().Add(time.Second)) backend.removeChangesBefore(ctx, time.Now().Add(time.Second))
stream, err = backend.Sync(ctx, 0) stream, err = backend.Sync(ctx, 0)
require.NoError(t, err) require.NoError(t, err)

View file

@ -17,7 +17,7 @@ type recordStream struct {
ctx context.Context ctx context.Context
backend *Backend backend *Backend
changed chan struct{} changed chan context.Context
version uint64 version uint64
record *databroker.Record record *databroker.Record
err error err error
@ -63,6 +63,7 @@ func (stream *recordStream) Next(block bool) bool {
ticker := time.NewTicker(watchPollInterval) ticker := time.NewTicker(watchPollInterval)
defer ticker.Stop() defer ticker.Stop()
changeCtx := context.Background()
for { for {
cmd := stream.backend.client.ZRangeByScore(stream.ctx, changesSetKey, &redis.ZRangeBy{ cmd := stream.backend.client.ZRangeByScore(stream.ctx, changesSetKey, &redis.ZRangeBy{
Min: fmt.Sprintf("(%d", stream.version), Min: fmt.Sprintf("(%d", stream.version),
@ -81,7 +82,7 @@ func (stream *recordStream) Next(block bool) bool {
var record databroker.Record var record databroker.Record
err = proto.Unmarshal([]byte(result), &record) err = proto.Unmarshal([]byte(result), &record)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("redis: invalid record detected") log.Warn(changeCtx).Err(err).Msg("redis: invalid record detected")
} else { } else {
stream.record = &record stream.record = &record
} }
@ -97,7 +98,7 @@ func (stream *recordStream) Next(block bool) bool {
case <-stream.closed: case <-stream.closed:
return false return false
case <-ticker.C: // check again case <-ticker.C: // check again
case <-stream.changed: // check again case changeCtx = <-stream.changed: // check again
} }
} else { } else {
return false return false

View file

@ -57,7 +57,7 @@ func MatchAny(any *anypb.Any, query string) bool {
msg, err := any.UnmarshalNew() msg, err := any.UnmarshalNew()
if err != nil { if err != nil {
// ignore invalid any types // ignore invalid any types
log.Error().Err(err).Msg("storage: invalid any type") log.Error(context.TODO()).Err(err).Msg("storage: invalid any type")
return false return false
} }

View file

@ -35,6 +35,7 @@ func (m *mockCheckClient) Check(ctx context.Context, in *envoy_service_auth_v2.C
} }
func TestProxy_ForwardAuth(t *testing.T) { func TestProxy_ForwardAuth(t *testing.T) {
ctx := context.Background()
t.Parallel() t.Parallel()
allowClient := &mockCheckClient{ allowClient := &mockCheckClient{
@ -85,7 +86,7 @@ func TestProxy_ForwardAuth(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.OnConfigChange(&config.Config{Options: tt.options}) p.OnConfigChange(ctx, &config.Config{Options: tt.options})
state := p.state.Load() state := p.state.Load()
state.sessionStore = tt.sessionStore state.sessionStore = tt.sessionStore
signer, err := jws.NewHS256Signer(nil) signer, err := jws.NewHS256Signer(nil)

View file

@ -2,6 +2,7 @@ package proxy
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -240,7 +241,7 @@ func TestProxy_Callback(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.OnConfigChange(&config.Config{Options: tt.options}) p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
state := p.state.Load() state := p.state.Load()
state.encoder = tt.cipher state.encoder = tt.cipher
state.sessionStore = tt.sessionStore state.sessionStore = tt.sessionStore
@ -486,7 +487,7 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.OnConfigChange(&config.Config{Options: tt.options}) p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
state := p.state.Load() state := p.state.Load()
state.encoder = tt.cipher state.encoder = tt.cipher
state.sessionStore = tt.sessionStore state.sessionStore = tt.sessionStore

View file

@ -5,6 +5,7 @@
package proxy package proxy
import ( import (
"context"
"fmt" "fmt"
"html/template" "html/template"
"net/http" "net/http"
@ -76,17 +77,17 @@ func New(cfg *config.Config) (*Proxy, error) {
} }
// OnConfigChange updates internal structures based on config.Options // OnConfigChange updates internal structures based on config.Options
func (p *Proxy) OnConfigChange(cfg *config.Config) { func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
if p == nil { if p == nil {
return return
} }
p.currentOptions.Store(cfg.Options) p.currentOptions.Store(cfg.Options)
if err := p.setHandlers(cfg.Options); err != nil { if err := p.setHandlers(cfg.Options); err != nil {
log.Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings") log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
} }
if state, err := newProxyStateFromConfig(cfg); err != nil { if state, err := newProxyStateFromConfig(cfg); err != nil {
log.Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings") log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy state from configuration settings")
} else { } else {
p.state.Store(state) p.state.Store(state)
} }
@ -94,7 +95,7 @@ func (p *Proxy) OnConfigChange(cfg *config.Config) {
func (p *Proxy) setHandlers(opts *config.Options) error { func (p *Proxy) setHandlers(opts *config.Options) error {
if len(opts.GetAllPolicies()) == 0 { if len(opts.GetAllPolicies()) == 0 {
log.Warn().Msg("proxy: configuration has no policies") log.Warn(context.TODO()).Msg("proxy: configuration has no policies")
} }
r := httputil.NewRouter() r := httputil.NewRouter()
r.NotFoundHandler = httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { r.NotFoundHandler = httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {

View file

@ -1,6 +1,7 @@
package proxy package proxy
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -197,7 +198,7 @@ func Test_UpdateOptions(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
p.OnConfigChange(&config.Config{Options: tt.updatedOptions}) p.OnConfigChange(context.Background(), &config.Config{Options: tt.updatedOptions})
r := httptest.NewRequest("GET", tt.host, nil) r := httptest.NewRequest("GET", tt.host, nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.ServeHTTP(w, r) p.ServeHTTP(w, r)
@ -210,5 +211,5 @@ func Test_UpdateOptions(t *testing.T) {
// Test nil // Test nil
var p *Proxy var p *Proxy
p.OnConfigChange(&config.Config{}) p.OnConfigChange(context.Background(), &config.Config{})
} }