mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-20 04:27:19 +02:00
log context (#2107)
This commit is contained in:
parent
e7995954ff
commit
e0c09a0998
87 changed files with 714 additions and 524 deletions
|
@ -3,6 +3,7 @@
|
|||
package authenticate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
|
@ -76,19 +77,19 @@ func New(cfg *config.Config) (*Authenticate, error) {
|
|||
}
|
||||
|
||||
// OnConfigChange updates internal structures based on config.Options
|
||||
func (a *Authenticate) OnConfigChange(cfg *config.Config) {
|
||||
func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
if a == nil {
|
||||
return
|
||||
}
|
||||
|
||||
a.options.Store(cfg.Options)
|
||||
if state, err := newAuthenticateStateFromConfig(cfg); err != nil {
|
||||
log.Error().Err(err).Msg("authenticate: failed to update state")
|
||||
log.Error(ctx).Err(err).Msg("authenticate: failed to update state")
|
||||
} else {
|
||||
a.state.Store(state)
|
||||
}
|
||||
if err := a.updateProvider(cfg); err != nil {
|
||||
log.Error().Err(err).Msg("authenticate: failed to update identity provider")
|
||||
log.Error(ctx).Err(err).Msg("authenticate: failed to update identity provider")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -275,7 +275,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
|||
endSessionURL.RawQuery = params.Encode()
|
||||
redirectString = endSessionURL.String()
|
||||
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
|
||||
log.Warn().Err(err).Msg("authenticate.SignOut: failed getting session")
|
||||
log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session")
|
||||
}
|
||||
if redirectString != "" {
|
||||
httputil.Redirect(w, r, redirectString, http.StatusFound)
|
||||
|
@ -558,7 +558,7 @@ func (a *Authenticate) saveSessionToDataBroker(
|
|||
AccessToken: accessToken.AccessToken,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("directory: failed to refresh user data")
|
||||
log.Error(ctx).Err(err).Msg("directory: failed to refresh user data")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -56,7 +56,7 @@ func (a *Authorize) WaitForInitialSync(ctx context.Context) error {
|
|||
return ctx.Err()
|
||||
case <-a.dataBrokerInitialSync:
|
||||
}
|
||||
log.Info().Msg("initial sync from databroker complete")
|
||||
log.Info(ctx).Msg("initial sync from databroker complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -82,10 +82,10 @@ func newPolicyEvaluator(opts *config.Options, store *evaluator.Store) (*evaluato
|
|||
}
|
||||
|
||||
// OnConfigChange updates internal structures based on config.Options
|
||||
func (a *Authorize) OnConfigChange(cfg *config.Config) {
|
||||
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
a.currentOptions.Store(cfg.Options)
|
||||
if state, err := newAuthorizeStateFromConfig(cfg, a.store); err != nil {
|
||||
log.Error().Err(err).Msg("authorize: error updating state")
|
||||
log.Error(ctx).Err(err).Msg("authorize: error updating state")
|
||||
} else {
|
||||
a.state.Store(state)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package authorize
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -116,7 +117,7 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
|
|||
o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUhHNHZDWlJxUFgwNGtmSFQxeVVDM1pUQkF6MFRYWkNtZ043clpDcFE3cHJvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFbzQzdjAwQlR4c3pKZWpmdHhBOWNtVGVUSmtQQXVtOGt1b0UwVlRUZnlId2k3SHJlN2FRUgpHQVJ6Nm0wMjVRdGFiRGxqeDd5MjIyY1gxblhCQXo3MlF3PT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
|
||||
assertFunc = assert.False
|
||||
}
|
||||
a.OnConfigChange(cfg)
|
||||
a.OnConfigChange(context.Background(), cfg)
|
||||
assertFunc(t, oldPe == a.state.Load().evaluator)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package authorize
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
|
@ -103,7 +104,7 @@ func (a *Authorize) htmlDeniedResponse(
|
|||
})
|
||||
if err != nil {
|
||||
buf.WriteString(reason)
|
||||
log.Error().Err(err).Msg("error executing error template")
|
||||
log.Error(context.TODO()).Err(err).Msg("error executing error template")
|
||||
}
|
||||
|
||||
envoyHeaders := []*envoy_config_core_v3.HeaderValueOption{
|
||||
|
|
|
@ -162,7 +162,7 @@ func getJWK(options *config.Options) (*jose.JSONWebKey, error) {
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't generate signing key: %w", err)
|
||||
}
|
||||
log.Info().Str("Algorithm", jwk.Algorithm).
|
||||
log.Info(context.TODO()).Str("Algorithm", jwk.Algorithm).
|
||||
Str("KeyID", jwk.KeyID).
|
||||
Interface("Public Key", jwk.Public()).
|
||||
Msg("authorize: signing key")
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package evaluator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
|
@ -46,7 +47,7 @@ func isValidClientCertificate(ca, cert string) (bool, error) {
|
|||
valid := verifyErr == nil
|
||||
|
||||
if verifyErr != nil {
|
||||
log.Debug().Err(verifyErr).Msg("client certificate failed verification: %w")
|
||||
log.Debug(context.Background()).Err(verifyErr).Msg("client certificate failed verification: %w")
|
||||
}
|
||||
|
||||
isValidClientCertificateCache.Add(cacheKey, valid)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package evaluator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
|
@ -71,7 +72,7 @@ func getDenyVar(vars rego.Vars) []Result {
|
|||
|
||||
status, err := strconv.Atoi(fmt.Sprint(denial[0]))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("invalid type in deny")
|
||||
log.Error(context.TODO()).Err(err).Msg("invalid type in deny")
|
||||
continue
|
||||
}
|
||||
msg := fmt.Sprint(denial[1])
|
||||
|
|
|
@ -156,11 +156,12 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
|
|||
}
|
||||
|
||||
func (s *Store) write(rawPath string, value interface{}) {
|
||||
err := storage.Txn(context.Background(), s.Store, storage.WriteParams, func(txn storage.Transaction) error {
|
||||
ctx := context.TODO()
|
||||
err := storage.Txn(ctx, s.Store, storage.WriteParams, func(txn storage.Transaction) error {
|
||||
return s.writeTxn(txn, rawPath, value)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("opa-store: error writing data")
|
||||
log.Error(ctx).Err(err).Msg("opa-store: error writing data")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,19 +46,19 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
|
||||
u, err := a.forceSync(ctx, sessionState)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("clearing session due to force sync failed")
|
||||
log.Warn(ctx).Err(err).Msg("clearing session due to force sync failed")
|
||||
sessionState = nil
|
||||
}
|
||||
|
||||
req, err := a.getEvaluatorRequestFromCheckRequest(in, sessionState)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("error building evaluator request")
|
||||
log.Warn(ctx).Err(err).Msg("error building evaluator request")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reply, err := state.evaluator.Evaluate(ctx, req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error during OPA evaluation")
|
||||
log.Error(ctx).Err(err).Msg("error during OPA evaluation")
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
|
|
|
@ -26,7 +26,7 @@ func (a *Authorize) logAuthorizeCheck(
|
|||
|
||||
hdrs := getCheckRequestHeaders(in)
|
||||
hattrs := in.GetAttributes().GetRequest().GetHttp()
|
||||
evt := log.Info().Str("service", "authorize")
|
||||
evt := log.Info(ctx).Str("service", "authorize")
|
||||
// request
|
||||
evt = evt.Str("request-id", requestid.FromContext(ctx))
|
||||
evt = evt.Str("check-request-id", hdrs["X-Request-Id"])
|
||||
|
@ -66,10 +66,10 @@ func (a *Authorize) logAuthorizeCheck(
|
|||
}
|
||||
sealed, err := enc.Encrypt(record)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("authorize: error encrypting audit record")
|
||||
log.Warn(ctx).Err(err).Msg("authorize: error encrypting audit record")
|
||||
return
|
||||
}
|
||||
log.Info().
|
||||
log.Info(ctx).
|
||||
Str("request-id", requestid.FromContext(ctx)).
|
||||
EmbedObject(sealed).
|
||||
Msg("audit log")
|
||||
|
|
|
@ -150,7 +150,7 @@ func (a *Authorize) waitForRecordSync(ctx context.Context, recordTypeURL, record
|
|||
// record not found, so no need to wait
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
log.Error().
|
||||
log.Error(ctx).
|
||||
Err(err).
|
||||
Str("type", recordTypeURL).
|
||||
Str("id", recordID).
|
||||
|
@ -160,7 +160,7 @@ func (a *Authorize) waitForRecordSync(ctx context.Context, recordTypeURL, record
|
|||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("type", recordTypeURL).
|
||||
Str("id", recordID).
|
||||
Msg("authorize: first sync of record did not complete")
|
||||
|
|
|
@ -17,10 +17,11 @@ var (
|
|||
)
|
||||
|
||||
func main() {
|
||||
if err := run(context.Background()); !errors.Is(err, context.Canceled) {
|
||||
ctx := context.Background()
|
||||
if err := run(ctx); !errors.Is(err, context.Canceled) {
|
||||
log.Fatal().Err(err).Msg("cmd/pomerium")
|
||||
}
|
||||
log.Info().Msg("cmd/pomerium: exiting")
|
||||
log.Info(ctx).Msg("cmd/pomerium: exiting")
|
||||
}
|
||||
|
||||
func run(ctx context.Context) error {
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"io/ioutil"
|
||||
"sync"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/fileutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -13,7 +16,7 @@ import (
|
|||
)
|
||||
|
||||
// A ChangeListener is called when configuration changes.
|
||||
type ChangeListener = func(*Config)
|
||||
type ChangeListener = func(context.Context, *Config)
|
||||
|
||||
// A ChangeDispatcher manages listeners on config changes.
|
||||
type ChangeDispatcher struct {
|
||||
|
@ -22,17 +25,17 @@ type ChangeDispatcher struct {
|
|||
}
|
||||
|
||||
// Trigger triggers a change.
|
||||
func (dispatcher *ChangeDispatcher) Trigger(cfg *Config) {
|
||||
func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) {
|
||||
dispatcher.Lock()
|
||||
defer dispatcher.Unlock()
|
||||
|
||||
for _, li := range dispatcher.onConfigChangeListeners {
|
||||
li(cfg)
|
||||
li(ctx, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// OnConfigChange adds a listener.
|
||||
func (dispatcher *ChangeDispatcher) OnConfigChange(li ChangeListener) {
|
||||
func (dispatcher *ChangeDispatcher) OnConfigChange(ctx context.Context, li ChangeListener) {
|
||||
dispatcher.Lock()
|
||||
defer dispatcher.Unlock()
|
||||
dispatcher.onConfigChangeListeners = append(dispatcher.onConfigChangeListeners, li)
|
||||
|
@ -41,7 +44,7 @@ func (dispatcher *ChangeDispatcher) OnConfigChange(li ChangeListener) {
|
|||
// A Source gets configuration.
|
||||
type Source interface {
|
||||
GetConfig() *Config
|
||||
OnConfigChange(ChangeListener)
|
||||
OnConfigChange(context.Context, ChangeListener)
|
||||
}
|
||||
|
||||
// A StaticSource always returns the same config. Useful for testing.
|
||||
|
@ -65,18 +68,18 @@ func (src *StaticSource) GetConfig() *Config {
|
|||
}
|
||||
|
||||
// SetConfig sets the config.
|
||||
func (src *StaticSource) SetConfig(cfg *Config) {
|
||||
func (src *StaticSource) SetConfig(ctx context.Context, cfg *Config) {
|
||||
src.mu.Lock()
|
||||
defer src.mu.Unlock()
|
||||
|
||||
src.cfg = cfg
|
||||
for _, li := range src.lis {
|
||||
li(cfg)
|
||||
li(ctx, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// OnConfigChange is ignored for the StaticSource.
|
||||
func (src *StaticSource) OnConfigChange(li ChangeListener) {
|
||||
func (src *StaticSource) OnConfigChange(ctx context.Context, li ChangeListener) {
|
||||
src.mu.Lock()
|
||||
defer src.mu.Unlock()
|
||||
|
||||
|
@ -95,38 +98,48 @@ type FileOrEnvironmentSource struct {
|
|||
|
||||
// NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource.
|
||||
func NewFileOrEnvironmentSource(configFile string) (*FileOrEnvironmentSource, error) {
|
||||
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("config_file_source", configFile)
|
||||
})
|
||||
|
||||
options, err := newOptionsFromConfig(configFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &Config{Options: options}
|
||||
metrics.SetConfigInfo(cfg.Options.Services, "local", cfg.Checksum(), true)
|
||||
metrics.SetConfigInfo(ctx, cfg.Options.Services, "local", cfg.Checksum(), true)
|
||||
|
||||
src := &FileOrEnvironmentSource{
|
||||
configFile: configFile,
|
||||
config: cfg,
|
||||
}
|
||||
options.viper.OnConfigChange(src.onConfigChange)
|
||||
options.viper.OnConfigChange(src.onConfigChange(ctx))
|
||||
go options.viper.WatchConfig()
|
||||
|
||||
return src, nil
|
||||
}
|
||||
|
||||
func (src *FileOrEnvironmentSource) onConfigChange(evt fsnotify.Event) {
|
||||
src.mu.Lock()
|
||||
cfg := src.config
|
||||
options, err := newOptionsFromConfig(src.configFile)
|
||||
if err == nil {
|
||||
cfg = &Config{Options: options}
|
||||
metrics.SetConfigInfo(cfg.Options.Services, "local", cfg.Checksum(), true)
|
||||
} else {
|
||||
log.Error().Err(err).Msg("config: error updating config")
|
||||
metrics.SetConfigInfo(cfg.Options.Services, "local", cfg.Checksum(), false)
|
||||
}
|
||||
src.mu.Unlock()
|
||||
func (src *FileOrEnvironmentSource) onConfigChange(ctx context.Context) func(fsnotify.Event) {
|
||||
return func(evt fsnotify.Event) {
|
||||
ctx := log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("config_change_id", uuid.New().String())
|
||||
})
|
||||
log.Info(ctx).Msg("config: file updated, reconfiguring...")
|
||||
src.mu.Lock()
|
||||
cfg := src.config
|
||||
options, err := newOptionsFromConfig(src.configFile)
|
||||
if err == nil {
|
||||
cfg = &Config{Options: options}
|
||||
metrics.SetConfigInfo(ctx, cfg.Options.Services, "local", cfg.Checksum(), true)
|
||||
} else {
|
||||
log.Error(ctx).Err(err).Msg("config: error updating config")
|
||||
metrics.SetConfigInfo(ctx, cfg.Options.Services, "local", cfg.Checksum(), false)
|
||||
}
|
||||
src.mu.Unlock()
|
||||
|
||||
src.Trigger(cfg)
|
||||
src.Trigger(ctx, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfig gets the config.
|
||||
|
@ -148,7 +161,7 @@ type FileWatcherSource struct {
|
|||
ChangeDispatcher
|
||||
}
|
||||
|
||||
// NewFileWatcherSource creates a new FileWatcherSource.
|
||||
// NewFileWatcherSource creates a new FileWatcherSource
|
||||
func NewFileWatcherSource(underlying Source) *FileWatcherSource {
|
||||
src := &FileWatcherSource{
|
||||
underlying: underlying,
|
||||
|
@ -158,13 +171,13 @@ func NewFileWatcherSource(underlying Source) *FileWatcherSource {
|
|||
ch := src.watcher.Bind()
|
||||
go func() {
|
||||
for range ch {
|
||||
src.check(underlying.GetConfig())
|
||||
src.check(context.TODO(), underlying.GetConfig())
|
||||
}
|
||||
}()
|
||||
underlying.OnConfigChange(func(cfg *Config) {
|
||||
src.check(cfg)
|
||||
underlying.OnConfigChange(context.TODO(), func(ctx context.Context, cfg *Config) {
|
||||
src.check(ctx, cfg)
|
||||
})
|
||||
src.check(underlying.GetConfig())
|
||||
src.check(context.TODO(), underlying.GetConfig())
|
||||
|
||||
return src
|
||||
}
|
||||
|
@ -176,7 +189,7 @@ func (src *FileWatcherSource) GetConfig() *Config {
|
|||
return src.computedConfig
|
||||
}
|
||||
|
||||
func (src *FileWatcherSource) check(cfg *Config) {
|
||||
func (src *FileWatcherSource) check(ctx context.Context, cfg *Config) {
|
||||
if cfg == nil || cfg.Options == nil {
|
||||
return
|
||||
}
|
||||
|
@ -218,5 +231,5 @@ func (src *FileWatcherSource) check(cfg *Config) {
|
|||
src.computedConfig = cfg.Clone()
|
||||
|
||||
// trigger a change
|
||||
src.Trigger(src.computedConfig)
|
||||
src.Trigger(ctx, src.computedConfig)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -13,6 +14,8 @@ import (
|
|||
)
|
||||
|
||||
func TestFileWatcherSource(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpdir := filepath.Join(os.TempDir(), uuid.New().String())
|
||||
err := os.MkdirAll(tmpdir, 0o755)
|
||||
if !assert.NoError(t, err) {
|
||||
|
@ -33,7 +36,7 @@ func TestFileWatcherSource(t *testing.T) {
|
|||
src := NewFileWatcherSource(ssrc)
|
||||
var closeOnce sync.Once
|
||||
ch := make(chan struct{})
|
||||
src.OnConfigChange(func(cfg *Config) {
|
||||
src.OnConfigChange(context.Background(), func(ctx context.Context, cfg *Config) {
|
||||
closeOnce.Do(func() {
|
||||
close(ch)
|
||||
})
|
||||
|
@ -50,7 +53,7 @@ func TestFileWatcherSource(t *testing.T) {
|
|||
t.Error("expected OnConfigChange to be fired after modifying a file")
|
||||
}
|
||||
|
||||
ssrc.SetConfig(&Config{
|
||||
ssrc.SetConfig(ctx, &Config{
|
||||
Options: &Options{
|
||||
CAFile: filepath.Join(tmpdir, "example.txt"),
|
||||
},
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package envoyconfig
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -22,7 +23,7 @@ import (
|
|||
)
|
||||
|
||||
// BuildClusters builds envoy clusters from the given config.
|
||||
func (b *Builder) BuildClusters(cfg *config.Config) ([]*envoy_config_cluster_v3.Cluster, error) {
|
||||
func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*envoy_config_cluster_v3.Cluster, error) {
|
||||
grpcURL := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: b.localGRPCAddress,
|
||||
|
@ -36,15 +37,15 @@ func (b *Builder) BuildClusters(cfg *config.Config) ([]*envoy_config_cluster_v3.
|
|||
return nil, err
|
||||
}
|
||||
|
||||
controlGRPC, err := b.buildInternalCluster(cfg.Options, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, true)
|
||||
controlGRPC, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
controlHTTP, err := b.buildInternalCluster(cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, false)
|
||||
controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authZ, err := b.buildInternalCluster(cfg.Options, "pomerium-authorize", authzURLs, true)
|
||||
authZ, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authzURLs, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -62,7 +63,7 @@ func (b *Builder) BuildClusters(cfg *config.Config) ([]*envoy_config_cluster_v3.
|
|||
policy.EnvoyOpts = newDefaultEnvoyClusterConfig()
|
||||
}
|
||||
if len(policy.To) > 0 {
|
||||
cluster, err := b.buildPolicyCluster(cfg.Options, &policy)
|
||||
cluster, err := b.buildPolicyCluster(ctx, cfg.Options, &policy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("policy #%d: %w", i, err)
|
||||
}
|
||||
|
@ -78,12 +79,18 @@ func (b *Builder) BuildClusters(cfg *config.Config) ([]*envoy_config_cluster_v3.
|
|||
return clusters, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildInternalCluster(options *config.Options, name string, dsts []*url.URL, forceHTTP2 bool) (*envoy_config_cluster_v3.Cluster, error) {
|
||||
func (b *Builder) buildInternalCluster(
|
||||
ctx context.Context,
|
||||
options *config.Options,
|
||||
name string,
|
||||
dsts []*url.URL,
|
||||
forceHTTP2 bool,
|
||||
) (*envoy_config_cluster_v3.Cluster, error) {
|
||||
cluster := newDefaultEnvoyClusterConfig()
|
||||
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
|
||||
var endpoints []Endpoint
|
||||
for _, dst := range dsts {
|
||||
ts, err := b.buildInternalTransportSocket(options, dst)
|
||||
ts, err := b.buildInternalTransportSocket(ctx, options, dst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -95,14 +102,14 @@ func (b *Builder) buildInternalCluster(options *config.Options, name string, dst
|
|||
return cluster, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildPolicyCluster(options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
|
||||
func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
|
||||
cluster := new(envoy_config_cluster_v3.Cluster)
|
||||
proto.Merge(cluster, policy.EnvoyOpts)
|
||||
|
||||
cluster.AltStatName = getClusterStatsName(policy)
|
||||
|
||||
name := getClusterID(policy)
|
||||
endpoints, err := b.buildPolicyEndpoints(policy)
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -122,10 +129,10 @@ func (b *Builder) buildPolicyCluster(options *config.Options, policy *config.Pol
|
|||
return cluster, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildPolicyEndpoints(policy *config.Policy) ([]Endpoint, error) {
|
||||
func (b *Builder) buildPolicyEndpoints(ctx context.Context, policy *config.Policy) ([]Endpoint, error) {
|
||||
var endpoints []Endpoint
|
||||
for _, dst := range policy.To {
|
||||
ts, err := b.buildPolicyTransportSocket(policy, dst.URL)
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, policy, dst.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -134,7 +141,7 @@ func (b *Builder) buildPolicyEndpoints(policy *config.Policy) ([]Endpoint, error
|
|||
return endpoints, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildInternalTransportSocket(options *config.Options, endpoint *url.URL) (*envoy_config_core_v3.TransportSocket, error) {
|
||||
func (b *Builder) buildInternalTransportSocket(ctx context.Context, options *config.Options, endpoint *url.URL) (*envoy_config_core_v3.TransportSocket, error) {
|
||||
if endpoint.Scheme != "https" {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -154,13 +161,13 @@ func (b *Builder) buildInternalTransportSocket(options *config.Options, endpoint
|
|||
} else if options.CA != "" {
|
||||
bs, err := base64.StdEncoding.DecodeString(options.CA)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("invalid custom CA certificate")
|
||||
log.Error(ctx).Err(err).Msg("invalid custom CA certificate")
|
||||
}
|
||||
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
|
||||
} else {
|
||||
rootCA, err := getRootCertificateAuthority()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||
} else {
|
||||
validationContext.TrustedCa = b.filemgr.FileDataSource(rootCA)
|
||||
}
|
||||
|
@ -183,12 +190,12 @@ func (b *Builder) buildInternalTransportSocket(options *config.Options, endpoint
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildPolicyTransportSocket(policy *config.Policy, dst url.URL) (*envoy_config_core_v3.TransportSocket, error) {
|
||||
func (b *Builder) buildPolicyTransportSocket(ctx context.Context, policy *config.Policy, dst url.URL) (*envoy_config_core_v3.TransportSocket, error) {
|
||||
if dst.Scheme != "https" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
vc, err := b.buildPolicyValidationContext(policy, dst)
|
||||
vc, err := b.buildPolicyValidationContext(ctx, policy, dst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -216,7 +223,7 @@ func (b *Builder) buildPolicyTransportSocket(policy *config.Policy, dst url.URL)
|
|||
}
|
||||
if policy.ClientCertificate != nil {
|
||||
tlsContext.CommonTlsContext.TlsCertificates = append(tlsContext.CommonTlsContext.TlsCertificates,
|
||||
b.envoyTLSCertificateFromGoTLSCertificate(policy.ClientCertificate))
|
||||
b.envoyTLSCertificateFromGoTLSCertificate(ctx, policy.ClientCertificate))
|
||||
}
|
||||
|
||||
tlsConfig := marshalAny(tlsContext)
|
||||
|
@ -229,6 +236,7 @@ func (b *Builder) buildPolicyTransportSocket(policy *config.Policy, dst url.URL)
|
|||
}
|
||||
|
||||
func (b *Builder) buildPolicyValidationContext(
|
||||
ctx context.Context,
|
||||
policy *config.Policy, dst url.URL,
|
||||
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
|
||||
sni := dst.Hostname()
|
||||
|
@ -247,13 +255,13 @@ func (b *Builder) buildPolicyValidationContext(
|
|||
} else if policy.TLSCustomCA != "" {
|
||||
bs, err := base64.StdEncoding.DecodeString(policy.TLSCustomCA)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("invalid custom CA certificate")
|
||||
log.Error(ctx).Err(err).Msg("invalid custom CA certificate")
|
||||
}
|
||||
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
|
||||
} else {
|
||||
rootCA, err := getRootCertificateAuthority()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||
} else {
|
||||
validationContext.TrustedCa = b.filemgr.FileDataSource(rootCA)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package envoyconfig
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -18,6 +19,7 @@ import (
|
|||
)
|
||||
|
||||
func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cacheDir, _ := os.UserCacheDir()
|
||||
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-32484c314b584447463735303142374c31414145374650305a525539554938594d524855353757313942494d473847535231.pem")
|
||||
|
||||
|
@ -26,14 +28,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
rootCA := b.filemgr.FileDataSource(rootCAPath).GetFilename()
|
||||
|
||||
t.Run("insecure", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(&config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://example.com"),
|
||||
}, *mustParseURL(t, "http://example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, ts)
|
||||
})
|
||||
t.Run("host as sni", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(&config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
require.NoError(t, err)
|
||||
|
@ -67,7 +69,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("tls_server_name as sni", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(&config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
TLSServerName: "use-this-name.example.com",
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -102,7 +104,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("tls_skip_verify", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(&config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
TLSSkipVerify: true,
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -138,7 +140,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("custom ca", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(&config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -174,7 +176,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
})
|
||||
t.Run("client certificate", func(t *testing.T) {
|
||||
clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey)
|
||||
ts, err := b.buildPolicyTransportSocket(&config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
ClientCertificate: clientCert,
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -219,11 +221,12 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_buildCluster(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
b := New("local-grpc", "local-http", filemgr.NewManager(), nil)
|
||||
rootCAPath, _ := getRootCertificateAuthority()
|
||||
rootCA := b.filemgr.FileDataSource(rootCAPath).GetFilename()
|
||||
t.Run("insecure", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(&config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -278,7 +281,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("secure", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(&config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
|
||||
To: mustParseWeightedURLs(t,
|
||||
"https://example.com",
|
||||
"https://example.com",
|
||||
|
@ -406,7 +409,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("ip addresses", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(&config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -459,7 +462,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("weights", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(&config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -514,7 +517,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("localhost", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(&config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://localhost"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -557,7 +560,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("outlier", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(&config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://example.com"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -3,6 +3,7 @@ package envoyconfig
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
|
@ -132,7 +133,10 @@ func buildAddress(hostport string, defaultPort int) *envoy_config_core_v3.Addres
|
|||
}
|
||||
}
|
||||
|
||||
func (b *Builder) envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_extensions_transport_sockets_tls_v3.TlsCertificate {
|
||||
func (b *Builder) envoyTLSCertificateFromGoTLSCertificate(
|
||||
ctx context.Context,
|
||||
cert *tls.Certificate,
|
||||
) *envoy_extensions_transport_sockets_tls_v3.TlsCertificate {
|
||||
envoyCert := &envoy_extensions_transport_sockets_tls_v3.TlsCertificate{}
|
||||
var chain bytes.Buffer
|
||||
for _, cbs := range cert.Certificate {
|
||||
|
@ -153,7 +157,7 @@ func (b *Builder) envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate)
|
|||
},
|
||||
))
|
||||
} else {
|
||||
log.Warn().Err(err).Msg("failed to marshal private key for tls config")
|
||||
log.Warn(ctx).Err(err).Msg("failed to marshal private key for tls config")
|
||||
}
|
||||
for _, scts := range cert.SignedCertificateTimestamps {
|
||||
envoyCert.SignedCertificateTimestamp = append(envoyCert.SignedCertificateTimestamp,
|
||||
|
@ -185,10 +189,10 @@ func getRootCertificateAuthority() (string, error) {
|
|||
}
|
||||
}
|
||||
if rootCABundle.value == "" {
|
||||
log.Error().Strs("known-locations", knownRootLocations).
|
||||
log.Error(context.TODO()).Strs("known-locations", knownRootLocations).
|
||||
Msgf("no root certificates were found in any of the known locations")
|
||||
} else {
|
||||
log.Info().Msgf("using %s as the system root certificate authority bundle", rootCABundle.value)
|
||||
log.Info(context.TODO()).Msgf("using %s as the system root certificate authority bundle", rootCABundle.value)
|
||||
}
|
||||
})
|
||||
if rootCABundle.value == "" {
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
package filemgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
@ -34,7 +35,7 @@ func (mgr *Manager) BytesDataSource(fileName string, data []byte) *envoy_config_
|
|||
fileName = fmt.Sprintf("%s-%x%s", fileName[:len(fileName)-len(ext)], h, ext)
|
||||
|
||||
if err := os.MkdirAll(mgr.cfg.cacheDir, 0o700); err != nil {
|
||||
log.Error().Err(err).Msg("filemgr: error creating cache directory, falling back to inline bytes")
|
||||
log.Error(context.TODO()).Err(err).Msg("filemgr: error creating cache directory, falling back to inline bytes")
|
||||
return inlineBytes(data)
|
||||
}
|
||||
|
||||
|
@ -42,11 +43,11 @@ func (mgr *Manager) BytesDataSource(fileName string, data []byte) *envoy_config_
|
|||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
err = ioutil.WriteFile(filePath, data, 0o600)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("filemgr: error writing cache file, falling back to inline bytes")
|
||||
log.Error(context.TODO()).Err(err).Msg("filemgr: error writing cache file, falling back to inline bytes")
|
||||
return inlineBytes(data)
|
||||
}
|
||||
} else if err != nil {
|
||||
log.Error().Err(err).Msg("filemgr: error reading cache file, falling back to inline bytes")
|
||||
log.Error(context.TODO()).Err(err).Msg("filemgr: error reading cache file, falling back to inline bytes")
|
||||
return inlineBytes(data)
|
||||
}
|
||||
|
||||
|
@ -62,7 +63,7 @@ func (mgr *Manager) ClearCache() {
|
|||
return os.Remove(p)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to clear envoy file cache")
|
||||
log.Error(context.TODO()).Err(err).Msg("failed to clear envoy file cache")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package envoyconfig
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -187,7 +188,7 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen
|
|||
CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{
|
||||
TlsParams: tlsParams,
|
||||
TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{
|
||||
b.envoyTLSCertificateFromGoTLSCertificate(cert),
|
||||
b.envoyTLSCertificateFromGoTLSCertificate(context.TODO(), cert),
|
||||
},
|
||||
AlpnProtocols: []string{"h2", "http/1.1"},
|
||||
},
|
||||
|
@ -645,19 +646,21 @@ func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_con
|
|||
}
|
||||
|
||||
func (b *Builder) buildDownstreamTLSContext(cfg *config.Config, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
|
||||
ctx := context.TODO()
|
||||
|
||||
certs, err := cfg.AllCertificates()
|
||||
if err != nil {
|
||||
log.Warn().Str("domain", domain).Err(err).Msg("failed to get all certificates from config")
|
||||
log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get all certificates from config")
|
||||
return nil
|
||||
}
|
||||
|
||||
cert, err := cryptutil.GetCertificateForDomain(certs, domain)
|
||||
if err != nil {
|
||||
log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain")
|
||||
log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get certificate for domain")
|
||||
return nil
|
||||
}
|
||||
|
||||
envoyCert := b.envoyTLSCertificateFromGoTLSCertificate(cert)
|
||||
envoyCert := b.envoyTLSCertificateFromGoTLSCertificate(context.TODO(), cert)
|
||||
return &envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext{
|
||||
CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{
|
||||
TlsParams: tlsParams,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
@ -8,6 +9,8 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/tripper"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type httpTransport struct {
|
||||
|
@ -18,16 +21,19 @@ type httpTransport struct {
|
|||
// NewHTTPTransport creates a new http transport. If CA or CAFile is set, the transport will
|
||||
// add the CA to system cert pool.
|
||||
func NewHTTPTransport(src Source) http.RoundTripper {
|
||||
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context {
|
||||
return c.Caller()
|
||||
})
|
||||
t := new(httpTransport)
|
||||
t.underlying, _ = http.DefaultTransport.(*http.Transport)
|
||||
src.OnConfigChange(func(cfg *Config) {
|
||||
t.update(cfg.Options)
|
||||
src.OnConfigChange(ctx, func(ctx context.Context, cfg *Config) {
|
||||
t.update(ctx, cfg.Options)
|
||||
})
|
||||
t.update(src.GetConfig().Options)
|
||||
t.update(ctx, src.GetConfig().Options)
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *httpTransport) update(options *Options) {
|
||||
func (t *httpTransport) update(ctx context.Context, options *Options) {
|
||||
nt := new(http.Transport)
|
||||
if t.underlying != nil {
|
||||
nt = t.underlying.Clone()
|
||||
|
@ -40,7 +46,7 @@ func (t *httpTransport) update(options *Options) {
|
|||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
} else {
|
||||
log.Error().Err(err).Msg("config: error getting cert pool")
|
||||
log.Error(ctx).Err(err).Msg("config: error getting cert pool")
|
||||
}
|
||||
}
|
||||
t.transport.Store(nt)
|
||||
|
@ -78,7 +84,7 @@ func NewPolicyHTTPTransport(options *Options, policy *Policy) http.RoundTripper
|
|||
tlsClientConfig.MinVersion = tls.VersionTLS12
|
||||
isCustomClientConfig = true
|
||||
} else {
|
||||
log.Error().Err(err).Msg("config: error getting ca cert pool")
|
||||
log.Error(context.TODO()).Err(err).Msg("config: error getting ca cert pool")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -89,7 +95,7 @@ func NewPolicyHTTPTransport(options *Options, policy *Policy) http.RoundTripper
|
|||
tlsClientConfig.MinVersion = tls.VersionTLS12
|
||||
isCustomClientConfig = true
|
||||
} else {
|
||||
log.Error().Err(err).Msg("config: error getting custom ca cert pool")
|
||||
log.Error(context.TODO()).Err(err).Msg("config: error getting custom ca cert pool")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -12,10 +13,10 @@ type LogManager struct {
|
|||
}
|
||||
|
||||
// NewLogManager creates a new LogManager.
|
||||
func NewLogManager(src Source) *LogManager {
|
||||
func NewLogManager(ctx context.Context, src Source) *LogManager {
|
||||
mgr := &LogManager{}
|
||||
src.OnConfigChange(mgr.OnConfigChange)
|
||||
mgr.OnConfigChange(src.GetConfig())
|
||||
src.OnConfigChange(ctx, mgr.OnConfigChange)
|
||||
mgr.OnConfigChange(ctx, src.GetConfig())
|
||||
return mgr
|
||||
}
|
||||
|
||||
|
@ -25,7 +26,7 @@ func (mgr *LogManager) Close() error {
|
|||
}
|
||||
|
||||
// OnConfigChange is called whenever configuration changes.
|
||||
func (mgr *LogManager) OnConfigChange(cfg *Config) {
|
||||
func (mgr *LogManager) OnConfigChange(ctx context.Context, cfg *Config) {
|
||||
if cfg == nil || cfg.Options == nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
|
@ -9,6 +10,8 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/telemetry"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// A MetricsManager manages metrics for a given configuration.
|
||||
|
@ -22,11 +25,14 @@ type MetricsManager struct {
|
|||
}
|
||||
|
||||
// NewMetricsManager creates a new MetricsManager.
|
||||
func NewMetricsManager(src Source) *MetricsManager {
|
||||
func NewMetricsManager(ctx context.Context, src Source) *MetricsManager {
|
||||
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("service", "metrics_manager")
|
||||
})
|
||||
mgr := &MetricsManager{}
|
||||
metrics.RegisterInfoMetrics()
|
||||
src.OnConfigChange(mgr.OnConfigChange)
|
||||
mgr.OnConfigChange(src.GetConfig())
|
||||
src.OnConfigChange(ctx, mgr.OnConfigChange)
|
||||
mgr.OnConfigChange(ctx, src.GetConfig())
|
||||
return mgr
|
||||
}
|
||||
|
||||
|
@ -36,7 +42,7 @@ func (mgr *MetricsManager) Close() error {
|
|||
}
|
||||
|
||||
// OnConfigChange updates the metrics manager when configuration is changed.
|
||||
func (mgr *MetricsManager) OnConfigChange(cfg *Config) {
|
||||
func (mgr *MetricsManager) OnConfigChange(ctx context.Context, cfg *Config) {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
||||
|
@ -63,7 +69,7 @@ func (mgr *MetricsManager) updateInfo(cfg *Config) {
|
|||
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: failed to get OS hostname")
|
||||
log.Error(context.TODO()).Err(err).Msg("telemetry/metrics: failed to get OS hostname")
|
||||
hostname = "__unknown__"
|
||||
}
|
||||
|
||||
|
@ -84,13 +90,13 @@ func (mgr *MetricsManager) updateServer(cfg *Config) {
|
|||
mgr.handler = nil
|
||||
|
||||
if mgr.addr == "" {
|
||||
log.Info().Msg("metrics: http server disabled")
|
||||
log.Info(context.TODO()).Msg("metrics: http server disabled")
|
||||
return
|
||||
}
|
||||
|
||||
handler, err := metrics.PrometheusHandler(EnvoyAdminURL, mgr.installationID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("metrics: failed to create prometheus handler")
|
||||
log.Error(context.TODO()).Err(err).Msg("metrics: failed to create prometheus handler")
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -12,12 +13,13 @@ import (
|
|||
)
|
||||
|
||||
func TestMetricsManager(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
src := NewStaticSource(&Config{
|
||||
Options: &Options{
|
||||
MetricsAddr: "ADDRESS",
|
||||
},
|
||||
})
|
||||
mgr := NewMetricsManager(src)
|
||||
mgr := NewMetricsManager(ctx, src)
|
||||
srv1 := httptest.NewServer(mgr)
|
||||
defer srv1.Close()
|
||||
srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -42,7 +44,7 @@ func TestMetricsManagerBasicAuth(t *testing.T) {
|
|||
MetricsBasicAuth: base64.StdEncoding.EncodeToString([]byte("x:y")),
|
||||
},
|
||||
})
|
||||
mgr := NewMetricsManager(src)
|
||||
mgr := NewMetricsManager(context.Background(), src)
|
||||
srv1 := httptest.NewServer(mgr)
|
||||
defer srv1.Close()
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package config
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
|
@ -427,7 +428,7 @@ func (o *Options) viperIsSet(key string) bool {
|
|||
|
||||
// parseHeaders handles unmarshalling any custom headers correctly from the
|
||||
// environment or viper's parsed keys
|
||||
func (o *Options) parseHeaders() error {
|
||||
func (o *Options) parseHeaders(ctx context.Context) error {
|
||||
var headers map[string]string
|
||||
if o.HeadersEnv != "" {
|
||||
// Handle JSON by default via viper
|
||||
|
@ -450,7 +451,7 @@ func (o *Options) parseHeaders() error {
|
|||
}
|
||||
|
||||
if o.viperIsSet("headers") {
|
||||
log.Warn().Msg("config: headers has been renamed to set_response_headers, it will be removed in v0.16")
|
||||
log.Warn(ctx).Msg("config: headers has been renamed to set_response_headers, it will be removed in v0.16")
|
||||
}
|
||||
|
||||
// option was renamed from `headers` to `set_response_headers`. Both config settings are supported.
|
||||
|
@ -507,6 +508,7 @@ func bindEnvs(o *Options, v *viper.Viper) error {
|
|||
|
||||
// Validate ensures the Options fields are valid, and hydrated.
|
||||
func (o *Options) Validate() error {
|
||||
ctx := context.TODO()
|
||||
if !IsValidService(o.Services) {
|
||||
return fmt.Errorf("config: %s is an invalid service type", o.Services)
|
||||
}
|
||||
|
@ -591,7 +593,7 @@ func (o *Options) Validate() error {
|
|||
return fmt.Errorf("config: failed to parse policy: %w", err)
|
||||
}
|
||||
|
||||
if err := o.parseHeaders(); err != nil {
|
||||
if err := o.parseHeaders(ctx); err != nil {
|
||||
return fmt.Errorf("config: failed to parse headers: %w", err)
|
||||
}
|
||||
|
||||
|
@ -669,7 +671,7 @@ func (o *Options) Validate() error {
|
|||
// GoogleCloudServerlessAuthenticationServiceAccount
|
||||
if o.Provider == "google" && o.GoogleCloudServerlessAuthenticationServiceAccount == "" {
|
||||
o.GoogleCloudServerlessAuthenticationServiceAccount = o.ServiceAccount
|
||||
log.Info().Msg("defaulting to idp_service_account for google_cloud_serverless_authentication_service_account")
|
||||
log.Info(ctx).Msg("defaulting to idp_service_account for google_cloud_serverless_authentication_service_account")
|
||||
}
|
||||
|
||||
// strip quotes from redirect address (#811)
|
||||
|
@ -683,7 +685,7 @@ func (o *Options) Validate() error {
|
|||
switch o.Provider {
|
||||
case azure.Name, github.Name, gitlab.Name, google.Name, okta.Name, onelogin.Name:
|
||||
if len(o.Scopes) > 0 {
|
||||
log.Warn().Msg(idpCustomScopesWarnMsg)
|
||||
log.Warn(ctx).Msg(idpCustomScopesWarnMsg)
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
@ -193,7 +194,7 @@ func Test_parseHeaders(t *testing.T) {
|
|||
o.viperSet("headers", tt.viperHeaders)
|
||||
o.viperSet("HeadersEnv", tt.envHeaders)
|
||||
o.HeadersEnv = tt.envHeaders
|
||||
err := o.parseHeaders()
|
||||
err := o.parseHeaders(context.Background())
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Error condition unexpected: err=%s", err)
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
octrace "go.opencensus.io/trace"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
octrace "go.opencensus.io/trace"
|
||||
)
|
||||
|
||||
// TracingOptions are the options for tracing.
|
||||
|
@ -60,10 +62,13 @@ type TraceManager struct {
|
|||
}
|
||||
|
||||
// NewTraceManager creates a new TraceManager.
|
||||
func NewTraceManager(src Source) *TraceManager {
|
||||
func NewTraceManager(ctx context.Context, src Source) *TraceManager {
|
||||
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("service", "trace_manager")
|
||||
})
|
||||
mgr := &TraceManager{}
|
||||
src.OnConfigChange(mgr.OnConfigChange)
|
||||
mgr.OnConfigChange(src.GetConfig())
|
||||
src.OnConfigChange(ctx, mgr.OnConfigChange)
|
||||
mgr.OnConfigChange(ctx, src.GetConfig())
|
||||
return mgr
|
||||
}
|
||||
|
||||
|
@ -79,18 +84,18 @@ func (mgr *TraceManager) Close() error {
|
|||
}
|
||||
|
||||
// OnConfigChange updates the manager whenever the configuration is changed.
|
||||
func (mgr *TraceManager) OnConfigChange(cfg *Config) {
|
||||
func (mgr *TraceManager) OnConfigChange(ctx context.Context, cfg *Config) {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
||||
traceOpts, err := NewTracingOptions(cfg.Options)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("trace: failed to build tracing options")
|
||||
log.Error(ctx).Err(err).Msg("trace: failed to build tracing options")
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(traceOpts, mgr.traceOpts) {
|
||||
log.Debug().Msg("no change detected in trace options")
|
||||
log.Debug(ctx).Msg("no change detected in trace options")
|
||||
return
|
||||
}
|
||||
mgr.traceOpts = traceOpts
|
||||
|
@ -104,11 +109,11 @@ func (mgr *TraceManager) OnConfigChange(cfg *Config) {
|
|||
return
|
||||
}
|
||||
|
||||
log.Info().Interface("options", traceOpts).Msg("trace: starting exporter")
|
||||
log.Info(ctx).Interface("options", traceOpts).Msg("trace: starting exporter")
|
||||
|
||||
mgr.exporter, err = trace.RegisterTracing(traceOpts)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("trace: failed to register exporter")
|
||||
log.Error(ctx).Err(err).Msg("trace: failed to register exporter")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -123,12 +123,12 @@ func TestTraceManager(t *testing.T) {
|
|||
TracingSampleRate: 1,
|
||||
}})
|
||||
|
||||
_ = NewTraceManager(src)
|
||||
_ = NewTraceManager(ctx, src)
|
||||
|
||||
_, span := trace.StartSpan(ctx, "Example")
|
||||
span.End()
|
||||
|
||||
src.SetConfig(&Config{Options: &Options{
|
||||
src.SetConfig(ctx, &Config{Options: &Options{
|
||||
TracingProvider: "zipkin",
|
||||
ZipkinEndpoint: srv2.URL,
|
||||
TracingSampleRate: 1,
|
||||
|
|
|
@ -104,13 +104,13 @@ func New(cfg *config.Config) (*DataBroker, error) {
|
|||
}
|
||||
|
||||
// OnConfigChange is called whenever configuration is changed.
|
||||
func (c *DataBroker) OnConfigChange(cfg *config.Config) {
|
||||
func (c *DataBroker) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
err := c.update(cfg)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("databroker: error updating configuration")
|
||||
log.Error(ctx).Err(err).Msg("databroker: error updating configuration")
|
||||
}
|
||||
|
||||
c.dataBrokerServer.OnConfigChange(cfg)
|
||||
c.dataBrokerServer.OnConfigChange(ctx, cfg)
|
||||
}
|
||||
|
||||
// Register registers all the gRPC services with the given server.
|
||||
|
|
|
@ -27,7 +27,7 @@ func newDataBrokerServer(cfg *config.Config) *dataBrokerServer {
|
|||
}
|
||||
|
||||
// OnConfigChange updates the underlying databroker server whenever configuration is changed.
|
||||
func (srv *dataBrokerServer) OnConfigChange(cfg *config.Config) {
|
||||
func (srv *dataBrokerServer) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
srv.server.UpdateConfig(srv.getOptions(cfg)...)
|
||||
srv.setKey(cfg)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,8 @@ import (
|
|||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
||||
"golang.org/x/net/publicsuffix"
|
||||
)
|
||||
|
||||
|
@ -51,6 +52,6 @@ type loggingRoundTripper struct {
|
|||
|
||||
func (rt *loggingRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) {
|
||||
res, err = rt.RoundTripper.RoundTrip(req)
|
||||
log.Debug().Str("method", req.Method).Str("url", req.URL.String()).Msg("http request")
|
||||
log.Debug(req.Context()).Str("method", req.Method).Str("url", req.URL.String()).Msg("http request")
|
||||
return res, err
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
type cmdOption func(*exec.Cmd)
|
||||
|
@ -53,7 +53,7 @@ func run(ctx context.Context, name string, options ...cmdOption) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to create stderr pipe for %s: %w", name, err)
|
||||
}
|
||||
go cmdLogger(stderr)
|
||||
go cmdLogger(ctx, stderr)
|
||||
defer stderr.Close()
|
||||
}
|
||||
if cmd.Stdout == nil {
|
||||
|
@ -61,17 +61,17 @@ func run(ctx context.Context, name string, options ...cmdOption) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdout pipe for %s: %w", name, err)
|
||||
}
|
||||
go cmdLogger(stdout)
|
||||
go cmdLogger(ctx, stdout)
|
||||
defer stdout.Close()
|
||||
}
|
||||
|
||||
log.Debug().Strs("args", cmd.Args).Msgf("running %s", name)
|
||||
log.Debug(ctx).Strs("args", cmd.Args).Msgf("running %s", name)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func cmdLogger(rdr io.Reader) {
|
||||
func cmdLogger(ctx context.Context, rdr io.Reader) {
|
||||
s := bufio.NewScanner(rdr)
|
||||
for s.Scan() {
|
||||
log.Debug().Msg(s.Text())
|
||||
log.Debug(ctx).Msg(s.Text())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,9 +15,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/go-jsonnet"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/pomerium/pomerium/integration/internal/netutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
var requiredDeployments = []string{
|
||||
|
@ -180,7 +180,7 @@ func applyManifests(ctx context.Context, jsonsrc string) error {
|
|||
return fmt.Errorf("error applying manifests: %w", err)
|
||||
}
|
||||
|
||||
log.Info().Msg("waiting for deployments to come up")
|
||||
log.Info(ctx).Msg("waiting for deployments to come up")
|
||||
ctx, clearTimeout := context.WithTimeout(ctx, 15*time.Minute)
|
||||
defer clearTimeout()
|
||||
ticker := time.NewTicker(time.Second * 5)
|
||||
|
@ -218,7 +218,7 @@ func applyManifests(ctx context.Context, jsonsrc string) error {
|
|||
for _, dep := range requiredDeployments {
|
||||
if byName[dep] < 1 {
|
||||
done = false
|
||||
log.Warn().Str("deployment", dep).Msg("deployment is not ready yet")
|
||||
log.Warn(ctx).Str("deployment", dep).Msg("deployment is not ready yet")
|
||||
}
|
||||
}
|
||||
if done {
|
||||
|
@ -233,7 +233,7 @@ func applyManifests(ctx context.Context, jsonsrc string) error {
|
|||
<-ticker.C
|
||||
}
|
||||
time.Sleep(time.Minute)
|
||||
log.Info().Msg("all deployments are ready")
|
||||
log.Info(ctx).Msg("all deployments are ready")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -69,15 +69,15 @@ func newManager(ctx context.Context,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mgr.src.OnConfigChange(func(cfg *config.Config) {
|
||||
mgr.src.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) {
|
||||
err := mgr.update(cfg)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("autocert: error updating config")
|
||||
log.Error(ctx).Err(err).Msg("autocert: error updating config")
|
||||
return
|
||||
}
|
||||
|
||||
cfg = mgr.GetConfig()
|
||||
mgr.Trigger(cfg)
|
||||
mgr.Trigger(ctx, cfg)
|
||||
})
|
||||
go func() {
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
|
@ -90,7 +90,7 @@ func newManager(ctx context.Context,
|
|||
case <-ticker.C:
|
||||
err := mgr.renewConfigCerts()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("autocert: error updating config")
|
||||
log.Error(context.TODO()).Err(err).Msg("autocert: error updating config")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -153,7 +153,7 @@ func (mgr *Manager) renewConfigCerts() error {
|
|||
}
|
||||
|
||||
mgr.config = cfg
|
||||
mgr.Trigger(cfg)
|
||||
mgr.Trigger(context.TODO(), cfg)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -172,10 +172,10 @@ func (mgr *Manager) update(cfg *config.Config) error {
|
|||
func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.Certificate, error) {
|
||||
cert, err := cm.CacheManagedCertificate(domain)
|
||||
if err != nil {
|
||||
log.Info().Str("domain", domain).Msg("obtaining certificate")
|
||||
log.Info(context.TODO()).Str("domain", domain).Msg("obtaining certificate")
|
||||
err = cm.ObtainCert(context.Background(), domain, false)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("autocert failed to obtain client certificate")
|
||||
log.Error(context.TODO()).Err(err).Msg("autocert failed to obtain client certificate")
|
||||
return certmagic.Certificate{}, errObtainCertFailed
|
||||
}
|
||||
metrics.RecordAutocertRenewal()
|
||||
|
@ -187,13 +187,13 @@ func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.C
|
|||
// renewCert attempts to renew given certificate.
|
||||
func (mgr *Manager) renewCert(domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) {
|
||||
expired := time.Now().After(cert.Leaf.NotAfter)
|
||||
log.Info().Str("domain", domain).Msg("renewing certificate")
|
||||
log.Info(context.TODO()).Str("domain", domain).Msg("renewing certificate")
|
||||
err := cm.RenewCert(context.Background(), domain, false)
|
||||
if err != nil {
|
||||
if expired {
|
||||
return certmagic.Certificate{}, errRenewCertFailed
|
||||
}
|
||||
log.Warn().Err(err).Msg("renew client certificated failed, use existing cert")
|
||||
log.Warn(context.TODO()).Err(err).Msg("renew client certificated failed, use existing cert")
|
||||
}
|
||||
return cm.CacheManagedCertificate(domain)
|
||||
}
|
||||
|
@ -220,11 +220,11 @@ func (mgr *Manager) updateAutocert(cfg *config.Config) error {
|
|||
return fmt.Errorf("autocert: failed to renew client certificate: %w", err)
|
||||
}
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("autocert: failed to obtain client certificate")
|
||||
log.Error(context.TODO()).Err(err).Msg("autocert: failed to obtain client certificate")
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info().Strs("names", cert.Names).Msg("autocert: added certificate")
|
||||
log.Info(context.TODO()).Strs("names", cert.Names).Msg("autocert: added certificate")
|
||||
cfg.AutoCertificates = append(cfg.AutoCertificates, cert.Certificate)
|
||||
}
|
||||
|
||||
|
@ -260,10 +260,10 @@ func (mgr *Manager) updateServer(cfg *config.Config) {
|
|||
}),
|
||||
}
|
||||
go func() {
|
||||
log.Info().Str("addr", hsrv.Addr).Msg("starting http redirect server")
|
||||
log.Info(context.TODO()).Str("addr", hsrv.Addr).Msg("starting http redirect server")
|
||||
err := hsrv.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to run http redirect server")
|
||||
log.Error(context.TODO()).Err(err).Msg("failed to run http redirect server")
|
||||
}
|
||||
}()
|
||||
mgr.srv = hsrv
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"syscall"
|
||||
|
||||
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate"
|
||||
|
@ -32,7 +33,7 @@ import (
|
|||
|
||||
// Run runs the main pomerium application.
|
||||
func Run(ctx context.Context, configFile string) error {
|
||||
log.Info().Str("version", version.FullVersion()).Msg("cmd/pomerium")
|
||||
log.Info(ctx).Str("version", version.FullVersion()).Msg("cmd/pomerium")
|
||||
|
||||
var src config.Source
|
||||
|
||||
|
@ -41,8 +42,8 @@ func Run(ctx context.Context, configFile string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
src = databroker.NewConfigSource(src)
|
||||
logMgr := config.NewLogManager(src)
|
||||
src = databroker.NewConfigSource(ctx, src)
|
||||
logMgr := config.NewLogManager(ctx, src)
|
||||
defer logMgr.Close()
|
||||
|
||||
// trigger changes when underlying files are changed
|
||||
|
@ -56,9 +57,9 @@ func Run(ctx context.Context, configFile string) error {
|
|||
// override the default http transport so we can use the custom CA in the TLS client config (#1570)
|
||||
http.DefaultTransport = config.NewHTTPTransport(src)
|
||||
|
||||
metricsMgr := config.NewMetricsManager(src)
|
||||
metricsMgr := config.NewMetricsManager(ctx, src)
|
||||
defer metricsMgr.Close()
|
||||
traceMgr := config.NewTraceManager(src)
|
||||
traceMgr := config.NewTraceManager(ctx, src)
|
||||
defer traceMgr.Close()
|
||||
|
||||
// setup the control plane
|
||||
|
@ -66,43 +67,46 @@ func Run(ctx context.Context, configFile string) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("error creating control plane: %w", err)
|
||||
}
|
||||
src.OnConfigChange(func(cfg *config.Config) {
|
||||
if err := controlPlane.OnConfigChange(cfg); err != nil {
|
||||
log.Error().Err(err).Msg("config change")
|
||||
}
|
||||
})
|
||||
src.OnConfigChange(ctx,
|
||||
func(ctx context.Context, cfg *config.Config) {
|
||||
if err := controlPlane.OnConfigChange(ctx, cfg); err != nil {
|
||||
log.Error(ctx).Err(err).Msg("config change")
|
||||
}
|
||||
})
|
||||
|
||||
if err = controlPlane.OnConfigChange(src.GetConfig()); err != nil {
|
||||
if err = controlPlane.OnConfigChange(log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("config_file_source", configFile).Bool("bootstrap", true)
|
||||
}), src.GetConfig()); err != nil {
|
||||
return fmt.Errorf("applying config: %w", err)
|
||||
}
|
||||
|
||||
_, grpcPort, _ := net.SplitHostPort(controlPlane.GRPCListener.Addr().String())
|
||||
_, httpPort, _ := net.SplitHostPort(controlPlane.HTTPListener.Addr().String())
|
||||
|
||||
log.Info().Str("port", grpcPort).Msg("gRPC server started")
|
||||
log.Info().Str("port", httpPort).Msg("HTTP server started")
|
||||
log.Info(ctx).Str("port", grpcPort).Msg("gRPC server started")
|
||||
log.Info(ctx).Str("port", httpPort).Msg("HTTP server started")
|
||||
|
||||
// create envoy server
|
||||
envoyServer, err := envoy.NewServer(src, grpcPort, httpPort, controlPlane.Builder)
|
||||
envoyServer, err := envoy.NewServer(ctx, src, grpcPort, httpPort, controlPlane.Builder)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating envoy server: %w", err)
|
||||
}
|
||||
defer envoyServer.Close()
|
||||
|
||||
// add services
|
||||
if err := setupAuthenticate(src, controlPlane); err != nil {
|
||||
if err := setupAuthenticate(ctx, src, controlPlane); err != nil {
|
||||
return err
|
||||
}
|
||||
var authorizeServer *authorize.Authorize
|
||||
if config.IsAuthorize(src.GetConfig().Options.Services) {
|
||||
authorizeServer, err = setupAuthorize(src, controlPlane)
|
||||
authorizeServer, err = setupAuthorize(ctx, src, controlPlane)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
var dataBrokerServer *databroker_service.DataBroker
|
||||
if config.IsDataBroker(src.GetConfig().Options.Services) {
|
||||
dataBrokerServer, err = setupDataBroker(src, controlPlane)
|
||||
dataBrokerServer, err = setupDataBroker(ctx, src, controlPlane)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up databroker: %w", err)
|
||||
}
|
||||
|
@ -112,10 +116,10 @@ func Run(ctx context.Context, configFile string) error {
|
|||
}
|
||||
}
|
||||
|
||||
if err = setupRegistryReporter(src); err != nil {
|
||||
if err = setupRegistryReporter(ctx, src); err != nil {
|
||||
return fmt.Errorf("setting up registry reporter: %w", err)
|
||||
}
|
||||
if err := setupProxy(src, controlPlane); err != nil {
|
||||
if err := setupProxy(ctx, src, controlPlane); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -159,7 +163,7 @@ func Run(ctx context.Context, configFile string) error {
|
|||
return eg.Wait()
|
||||
}
|
||||
|
||||
func setupAuthenticate(src config.Source, controlPlane *controlplane.Server) error {
|
||||
func setupAuthenticate(ctx context.Context, src config.Source, controlPlane *controlplane.Server) error {
|
||||
if !config.IsAuthenticate(src.GetConfig().Options.Services) {
|
||||
return nil
|
||||
}
|
||||
|
@ -174,56 +178,56 @@ func setupAuthenticate(src config.Source, controlPlane *controlplane.Server) err
|
|||
return fmt.Errorf("error getting authenticate URL: %w", err)
|
||||
}
|
||||
|
||||
src.OnConfigChange(svc.OnConfigChange)
|
||||
svc.OnConfigChange(src.GetConfig())
|
||||
src.OnConfigChange(ctx, svc.OnConfigChange)
|
||||
svc.OnConfigChange(ctx, src.GetConfig())
|
||||
host := urlutil.StripPort(authenticateURL.Host)
|
||||
sr := controlPlane.HTTPRouter.Host(host).Subrouter()
|
||||
svc.Mount(sr)
|
||||
log.Info().Str("host", host).Msg("enabled authenticate service")
|
||||
log.Info(context.TODO()).Str("host", host).Msg("enabled authenticate service")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupAuthorize(src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) {
|
||||
func setupAuthorize(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*authorize.Authorize, error) {
|
||||
svc, err := authorize.New(src.GetConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating authorize service: %w", err)
|
||||
}
|
||||
envoy_service_auth_v3.RegisterAuthorizationServer(controlPlane.GRPCServer, svc)
|
||||
|
||||
log.Info().Msg("enabled authorize service")
|
||||
src.OnConfigChange(svc.OnConfigChange)
|
||||
svc.OnConfigChange(src.GetConfig())
|
||||
log.Info(context.TODO()).Msg("enabled authorize service")
|
||||
src.OnConfigChange(ctx, svc.OnConfigChange)
|
||||
svc.OnConfigChange(ctx, src.GetConfig())
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func setupDataBroker(src config.Source, controlPlane *controlplane.Server) (*databroker_service.DataBroker, error) {
|
||||
func setupDataBroker(ctx context.Context, src config.Source, controlPlane *controlplane.Server) (*databroker_service.DataBroker, error) {
|
||||
svc, err := databroker_service.New(src.GetConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating databroker service: %w", err)
|
||||
}
|
||||
svc.Register(controlPlane.GRPCServer)
|
||||
log.Info().Msg("enabled databroker service")
|
||||
src.OnConfigChange(svc.OnConfigChange)
|
||||
svc.OnConfigChange(src.GetConfig())
|
||||
log.Info(context.TODO()).Msg("enabled databroker service")
|
||||
src.OnConfigChange(ctx, svc.OnConfigChange)
|
||||
svc.OnConfigChange(ctx, src.GetConfig())
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func setupRegistryServer(src config.Source, controlPlane *controlplane.Server) error {
|
||||
svc := registry.NewInMemoryServer(context.TODO(), registryTTL)
|
||||
registry_pb.RegisterRegistryServer(controlPlane.GRPCServer, svc)
|
||||
log.Info().Msg("enabled service discovery")
|
||||
log.Info(context.TODO()).Msg("enabled service discovery")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupRegistryReporter(src config.Source) error {
|
||||
func setupRegistryReporter(ctx context.Context, src config.Source) error {
|
||||
reporter := new(registry.Reporter)
|
||||
src.OnConfigChange(reporter.OnConfigChange)
|
||||
reporter.OnConfigChange(src.GetConfig())
|
||||
src.OnConfigChange(ctx, reporter.OnConfigChange)
|
||||
reporter.OnConfigChange(ctx, src.GetConfig())
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupProxy(src config.Source, controlPlane *controlplane.Server) error {
|
||||
func setupProxy(ctx context.Context, src config.Source, controlPlane *controlplane.Server) error {
|
||||
if !config.IsProxy(src.GetConfig().Options.Services) {
|
||||
return nil
|
||||
}
|
||||
|
@ -234,9 +238,9 @@ func setupProxy(src config.Source, controlPlane *controlplane.Server) error {
|
|||
}
|
||||
controlPlane.HTTPRouter.PathPrefix("/").Handler(svc)
|
||||
|
||||
log.Info().Msg("enabled proxy service")
|
||||
src.OnConfigChange(svc.OnConfigChange)
|
||||
svc.OnConfigChange(src.GetConfig())
|
||||
log.Info(context.TODO()).Msg("enabled proxy service")
|
||||
src.OnConfigChange(ctx, svc.OnConfigChange)
|
||||
svc.OnConfigChange(ctx, src.GetConfig())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ func (srv *Server) StreamAccessLogs(stream envoy_service_accesslog_v3.AccessLogS
|
|||
for {
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("access log stream error, disconnecting")
|
||||
log.Error(stream.Context()).Err(err).Msg("access log stream error, disconnecting")
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -27,9 +27,9 @@ func (srv *Server) StreamAccessLogs(stream envoy_service_accesslog_v3.AccessLogS
|
|||
reqPath := entry.GetRequest().GetPath()
|
||||
var evt *zerolog.Event
|
||||
if reqPath == "/ping" || reqPath == "/healthz" {
|
||||
evt = log.Debug()
|
||||
evt = log.Debug(stream.Context())
|
||||
} else {
|
||||
evt = log.Info()
|
||||
evt = log.Info(stream.Context())
|
||||
}
|
||||
// common properties
|
||||
evt = evt.Str("service", "envoy")
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
@ -106,7 +107,11 @@ func NewServer(name string, metricsMgr *config.MetricsManager) (*Server, error)
|
|||
srv.reproxy,
|
||||
)
|
||||
|
||||
res, err := srv.buildDiscoveryResources()
|
||||
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("server_name", name)
|
||||
})
|
||||
|
||||
res, err := srv.buildDiscoveryResources(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -123,7 +128,7 @@ func (srv *Server) Run(ctx context.Context) error {
|
|||
|
||||
// start the gRPC server
|
||||
eg.Go(func() error {
|
||||
log.Info().Str("addr", srv.GRPCListener.Addr().String()).Msg("starting control-plane gRPC server")
|
||||
log.Info(ctx).Str("addr", srv.GRPCListener.Addr().String()).Msg("starting control-plane gRPC server")
|
||||
return srv.GRPCServer.Serve(srv.GRPCListener)
|
||||
})
|
||||
|
||||
|
@ -160,7 +165,7 @@ func (srv *Server) Run(ctx context.Context) error {
|
|||
|
||||
// start the HTTP server
|
||||
eg.Go(func() error {
|
||||
log.Info().Str("addr", srv.HTTPListener.Addr().String()).Msg("starting control-plane HTTP server")
|
||||
log.Info(ctx).Str("addr", srv.HTTPListener.Addr().String()).Msg("starting control-plane HTTP server")
|
||||
return hsrv.Serve(srv.HTTPListener)
|
||||
})
|
||||
|
||||
|
@ -178,17 +183,17 @@ func (srv *Server) Run(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// OnConfigChange updates the pomerium config options.
|
||||
func (srv *Server) OnConfigChange(cfg *config.Config) error {
|
||||
srv.reproxy.Update(cfg)
|
||||
func (srv *Server) OnConfigChange(ctx context.Context, cfg *config.Config) error {
|
||||
srv.reproxy.Update(ctx, cfg)
|
||||
prev := srv.currentConfig.Load()
|
||||
srv.currentConfig.Store(versionedConfig{
|
||||
Config: cfg,
|
||||
version: prev.version + 1,
|
||||
})
|
||||
res, err := srv.buildDiscoveryResources()
|
||||
res, err := srv.buildDiscoveryResources(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv.xdsmgr.Update(res)
|
||||
srv.xdsmgr.Update(ctx, res)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package controlplane
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
|
||||
envoy_service_discovery_v3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
|
||||
|
@ -14,11 +15,11 @@ const (
|
|||
listenerTypeURL = "type.googleapis.com/envoy.config.listener.v3.Listener"
|
||||
)
|
||||
|
||||
func (srv *Server) buildDiscoveryResources() (map[string][]*envoy_service_discovery_v3.Resource, error) {
|
||||
func (srv *Server) buildDiscoveryResources(ctx context.Context) (map[string][]*envoy_service_discovery_v3.Resource, error) {
|
||||
resources := map[string][]*envoy_service_discovery_v3.Resource{}
|
||||
cfg := srv.currentConfig.Load()
|
||||
|
||||
clusters, err := srv.Builder.BuildClusters(cfg.Config)
|
||||
clusters, err := srv.Builder.BuildClusters(ctx, cfg.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
package xdsmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
|
@ -51,7 +52,7 @@ func (mgr *Manager) DeltaAggregatedResources(
|
|||
|
||||
stateByTypeURL := map[string]*streamState{}
|
||||
|
||||
getDeltaResponse := func(typeURL string) *envoy_service_discovery_v3.DeltaDiscoveryResponse {
|
||||
getDeltaResponse := func(ctx context.Context, typeURL string) *envoy_service_discovery_v3.DeltaDiscoveryResponse {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
||||
|
@ -85,7 +86,7 @@ func (mgr *Manager) DeltaAggregatedResources(
|
|||
return res
|
||||
}
|
||||
|
||||
handleDeltaRequest := func(req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
|
||||
handleDeltaRequest := func(ctx context.Context, req *envoy_service_discovery_v3.DeltaDiscoveryRequest) {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
||||
|
@ -109,8 +110,9 @@ func (mgr *Manager) DeltaAggregatedResources(
|
|||
case req.GetErrorDetail() != nil:
|
||||
// a NACK
|
||||
bs, _ := json.Marshal(req.ErrorDetail.Details)
|
||||
log.Error().
|
||||
log.Error(ctx).
|
||||
Err(errors.New(req.ErrorDetail.Message)).
|
||||
Str("nonce", req.ResponseNonce).
|
||||
Int32("code", req.ErrorDetail.Code).
|
||||
RawJSON("details", bs).Msg("error applying configuration")
|
||||
// - set the client resource versions to the current resource versions
|
||||
|
@ -121,12 +123,18 @@ func (mgr *Manager) DeltaAggregatedResources(
|
|||
case req.GetResponseNonce() == mgr.nonce:
|
||||
// an ACK for the last response
|
||||
// - set the client resource versions to the current resource versions
|
||||
log.Debug(ctx).
|
||||
Str("nonce", req.ResponseNonce).
|
||||
Msg("ACK")
|
||||
state.clientResourceVersions = make(map[string]string)
|
||||
for _, resource := range mgr.resources[req.GetTypeUrl()] {
|
||||
state.clientResourceVersions[resource.Name] = resource.Version
|
||||
}
|
||||
default:
|
||||
// an ACK for a response that's not the last response
|
||||
log.Debug(ctx).
|
||||
Str("nonce", req.ResponseNonce).
|
||||
Msg("stale ACK")
|
||||
}
|
||||
|
||||
// update subscriptions
|
||||
|
@ -168,15 +176,16 @@ func (mgr *Manager) DeltaAggregatedResources(
|
|||
})
|
||||
// 2. handle incoming requests or resource changes
|
||||
eg.Go(func() error {
|
||||
changeCtx := ctx
|
||||
for {
|
||||
var typeURLs []string
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case req := <-incoming:
|
||||
handleDeltaRequest(req)
|
||||
handleDeltaRequest(changeCtx, req)
|
||||
typeURLs = []string{req.GetTypeUrl()}
|
||||
case <-ch:
|
||||
case changeCtx = <-ch:
|
||||
mgr.mu.Lock()
|
||||
for typeURL := range mgr.resources {
|
||||
typeURLs = append(typeURLs, typeURL)
|
||||
|
@ -185,7 +194,7 @@ func (mgr *Manager) DeltaAggregatedResources(
|
|||
}
|
||||
|
||||
for _, typeURL := range typeURLs {
|
||||
res := getDeltaResponse(typeURL)
|
||||
res := getDeltaResponse(changeCtx, typeURL)
|
||||
if res == nil {
|
||||
continue
|
||||
}
|
||||
|
@ -194,6 +203,10 @@ func (mgr *Manager) DeltaAggregatedResources(
|
|||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case outgoing <- res:
|
||||
log.Info(changeCtx).
|
||||
Str("nounce", res.Nonce).
|
||||
Str("type", res.TypeUrl).
|
||||
Msg("send update")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -224,11 +237,11 @@ func (mgr *Manager) StreamAggregatedResources(
|
|||
|
||||
// Update updates the state of resources. If any changes are made they will be pushed to any listening
|
||||
// streams. For each TypeURL the list of resources should be the complete list of resources.
|
||||
func (mgr *Manager) Update(resources map[string][]*envoy_service_discovery_v3.Resource) {
|
||||
func (mgr *Manager) Update(ctx context.Context, resources map[string][]*envoy_service_discovery_v3.Resource) {
|
||||
mgr.mu.Lock()
|
||||
mgr.nonce = uuid.New().String()
|
||||
mgr.resources = resources
|
||||
mgr.mu.Unlock()
|
||||
|
||||
mgr.signal.Broadcast()
|
||||
mgr.signal.Broadcast(ctx)
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ func TestManager(t *testing.T) {
|
|||
origOnHandleDeltaRequest := onHandleDeltaRequest
|
||||
defer func() { onHandleDeltaRequest = origOnHandleDeltaRequest }()
|
||||
onHandleDeltaRequest = func(state *streamState) {
|
||||
stateChanged.Broadcast()
|
||||
stateChanged.Broadcast(ctx)
|
||||
}
|
||||
|
||||
srv := grpc.NewServer()
|
||||
|
@ -94,7 +94,7 @@ func TestManager(t *testing.T) {
|
|||
}, msg.GetResources())
|
||||
ack(msg.Nonce)
|
||||
|
||||
mgr.Update(map[string][]*envoy_service_discovery_v3.Resource{
|
||||
mgr.Update(ctx, map[string][]*envoy_service_discovery_v3.Resource{
|
||||
typeURL: {{Name: "r1", Version: "2"}},
|
||||
})
|
||||
|
||||
|
@ -105,7 +105,7 @@ func TestManager(t *testing.T) {
|
|||
}, msg.GetResources())
|
||||
ack(msg.Nonce)
|
||||
|
||||
mgr.Update(map[string][]*envoy_service_discovery_v3.Resource{
|
||||
mgr.Update(ctx, map[string][]*envoy_service_discovery_v3.Resource{
|
||||
typeURL: nil,
|
||||
})
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package databroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
@ -65,7 +66,7 @@ func WithSharedKey(sharedKey string) ServerOption {
|
|||
return func(cfg *serverConfig) {
|
||||
key, err := base64.StdEncoding.DecodeString(sharedKey)
|
||||
if err != nil || len(key) != cryptutil.DefaultKeySize {
|
||||
log.Error().Err(err).Msgf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
|
||||
log.Error(context.TODO()).Err(err).Msgf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
|
||||
return
|
||||
}
|
||||
cfg.secret = key
|
||||
|
|
|
@ -35,22 +35,22 @@ type dbConfig struct {
|
|||
}
|
||||
|
||||
// NewConfigSource creates a new ConfigSource.
|
||||
func NewConfigSource(underlying config.Source, listeners ...config.ChangeListener) *ConfigSource {
|
||||
func NewConfigSource(ctx context.Context, underlying config.Source, listeners ...config.ChangeListener) *ConfigSource {
|
||||
src := &ConfigSource{
|
||||
dbConfigs: map[string]dbConfig{},
|
||||
}
|
||||
for _, li := range listeners {
|
||||
src.OnConfigChange(li)
|
||||
src.OnConfigChange(ctx, li)
|
||||
}
|
||||
underlying.OnConfigChange(func(cfg *config.Config) {
|
||||
underlying.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) {
|
||||
src.mu.Lock()
|
||||
src.underlyingConfig = cfg.Clone()
|
||||
src.mu.Unlock()
|
||||
|
||||
src.rebuild(false)
|
||||
src.rebuild(ctx, firstTime(false))
|
||||
})
|
||||
src.underlyingConfig = underlying.GetConfig()
|
||||
src.rebuild(true)
|
||||
src.rebuild(ctx, firstTime(true))
|
||||
return src
|
||||
}
|
||||
|
||||
|
@ -62,8 +62,10 @@ func (src *ConfigSource) GetConfig() *config.Config {
|
|||
return src.computedConfig
|
||||
}
|
||||
|
||||
func (src *ConfigSource) rebuild(firstTime bool) {
|
||||
_, span := trace.StartSpan(context.Background(), "databroker.config_source.rebuild")
|
||||
type firstTime bool
|
||||
|
||||
func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.config_source.rebuild")
|
||||
defer span.End()
|
||||
|
||||
src.mu.Lock()
|
||||
|
@ -78,7 +80,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
|
|||
for _, policy := range cfg.Options.GetAllPolicies() {
|
||||
id, err := policy.RouteID()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).
|
||||
log.Warn(ctx).Err(err).
|
||||
Str("policy", policy.String()).
|
||||
Msg("databroker: invalid policy config, ignoring")
|
||||
return
|
||||
|
@ -95,7 +97,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
|
|||
|
||||
err := cfg.Options.Validate()
|
||||
if err != nil {
|
||||
metrics.SetDBConfigRejected(cfg.Options.Services, id, cfgpb.version, err)
|
||||
metrics.SetDBConfigRejected(ctx, cfg.Options.Services, id, cfgpb.version, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -103,7 +105,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
|
|||
policy, err := config.NewPolicyFromProto(routepb)
|
||||
if err != nil {
|
||||
errCount++
|
||||
log.Warn().Err(err).
|
||||
log.Warn(ctx).Err(err).
|
||||
Str("db_config_id", id).
|
||||
Msg("databroker: error converting protobuf into policy")
|
||||
continue
|
||||
|
@ -112,7 +114,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
|
|||
err = policy.Validate()
|
||||
if err != nil {
|
||||
errCount++
|
||||
log.Warn().Err(err).
|
||||
log.Warn(ctx).Err(err).
|
||||
Str("db_config_id", id).
|
||||
Str("policy", policy.String()).
|
||||
Msg("databroker: invalid policy, ignoring")
|
||||
|
@ -122,7 +124,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
|
|||
routeID, err := policy.RouteID()
|
||||
if err != nil {
|
||||
errCount++
|
||||
log.Warn().Err(err).
|
||||
log.Warn(ctx).Err(err).
|
||||
Str("db_config_id", id).
|
||||
Str("policy", policy.String()).
|
||||
Msg("databroker: cannot establish policy route ID, ignoring")
|
||||
|
@ -131,7 +133,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
|
|||
|
||||
if _, ok := seen[routeID]; ok {
|
||||
errCount++
|
||||
log.Warn().Err(err).
|
||||
log.Warn(ctx).Err(err).
|
||||
Str("db_config_id", id).
|
||||
Str("policy", policy.String()).
|
||||
Msg("databroker: duplicate policy detected, ignoring")
|
||||
|
@ -141,7 +143,7 @@ func (src *ConfigSource) rebuild(firstTime bool) {
|
|||
|
||||
additionalPolicies = append(additionalPolicies, *policy)
|
||||
}
|
||||
metrics.SetDBConfigInfo(cfg.Options.Services, id, cfgpb.version, int64(errCount))
|
||||
metrics.SetDBConfigInfo(ctx, cfg.Options.Services, id, cfgpb.version, int64(errCount))
|
||||
}
|
||||
|
||||
// add the additional policies here since calling `Validate` will reset them.
|
||||
|
@ -149,10 +151,10 @@ func (src *ConfigSource) rebuild(firstTime bool) {
|
|||
|
||||
src.computedConfig = cfg
|
||||
if !firstTime {
|
||||
src.Trigger(cfg)
|
||||
src.Trigger(ctx, cfg)
|
||||
}
|
||||
|
||||
metrics.SetConfigInfo(cfg.Options.Services, "databroker", cfg.Checksum(), true)
|
||||
metrics.SetConfigInfo(ctx, cfg.Options.Services, "databroker", cfg.Checksum(), true)
|
||||
}
|
||||
|
||||
func (src *ConfigSource) runUpdater(cfg *config.Config) {
|
||||
|
@ -191,7 +193,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
|
|||
|
||||
cc, err := grpc.NewGRPCClientConn(connectionOptions)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("databroker: failed to create gRPC connection to data broker")
|
||||
log.Error(context.TODO()).Err(err).Msg("databroker: failed to create gRPC connection to data broker")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -237,7 +239,7 @@ func (s *syncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64,
|
|||
var cfgpb configpb.Config
|
||||
err := record.GetData().UnmarshalTo(&cfgpb)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("databroker: error decoding config")
|
||||
log.Warn(ctx).Err(err).Msg("databroker: error decoding config")
|
||||
delete(s.src.dbConfigs, record.GetId())
|
||||
continue
|
||||
}
|
||||
|
@ -246,5 +248,5 @@ func (s *syncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64,
|
|||
}
|
||||
s.src.mu.Unlock()
|
||||
|
||||
s.src.rebuild(false)
|
||||
s.src.rebuild(ctx, firstTime(false))
|
||||
}
|
||||
|
|
|
@ -37,9 +37,9 @@ func TestConfigSource(t *testing.T) {
|
|||
base.InsecureServer = true
|
||||
base.GRPCInsecure = true
|
||||
|
||||
src := NewConfigSource(config.NewStaticSource(&config.Config{
|
||||
src := NewConfigSource(ctx, config.NewStaticSource(&config.Config{
|
||||
Options: base,
|
||||
}), func(cfg *config.Config) {
|
||||
}), func(_ context.Context, cfg *config.Config) {
|
||||
cfgs <- cfg
|
||||
})
|
||||
cfgs <- src.GetConfig()
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/rs/zerolog"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
@ -35,7 +34,6 @@ const (
|
|||
// Server implements the databroker service using an in memory database.
|
||||
type Server struct {
|
||||
cfg *serverConfig
|
||||
log zerolog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
version uint64
|
||||
|
@ -44,33 +42,31 @@ type Server struct {
|
|||
|
||||
// New creates a new server.
|
||||
func New(options ...ServerOption) *Server {
|
||||
srv := &Server{
|
||||
log: log.With().Str("service", "databroker").Logger(),
|
||||
}
|
||||
srv := &Server{}
|
||||
srv.UpdateConfig(options...)
|
||||
return srv
|
||||
}
|
||||
|
||||
func (srv *Server) initVersion() {
|
||||
func (srv *Server) initVersion(ctx context.Context) {
|
||||
db, _, err := srv.getBackendLocked()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to init server version")
|
||||
log.Error(ctx).Err(err).Msg("failed to init server version")
|
||||
return
|
||||
}
|
||||
|
||||
// Get version from storage first.
|
||||
r, err := db.Get(context.Background(), recordTypeServerVersion, serverVersionKey)
|
||||
r, err := db.Get(ctx, recordTypeServerVersion, serverVersionKey)
|
||||
switch {
|
||||
case err == nil:
|
||||
var sv wrapperspb.UInt64Value
|
||||
if err := r.GetData().UnmarshalTo(&sv); err == nil {
|
||||
srv.log.Debug().Uint64("server_version", sv.Value).Msg("got db version from Backend")
|
||||
log.Debug(ctx).Uint64("server_version", sv.Value).Msg("got db version from Backend")
|
||||
srv.version = sv.Value
|
||||
}
|
||||
return
|
||||
case errors.Is(err, storage.ErrNotFound): // no server version, so we'll create a new one
|
||||
case err != nil:
|
||||
log.Error().Err(err).Msg("failed to retrieve server version")
|
||||
log.Error(ctx).Err(err).Msg("failed to retrieve server version")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -81,7 +77,7 @@ func (srv *Server) initVersion() {
|
|||
Id: serverVersionKey,
|
||||
Data: data,
|
||||
}); err != nil {
|
||||
srv.log.Warn().Err(err).Msg("failed to save server version.")
|
||||
log.Warn(ctx).Err(err).Msg("failed to save server version.")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,9 +86,11 @@ func (srv *Server) UpdateConfig(options ...ServerOption) {
|
|||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
ctx := context.TODO()
|
||||
|
||||
cfg := newServerConfig(options...)
|
||||
if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) {
|
||||
log.Debug().Msg("databroker: no changes detected, re-using existing DBs")
|
||||
log.Debug(ctx).Msg("databroker: no changes detected, re-using existing DBs")
|
||||
return
|
||||
}
|
||||
srv.cfg = cfg
|
||||
|
@ -100,19 +98,19 @@ func (srv *Server) UpdateConfig(options ...ServerOption) {
|
|||
if srv.backend != nil {
|
||||
err := srv.backend.Close()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("databroker: error closing backend")
|
||||
log.Error(ctx).Err(err).Msg("databroker: error closing backend")
|
||||
}
|
||||
srv.backend = nil
|
||||
}
|
||||
|
||||
srv.initVersion()
|
||||
srv.initVersion(ctx)
|
||||
}
|
||||
|
||||
// Get gets a record from the in-memory list.
|
||||
func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databroker.GetResponse, error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.grpc.Get")
|
||||
defer span.End()
|
||||
srv.log.Info().
|
||||
log.Info(ctx).
|
||||
Str("peer", grpcutil.GetPeerAddr(ctx)).
|
||||
Str("type", req.GetType()).
|
||||
Str("id", req.GetId()).
|
||||
|
@ -141,7 +139,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
|||
func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*databroker.QueryResponse, error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.grpc.Query")
|
||||
defer span.End()
|
||||
srv.log.Info().
|
||||
log.Info(ctx).
|
||||
Str("peer", grpcutil.GetPeerAddr(ctx)).
|
||||
Str("type", req.GetType()).
|
||||
Str("query", req.GetQuery()).
|
||||
|
@ -185,7 +183,7 @@ func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databr
|
|||
defer span.End()
|
||||
record := req.GetRecord()
|
||||
|
||||
srv.log.Info().
|
||||
log.Info(ctx).
|
||||
Str("peer", grpcutil.GetPeerAddr(ctx)).
|
||||
Str("type", record.GetType()).
|
||||
Str("id", record.GetId()).
|
||||
|
@ -208,7 +206,7 @@ func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databr
|
|||
func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error {
|
||||
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync")
|
||||
defer span.End()
|
||||
srv.log.Info().
|
||||
log.Info(stream.Context()).
|
||||
Str("peer", grpcutil.GetPeerAddr(stream.Context())).
|
||||
Uint64("server_version", req.GetServerVersion()).
|
||||
Uint64("record_version", req.GetRecordVersion()).
|
||||
|
@ -251,7 +249,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
|||
func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error {
|
||||
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.SyncLatest")
|
||||
defer span.End()
|
||||
srv.log.Info().
|
||||
log.Info(stream.Context()).
|
||||
Str("peer", grpcutil.GetPeerAddr(stream.Context())).
|
||||
Str("type", req.GetType()).
|
||||
Msg("sync latest")
|
||||
|
@ -333,9 +331,10 @@ func (srv *Server) getBackendLocked() (backend storage.Backend, version uint64,
|
|||
}
|
||||
|
||||
func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
|
||||
ctx := context.Background()
|
||||
caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("failed to read databroker CA file")
|
||||
log.Warn(ctx).Err(err).Msg("failed to read databroker CA file")
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: caCertPool,
|
||||
|
@ -348,10 +347,10 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
|
|||
|
||||
switch srv.cfg.storageType {
|
||||
case config.StorageInMemoryName:
|
||||
srv.log.Info().Msg("using in-memory store")
|
||||
log.Info(ctx).Msg("using in-memory store")
|
||||
return inmemory.New(), nil
|
||||
case config.StorageRedisName:
|
||||
srv.log.Info().Msg("using redis store")
|
||||
log.Info(ctx).Msg("using redis store")
|
||||
backend, err = redis.New(
|
||||
srv.cfg.storageConnectionString,
|
||||
redis.WithTLSConfig(tlsConfig),
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
)
|
||||
|
@ -19,7 +18,6 @@ func newServer(cfg *serverConfig) *Server {
|
|||
return &Server{
|
||||
version: 11,
|
||||
cfg: cfg,
|
||||
log: log.With().Str("service", "databroker").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -69,19 +69,25 @@ func getConfig(options ...Option) *config {
|
|||
// The Provider retrieves users and groups from gitlab.
|
||||
type Provider struct {
|
||||
cfg *config
|
||||
log zerolog.Logger
|
||||
}
|
||||
|
||||
// New creates a new Provider.
|
||||
func New(options ...Option) *Provider {
|
||||
return &Provider{
|
||||
cfg: getConfig(options...),
|
||||
log: log.With().Str("service", "directory").Str("provider", "gitlab").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
func withLog(ctx context.Context) context.Context {
|
||||
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("service", "directory").Str("provider", "gitlab")
|
||||
})
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
ctx = withLog(ctx)
|
||||
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
@ -107,11 +113,13 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc
|
|||
|
||||
// UserGroups gets the directory user groups for gitlab.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
ctx = withLog(ctx)
|
||||
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, nil, fmt.Errorf("gitlab: service account not defined")
|
||||
}
|
||||
|
||||
p.log.Info().Msg("getting user groups")
|
||||
log.Info(ctx).Msg("getting user groups")
|
||||
|
||||
groups, err := p.listGroups(ctx, "")
|
||||
if err != nil {
|
||||
|
|
|
@ -93,7 +93,6 @@ func getConfig(options ...Option) *config {
|
|||
// A Provider is an Okta user group directory provider.
|
||||
type Provider struct {
|
||||
cfg *config
|
||||
log zerolog.Logger
|
||||
lastUpdated *time.Time
|
||||
groups map[string]*directory.Group
|
||||
}
|
||||
|
@ -102,13 +101,20 @@ type Provider struct {
|
|||
func New(options ...Option) *Provider {
|
||||
return &Provider{
|
||||
cfg: getConfig(options...),
|
||||
log: log.With().Str("service", "directory").Str("provider", "okta").Logger(),
|
||||
groups: make(map[string]*directory.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func withLog(ctx context.Context) context.Context {
|
||||
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("service", "directory").Str("provider", "okta")
|
||||
})
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
ctx = withLog(ctx)
|
||||
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, ErrServiceAccountNotDefined
|
||||
}
|
||||
|
@ -139,11 +145,13 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc
|
|||
// UserGroups fetches the groups of which the user is a member
|
||||
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
ctx = withLog(ctx)
|
||||
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, nil, ErrServiceAccountNotDefined
|
||||
}
|
||||
|
||||
p.log.Info().Msg("getting user groups")
|
||||
log.Info(ctx).Msg("getting user groups")
|
||||
|
||||
if p.cfg.providerURL == nil {
|
||||
return nil, nil, ErrProviderURLNotDefined
|
||||
|
@ -164,7 +172,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
// the cached lookup and the local groups list
|
||||
var apiErr *APIError
|
||||
if errors.As(err, &apiErr) && apiErr.HTTPStatusCode == http.StatusNotFound {
|
||||
log.Debug().Str("group", group.Id).Msg("okta: group was removed")
|
||||
log.Debug(ctx).Str("group", group.Id).Msg("okta: group was removed")
|
||||
delete(p.groups, group.Id)
|
||||
groups = append(groups[:i], groups[i+1:]...)
|
||||
i--
|
||||
|
|
|
@ -78,7 +78,6 @@ func getConfig(options ...Option) *config {
|
|||
// The Provider retrieves users and groups from onelogin.
|
||||
type Provider struct {
|
||||
cfg *config
|
||||
log zerolog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
token *oauth2.Token
|
||||
|
@ -89,10 +88,15 @@ func New(options ...Option) *Provider {
|
|||
cfg := getConfig(options...)
|
||||
return &Provider{
|
||||
cfg: cfg,
|
||||
log: log.With().Str("service", "directory").Str("provider", "onelogin").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
func withLog(ctx context.Context) context.Context {
|
||||
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("service", "directory").Str("provider", "onelogin")
|
||||
})
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
|
@ -102,6 +106,8 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc
|
|||
Id: userID,
|
||||
}
|
||||
|
||||
ctx = withLog(ctx)
|
||||
|
||||
token, err := p.getToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -124,7 +130,9 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
return nil, nil, fmt.Errorf("onelogin: service account not defined")
|
||||
}
|
||||
|
||||
p.log.Info().Msg("getting user groups")
|
||||
ctx = withLog(ctx)
|
||||
|
||||
log.Info(ctx).Msg("getting user groups")
|
||||
|
||||
token, err := p.getToken(ctx)
|
||||
if err != nil {
|
||||
|
@ -252,7 +260,7 @@ func (p *Provider) apiGet(ctx context.Context, accessToken string, uri string, o
|
|||
return "", err
|
||||
}
|
||||
|
||||
p.log.Info().
|
||||
log.Info(ctx).
|
||||
Str("url", uri).
|
||||
Interface("result", result).
|
||||
Msg("api request")
|
||||
|
|
|
@ -50,8 +50,9 @@ func GetProvider(options Options) (provider Provider) {
|
|||
globalProvider.Lock()
|
||||
defer globalProvider.Unlock()
|
||||
|
||||
ctx := context.TODO()
|
||||
if globalProvider.provider != nil && cmp.Equal(globalProvider.options, options) {
|
||||
log.Debug().Str("provider", options.Provider).Msg("directory: no change detected, reusing existing directory provider")
|
||||
log.Debug(ctx).Str("provider", options.Provider).Msg("directory: no change detected, reusing existing directory provider")
|
||||
return globalProvider.provider
|
||||
}
|
||||
defer func() {
|
||||
|
@ -72,7 +73,7 @@ func GetProvider(options Options) (provider Provider) {
|
|||
auth0.WithDomain(options.ProviderURL),
|
||||
auth0.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
|
@ -82,7 +83,7 @@ func GetProvider(options Options) (provider Provider) {
|
|||
if err == nil {
|
||||
return azure.New(azure.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
|
@ -92,7 +93,7 @@ func GetProvider(options Options) (provider Provider) {
|
|||
if err == nil {
|
||||
return github.New(github.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
|
@ -107,7 +108,7 @@ func GetProvider(options Options) (provider Provider) {
|
|||
gitlab.WithURL(providerURL),
|
||||
gitlab.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
|
@ -117,7 +118,7 @@ func GetProvider(options Options) (provider Provider) {
|
|||
if err == nil {
|
||||
return google.New(google.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
|
@ -129,7 +130,7 @@ func GetProvider(options Options) (provider Provider) {
|
|||
okta.WithProviderURL(providerURL),
|
||||
okta.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
|
@ -139,7 +140,7 @@ func GetProvider(options Options) (provider Provider) {
|
|||
if err == nil {
|
||||
return onelogin.New(onelogin.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
|
@ -151,14 +152,14 @@ func GetProvider(options Options) (provider Provider) {
|
|||
ping.WithProviderURL(providerURL),
|
||||
ping.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for ping directory provider")
|
||||
}
|
||||
|
||||
log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("provider", options.Provider).
|
||||
Msg("no directory provider implementation found, disabling support for groups")
|
||||
return nullProvider{}
|
||||
|
|
|
@ -67,7 +67,7 @@ type Server struct {
|
|||
}
|
||||
|
||||
// NewServer creates a new server with traffic routed by envoy.
|
||||
func NewServer(src config.Source, grpcPort, httpPort string, builder *envoyconfig.Builder) (*Server, error) {
|
||||
func NewServer(ctx context.Context, src config.Source, grpcPort, httpPort string, builder *envoyconfig.Builder) (*Server, error) {
|
||||
wd := filepath.Join(os.TempDir(), workingDirectoryName)
|
||||
err := os.MkdirAll(wd, embeddedEnvoyPermissions)
|
||||
if err != nil {
|
||||
|
@ -76,7 +76,7 @@ func NewServer(src config.Source, grpcPort, httpPort string, builder *envoyconfi
|
|||
|
||||
envoyPath, err := extractEmbeddedEnvoy()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Send()
|
||||
log.Warn(ctx).Err(err).Send()
|
||||
envoyPath = "envoy"
|
||||
}
|
||||
|
||||
|
@ -98,7 +98,7 @@ func NewServer(src config.Source, grpcPort, httpPort string, builder *envoyconfi
|
|||
return nil, fmt.Errorf("invalid envoy binary, expected %s but got %s", Checksum, s)
|
||||
}
|
||||
} else {
|
||||
log.Info().Msg("no checksum defined, envoy binary will not be verified!")
|
||||
log.Info(ctx).Msg("no checksum defined, envoy binary will not be verified!")
|
||||
}
|
||||
|
||||
srv := &Server{
|
||||
|
@ -108,12 +108,12 @@ func NewServer(src config.Source, grpcPort, httpPort string, builder *envoyconfi
|
|||
httpPort: httpPort,
|
||||
envoyPath: envoyPath,
|
||||
}
|
||||
go srv.runProcessCollector()
|
||||
go srv.runProcessCollector(ctx)
|
||||
|
||||
src.OnConfigChange(srv.onConfigChange)
|
||||
srv.onConfigChange(src.GetConfig())
|
||||
src.OnConfigChange(ctx, srv.onConfigChange)
|
||||
srv.onConfigChange(ctx, src.GetConfig())
|
||||
|
||||
log.Info().
|
||||
log.Info(ctx).
|
||||
Str("path", envoyPath).
|
||||
Str("checksum", Checksum).
|
||||
Msg("running envoy")
|
||||
|
@ -130,7 +130,7 @@ func (srv *Server) Close() error {
|
|||
if srv.cmd != nil && srv.cmd.Process != nil {
|
||||
err = srv.cmd.Process.Kill()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("service", "envoy").Msg("envoy: failed to kill process on close")
|
||||
log.Error(context.TODO()).Err(err).Str("service", "envoy").Msg("envoy: failed to kill process on close")
|
||||
}
|
||||
srv.cmd = nil
|
||||
}
|
||||
|
@ -138,17 +138,17 @@ func (srv *Server) Close() error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (srv *Server) onConfigChange(cfg *config.Config) {
|
||||
srv.update(cfg)
|
||||
func (srv *Server) onConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
srv.update(ctx, cfg)
|
||||
}
|
||||
|
||||
func (srv *Server) update(cfg *config.Config) {
|
||||
func (srv *Server) update(ctx context.Context, cfg *config.Config) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
tracingOptions, err := config.NewTracingOptions(cfg.Options)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("service", "envoy").Msg("invalid tracing config")
|
||||
log.Error(ctx).Err(err).Str("service", "envoy").Msg("invalid tracing config")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -159,24 +159,24 @@ func (srv *Server) update(cfg *config.Config) {
|
|||
}
|
||||
|
||||
if cmp.Equal(srv.options, options, cmp.AllowUnexported(serverOptions{})) {
|
||||
log.Debug().Str("service", "envoy").Msg("envoy: no config changes detected")
|
||||
log.Debug(ctx).Str("service", "envoy").Msg("envoy: no config changes detected")
|
||||
return
|
||||
}
|
||||
srv.options = options
|
||||
|
||||
if err := srv.writeConfig(cfg); err != nil {
|
||||
log.Error().Err(err).Str("service", "envoy").Msg("envoy: failed to write envoy config")
|
||||
if err := srv.writeConfig(ctx, cfg); err != nil {
|
||||
log.Error(ctx).Err(err).Str("service", "envoy").Msg("envoy: failed to write envoy config")
|
||||
return
|
||||
}
|
||||
|
||||
log.Info().Msg("envoy: starting envoy process")
|
||||
if err := srv.run(); err != nil {
|
||||
log.Error().Err(err).Str("service", "envoy").Msg("envoy: failed to run envoy process")
|
||||
log.Info(ctx).Msg("envoy: starting envoy process")
|
||||
if err := srv.run(ctx); err != nil {
|
||||
log.Error(ctx).Err(err).Str("service", "envoy").Msg("envoy: failed to run envoy process")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) run() error {
|
||||
func (srv *Server) run(ctx context.Context) error {
|
||||
args := []string{
|
||||
"-c", configFileName,
|
||||
"--log-level", srv.options.logLevel,
|
||||
|
@ -198,13 +198,13 @@ func (srv *Server) run() error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("error creating stderr pipe for envoy: %w", err)
|
||||
}
|
||||
go srv.handleLogs(stderr)
|
||||
go srv.handleLogs(ctx, stderr)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating stderr pipe for envoy: %w", err)
|
||||
}
|
||||
go srv.handleLogs(stdout)
|
||||
go srv.handleLogs(ctx, stdout)
|
||||
|
||||
// make sure envoy is killed if we're killed
|
||||
cmd.SysProcAttr = sysProcAttr
|
||||
|
@ -216,10 +216,10 @@ func (srv *Server) run() error {
|
|||
|
||||
// release the previous process so we can hot-reload
|
||||
if srv.cmd != nil && srv.cmd.Process != nil {
|
||||
log.Info().Msg("envoy: releasing envoy process for hot-reload")
|
||||
log.Info(ctx).Msg("envoy: releasing envoy process for hot-reload")
|
||||
err := srv.cmd.Process.Release()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("service", "envoy").Msg("envoy: failed to release envoy process for hot-reload")
|
||||
log.Warn(ctx).Err(err).Str("service", "envoy").Msg("envoy: failed to release envoy process for hot-reload")
|
||||
}
|
||||
}
|
||||
srv.cmd = cmd
|
||||
|
@ -227,14 +227,14 @@ func (srv *Server) run() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (srv *Server) writeConfig(cfg *config.Config) error {
|
||||
func (srv *Server) writeConfig(ctx context.Context, cfg *config.Config) error {
|
||||
confBytes, err := srv.buildBootstrapConfig(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfgPath := filepath.Join(srv.wd, configFileName)
|
||||
log.Debug().Str("service", "envoy").Str("location", cfgPath).Msg("wrote config file to location")
|
||||
log.Debug(ctx).Str("service", "envoy").Str("location", cfgPath).Msg("wrote config file to location")
|
||||
|
||||
return atomic.WriteFile(cfgPath, bytes.NewReader(confBytes))
|
||||
}
|
||||
|
@ -313,9 +313,10 @@ func (srv *Server) parseLog(line string) (name string, logLevel string, msg stri
|
|||
return
|
||||
}
|
||||
|
||||
func (srv *Server) handleLogs(rc io.ReadCloser) {
|
||||
func (srv *Server) handleLogs(ctx context.Context, rc io.ReadCloser) {
|
||||
defer rc.Close()
|
||||
|
||||
l := log.With().Str("service", "envoy").Logger()
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
|
||||
s := bufio.NewReader(rc)
|
||||
|
@ -325,7 +326,7 @@ func (srv *Server) handleLogs(rc io.ReadCloser) {
|
|||
if errors.Is(err, io.EOF) || errors.Is(err, os.ErrClosed) {
|
||||
break
|
||||
}
|
||||
log.Error().Err(err).Msg("failed to read log")
|
||||
log.Error(ctx).Err(err).Msg("failed to read log")
|
||||
time.Sleep(bo.NextBackOff())
|
||||
continue
|
||||
}
|
||||
|
@ -336,7 +337,8 @@ func (srv *Server) handleLogs(rc io.ReadCloser) {
|
|||
if name == "" {
|
||||
name = "envoy"
|
||||
}
|
||||
lvl := zerolog.DebugLevel
|
||||
|
||||
lvl := zerolog.ErrorLevel
|
||||
if x, err := zerolog.ParseLevel(logLevel); err == nil {
|
||||
lvl = x
|
||||
}
|
||||
|
@ -354,14 +356,13 @@ func (srv *Server) handleLogs(rc io.ReadCloser) {
|
|||
continue
|
||||
}
|
||||
|
||||
log.WithLevel(lvl).
|
||||
Str("service", "envoy").
|
||||
l.WithLevel(lvl).
|
||||
Str("name", name).
|
||||
Msg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) runProcessCollector() {
|
||||
func (srv *Server) runProcessCollector(ctx context.Context) {
|
||||
// macos is not supported
|
||||
if runtime.GOOS != "linux" {
|
||||
return
|
||||
|
@ -369,7 +370,7 @@ func (srv *Server) runProcessCollector() {
|
|||
|
||||
pc := metrics.NewProcessCollector("envoy")
|
||||
if err := view.Register(pc.Views()...); err != nil {
|
||||
log.Error().Err(err).Msg("failed to register envoy process metric views")
|
||||
log.Error(ctx).Err(err).Msg("failed to register envoy process metric views")
|
||||
}
|
||||
|
||||
const collectInterval = time.Second * 10
|
||||
|
@ -387,7 +388,7 @@ func (srv *Server) runProcessCollector() {
|
|||
if pid > 0 {
|
||||
err := pc.Measure(context.Background(), pid)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to measure envoy process metrics")
|
||||
log.Error(ctx).Err(err).Msg("failed to measure envoy process metrics")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package envoy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
@ -43,6 +44,6 @@ func Benchmark_handleLogs(b *testing.B) {
|
|||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
srv.handleLogs(rc)
|
||||
srv.handleLogs(context.Background(), rc)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
package fileutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/rjeczalik/notify"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
|
@ -29,6 +31,10 @@ func (watcher *Watcher) Add(filePath string) {
|
|||
watcher.mu.Lock()
|
||||
defer watcher.mu.Unlock()
|
||||
|
||||
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("watch_file", filePath)
|
||||
})
|
||||
|
||||
// already watching
|
||||
if _, ok := watcher.filePaths[filePath]; ok {
|
||||
return
|
||||
|
@ -37,18 +43,18 @@ func (watcher *Watcher) Add(filePath string) {
|
|||
ch := make(chan notify.EventInfo, 1)
|
||||
go func() {
|
||||
for evt := range ch {
|
||||
log.Info().Str("path", evt.Path()).Str("event", evt.Event().String()).Msg("filemgr: detected file change")
|
||||
watcher.Signal.Broadcast()
|
||||
log.Info(ctx).Str("event", evt.Event().String()).Msg("filemgr: detected file change")
|
||||
watcher.Signal.Broadcast(ctx)
|
||||
}
|
||||
}()
|
||||
err := notify.Watch(filePath, ch, notify.All)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("path", filePath).Msg("filemgr: error watching file path")
|
||||
log.Error(ctx).Err(err).Msg("filemgr: error watching file path")
|
||||
notify.Stop(ch)
|
||||
close(ch)
|
||||
return
|
||||
}
|
||||
log.Debug().Str("path", filePath).Msg("filemgr: watching file for changes")
|
||||
log.Debug(ctx).Msg("filemgr: watching file for changes")
|
||||
|
||||
watcher.filePaths[filePath] = ch
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
package reproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"math/rand"
|
||||
|
@ -120,7 +121,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler {
|
|||
}
|
||||
|
||||
// Update updates the handler with new configuration.
|
||||
func (h *Handler) Update(cfg *config.Config) {
|
||||
func (h *Handler) Update(ctx context.Context, cfg *config.Config) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
|
@ -130,7 +131,7 @@ func (h *Handler) Update(cfg *config.Config) {
|
|||
for i, p := range cfg.Options.Policies {
|
||||
id, err := p.RouteID()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("reproxy: error getting route id")
|
||||
log.Warn(ctx).Err(err).Msg("reproxy: error getting route id")
|
||||
continue
|
||||
}
|
||||
h.policies[id] = &cfg.Options.Policies[i]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package reproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
@ -58,7 +59,7 @@ func TestMiddleware(t *testing.T) {
|
|||
}},
|
||||
},
|
||||
}
|
||||
h.Update(cfg)
|
||||
h.Update(context.Background(), cfg)
|
||||
|
||||
policyID, _ := cfg.Options.Policies[0].RouteID()
|
||||
|
||||
|
|
|
@ -97,8 +97,8 @@ func Shutdown(srv *http.Server) {
|
|||
rec := <-sigint
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
log.Info().Str("signal", rec.String()).Msg("internal/httputil: shutting down servers")
|
||||
log.Info(context.TODO()).Str("signal", rec.String()).Msg("internal/httputil: shutting down servers")
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("internal/httputil: shutdown failed")
|
||||
log.Error(context.TODO()).Err(err).Msg("internal/httputil: shutdown failed")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,7 +47,6 @@ type (
|
|||
// A Manager refreshes identity information using session and user data.
|
||||
type Manager struct {
|
||||
cfg *atomicConfig
|
||||
log zerolog.Logger
|
||||
|
||||
sessionScheduler *scheduler.Scheduler
|
||||
userScheduler *scheduler.Scheduler
|
||||
|
@ -68,7 +67,6 @@ func New(
|
|||
) *Manager {
|
||||
mgr := &Manager{
|
||||
cfg: newAtomicConfig(newConfig()),
|
||||
log: log.With().Str("service", "identity_manager").Logger(),
|
||||
|
||||
sessionScheduler: scheduler.New(),
|
||||
userScheduler: scheduler.New(),
|
||||
|
@ -79,6 +77,12 @@ func New(
|
|||
return mgr
|
||||
}
|
||||
|
||||
func withLog(ctx context.Context) context.Context {
|
||||
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("service", "identity_manager")
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateConfig updates the manager with the new options.
|
||||
func (mgr *Manager) UpdateConfig(options ...Option) {
|
||||
mgr.cfg.Store(newConfig(options...))
|
||||
|
@ -86,10 +90,11 @@ func (mgr *Manager) UpdateConfig(options ...Option) {
|
|||
|
||||
// Run runs the manager. This method blocks until an error occurs or the given context is canceled.
|
||||
func (mgr *Manager) Run(ctx context.Context) error {
|
||||
ctx = withLog(ctx)
|
||||
update := make(chan updateRecordsMessage, 1)
|
||||
clear := make(chan struct{}, 1)
|
||||
|
||||
syncer := newDataBrokerSyncer(mgr.cfg, mgr.log, update, clear)
|
||||
syncer := newDataBrokerSyncer(ctx, mgr.cfg, update, clear)
|
||||
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
|
@ -119,7 +124,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
|
|||
mgr.onUpdateRecords(ctx, msg)
|
||||
}
|
||||
|
||||
mgr.log.Info().
|
||||
log.Info(ctx).
|
||||
Int("directory_groups", len(mgr.directoryGroups)).
|
||||
Int("directory_users", len(mgr.directoryUsers)).
|
||||
Int("sessions", mgr.sessions.Len()).
|
||||
|
@ -196,7 +201,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
|
|||
}
|
||||
|
||||
func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) {
|
||||
mgr.log.Info().Msg("refreshing directory users")
|
||||
log.Info(ctx).Msg("refreshing directory users")
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.Load().groupRefreshTimeout)
|
||||
defer clearTimeout()
|
||||
|
@ -208,7 +213,7 @@ func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) {
|
|||
msg += ". You may need to increase the identity provider directory timeout setting"
|
||||
msg += "(https://www.pomerium.io/reference/#identity-provider-refresh-directory-settings)"
|
||||
}
|
||||
mgr.log.Warn().Err(err).Msg(msg)
|
||||
log.Warn(ctx).Err(err).Msg(msg)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -232,7 +237,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
|
|||
id := newDG.GetId()
|
||||
any, err := anypb.New(newDG)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
|
||||
log.Warn(ctx).Err(err).Msg("failed to marshal directory group")
|
||||
return
|
||||
}
|
||||
eg.Go(func() error {
|
||||
|
@ -262,7 +267,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
|
|||
id := curDG.GetId()
|
||||
any, err := anypb.New(curDG)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
|
||||
log.Warn(ctx).Err(err).Msg("failed to marshal directory group")
|
||||
return
|
||||
}
|
||||
eg.Go(func() error {
|
||||
|
@ -287,7 +292,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
|
|||
}
|
||||
|
||||
if err := eg.Wait(); err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("manager: failed to merge groups")
|
||||
log.Warn(ctx).Err(err).Msg("manager: failed to merge groups")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -305,7 +310,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
|
|||
id := newDU.GetId()
|
||||
any, err := anypb.New(newDU)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
|
||||
log.Warn(ctx).Err(err).Msg("failed to marshal directory user")
|
||||
return
|
||||
}
|
||||
eg.Go(func() error {
|
||||
|
@ -335,7 +340,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
|
|||
id := curDU.GetId()
|
||||
any, err := anypb.New(curDU)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
|
||||
log.Warn(ctx).Err(err).Msg("failed to marshal directory user")
|
||||
return
|
||||
}
|
||||
eg.Go(func() error {
|
||||
|
@ -361,19 +366,19 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
|
|||
}
|
||||
|
||||
if err := eg.Wait(); err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("manager: failed to merge users")
|
||||
log.Warn(ctx).Err(err).Msg("manager: failed to merge users")
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) {
|
||||
mgr.log.Info().
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
Msg("refreshing session")
|
||||
|
||||
s, ok := mgr.sessions.Get(userID, sessionID)
|
||||
if !ok {
|
||||
mgr.log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
Msg("no session found for refresh")
|
||||
|
@ -382,7 +387,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
|
||||
expiry := s.GetExpiresAt().AsTime()
|
||||
if !expiry.After(time.Now()) {
|
||||
mgr.log.Info().
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
Msg("deleting expired session")
|
||||
|
@ -391,7 +396,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
}
|
||||
|
||||
if s.Session == nil || s.Session.OauthToken == nil {
|
||||
mgr.log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
Msg("no session oauth2 token found for refresh")
|
||||
|
@ -400,13 +405,13 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
|
||||
newToken, err := mgr.cfg.Load().authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s)
|
||||
if isTemporaryError(err) {
|
||||
mgr.log.Error().Err(err).
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to refresh oauth2 token")
|
||||
return
|
||||
} else if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to refresh oauth2 token, deleting session")
|
||||
|
@ -417,13 +422,13 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
|
||||
err = mgr.cfg.Load().authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &s)
|
||||
if isTemporaryError(err) {
|
||||
mgr.log.Error().Err(err).
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user info")
|
||||
return
|
||||
} else if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user info, deleting session")
|
||||
|
@ -433,7 +438,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
|
||||
res, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update session")
|
||||
|
@ -444,13 +449,13 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
}
|
||||
|
||||
func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
||||
mgr.log.Info().
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
Msg("refreshing user")
|
||||
|
||||
u, ok := mgr.users.Get(userID)
|
||||
if !ok {
|
||||
mgr.log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("user_id", userID).
|
||||
Msg("no user found for refresh")
|
||||
return
|
||||
|
@ -460,7 +465,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
|
||||
for _, s := range mgr.sessions.GetSessionsForUser(userID) {
|
||||
if s.Session == nil || s.Session.OauthToken == nil {
|
||||
mgr.log.Warn().
|
||||
log.Warn(ctx).
|
||||
Str("user_id", userID).
|
||||
Msg("no session oauth2 token found for refresh")
|
||||
continue
|
||||
|
@ -468,13 +473,13 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
|
||||
err := mgr.cfg.Load().authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u)
|
||||
if isTemporaryError(err) {
|
||||
mgr.log.Error().Err(err).
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user info")
|
||||
return
|
||||
} else if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user info, deleting session")
|
||||
|
@ -484,7 +489,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
|
||||
record, err := user.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user")
|
||||
|
@ -502,7 +507,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
|
|||
var pbDirectoryGroup directory.Group
|
||||
err := record.GetData().UnmarshalTo(&pbDirectoryGroup)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Msgf("error unmarshaling directory group: %s", err)
|
||||
log.Warn(ctx).Msgf("error unmarshaling directory group: %s", err)
|
||||
continue
|
||||
}
|
||||
mgr.onUpdateDirectoryGroup(ctx, &pbDirectoryGroup)
|
||||
|
@ -510,7 +515,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
|
|||
var pbDirectoryUser directory.User
|
||||
err := record.GetData().UnmarshalTo(&pbDirectoryUser)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Msgf("error unmarshaling directory user: %s", err)
|
||||
log.Warn(ctx).Msgf("error unmarshaling directory user: %s", err)
|
||||
continue
|
||||
}
|
||||
mgr.onUpdateDirectoryUser(ctx, &pbDirectoryUser)
|
||||
|
@ -518,7 +523,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
|
|||
var pbSession session.Session
|
||||
err := record.GetData().UnmarshalTo(&pbSession)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Msgf("error unmarshaling session: %s", err)
|
||||
log.Warn(ctx).Msgf("error unmarshaling session: %s", err)
|
||||
continue
|
||||
}
|
||||
mgr.onUpdateSession(ctx, record, &pbSession)
|
||||
|
@ -526,7 +531,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
|
|||
var pbUser user.User
|
||||
err := record.GetData().UnmarshalTo(&pbUser)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Msgf("error unmarshaling user: %s", err)
|
||||
log.Warn(ctx).Msgf("error unmarshaling user: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
@ -578,7 +583,7 @@ func (mgr *Manager) onUpdateDirectoryGroup(_ context.Context, pbDirectoryGroup *
|
|||
func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Session) {
|
||||
err := session.Delete(ctx, mgr.cfg.Load().dataBrokerClient, pbSession.GetId())
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
log.Error(ctx).Err(err).
|
||||
Str("session_id", pbSession.GetId()).
|
||||
Msg("failed to delete session")
|
||||
}
|
||||
|
|
|
@ -3,14 +3,11 @@ package manager
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
type dataBrokerSyncer struct {
|
||||
cfg *atomicConfig
|
||||
log zerolog.Logger
|
||||
|
||||
update chan<- updateRecordsMessage
|
||||
clear chan<- struct{}
|
||||
|
@ -19,14 +16,13 @@ type dataBrokerSyncer struct {
|
|||
}
|
||||
|
||||
func newDataBrokerSyncer(
|
||||
ctx context.Context,
|
||||
cfg *atomicConfig,
|
||||
log zerolog.Logger,
|
||||
update chan<- updateRecordsMessage,
|
||||
clear chan<- struct{},
|
||||
) *dataBrokerSyncer {
|
||||
syncer := &dataBrokerSyncer{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
|
||||
update: update,
|
||||
clear: clear,
|
||||
|
|
|
@ -142,7 +142,7 @@ func (p *Provider) userEmail(ctx context.Context, t *oauth2.Token, v interface{}
|
|||
Email string `json:"email"`
|
||||
Verified bool `json:"email_verified"`
|
||||
}
|
||||
log.Debug().Interface("emails", response).Msg("github: user emails")
|
||||
log.Debug(ctx).Interface("emails", response).Msg("github: user emails")
|
||||
for _, email := range response {
|
||||
if email.Primary && email.Verified {
|
||||
out.Email = email.Email
|
||||
|
|
|
@ -88,35 +88,55 @@ func With() zerolog.Context {
|
|||
}
|
||||
|
||||
// Level creates a child logger with the minimum accepted level set to level.
|
||||
func Level(level zerolog.Level) zerolog.Logger {
|
||||
return Logger().Level(level)
|
||||
func Level(ctx context.Context, level zerolog.Level) *zerolog.Logger {
|
||||
l := contextLogger(ctx).Level(level)
|
||||
return &l
|
||||
}
|
||||
|
||||
// Debug starts a new message with debug level.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func Debug() *zerolog.Event {
|
||||
return Logger().Debug()
|
||||
func Debug(ctx context.Context) *zerolog.Event {
|
||||
return contextLogger(ctx).Debug()
|
||||
}
|
||||
|
||||
// Info starts a new message with info level.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func Info() *zerolog.Event {
|
||||
return Logger().Info()
|
||||
func Info(ctx context.Context) *zerolog.Event {
|
||||
return contextLogger(ctx).Info()
|
||||
}
|
||||
|
||||
// Warn starts a new message with warn level.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func Warn() *zerolog.Event {
|
||||
return Logger().Warn()
|
||||
func Warn(ctx context.Context) *zerolog.Event {
|
||||
return contextLogger(ctx).Warn()
|
||||
}
|
||||
|
||||
func contextLogger(ctx context.Context) *zerolog.Logger {
|
||||
global := Logger()
|
||||
if global.GetLevel() == zerolog.Disabled {
|
||||
return global
|
||||
}
|
||||
l := zerolog.Ctx(ctx)
|
||||
if l.GetLevel() == zerolog.Disabled { // no logger associated with context
|
||||
return global
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// WithContext returns a context that has an associated logger and extra fields set via update
|
||||
func WithContext(ctx context.Context, update func(c zerolog.Context) zerolog.Context) context.Context {
|
||||
l := contextLogger(ctx).With().Logger()
|
||||
l.UpdateContext(update)
|
||||
return l.WithContext(ctx)
|
||||
}
|
||||
|
||||
// Error starts a new message with error level.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func Error() *zerolog.Event {
|
||||
func Error(ctx context.Context) *zerolog.Event {
|
||||
return Logger().Error()
|
||||
}
|
||||
|
||||
|
@ -136,18 +156,11 @@ func Panic() *zerolog.Event {
|
|||
return Logger().Panic()
|
||||
}
|
||||
|
||||
// WithLevel starts a new message with level.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func WithLevel(level zerolog.Level) *zerolog.Event {
|
||||
return Logger().WithLevel(level)
|
||||
}
|
||||
|
||||
// Log starts a new message with no level. Setting zerolog.GlobalLevel to
|
||||
// zerolog.Disabled will still disable events produced by this method.
|
||||
//
|
||||
// You must call Msg on the returned event in order to send the event.
|
||||
func Log() *zerolog.Event {
|
||||
func Log(ctx context.Context) *zerolog.Event {
|
||||
return Logger().Log()
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package log_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"time"
|
||||
|
@ -54,7 +55,7 @@ func ExamplePrintf() {
|
|||
// Example of a log with no particular "level"
|
||||
func ExampleLog() {
|
||||
setup()
|
||||
log.Log().Msg("hello world")
|
||||
log.Log(context.Background()).Msg("hello world")
|
||||
|
||||
// Output: {"time":1199811905,"message":"hello world"}
|
||||
}
|
||||
|
@ -62,7 +63,7 @@ func ExampleLog() {
|
|||
// Example of a log at a particular "level" (in this case, "debug")
|
||||
func ExampleDebug() {
|
||||
setup()
|
||||
log.Debug().Msg("hello world")
|
||||
log.Debug(context.Background()).Msg("hello world")
|
||||
|
||||
// Output: {"level":"debug","time":1199811905,"message":"hello world"}
|
||||
}
|
||||
|
@ -70,7 +71,7 @@ func ExampleDebug() {
|
|||
// Example of a log at a particular "level" (in this case, "info")
|
||||
func ExampleInfo() {
|
||||
setup()
|
||||
log.Info().Msg("hello world")
|
||||
log.Info(context.Background()).Msg("hello world")
|
||||
|
||||
// Output: {"level":"info","time":1199811905,"message":"hello world"}
|
||||
}
|
||||
|
@ -78,7 +79,7 @@ func ExampleInfo() {
|
|||
// Example of a log at a particular "level" (in this case, "warn")
|
||||
func ExampleWarn() {
|
||||
setup()
|
||||
log.Warn().Msg("hello world")
|
||||
log.Warn(context.Background()).Msg("hello world")
|
||||
|
||||
// Output: {"level":"warn","time":1199811905,"message":"hello world"}
|
||||
}
|
||||
|
@ -86,7 +87,7 @@ func ExampleWarn() {
|
|||
// Example of a log at a particular "level" (in this case, "error")
|
||||
func ExampleError() {
|
||||
setup()
|
||||
log.Error().Msg("hello world")
|
||||
log.Error(context.Background()).Msg("hello world")
|
||||
|
||||
// Output: {"level":"error","time":1199811905,"message":"hello world"}
|
||||
}
|
||||
|
@ -119,10 +120,10 @@ func Example() {
|
|||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
}
|
||||
|
||||
log.Debug().Msg("This message appears only when log level set to Debug")
|
||||
log.Info().Msg("This message appears when log level set to Debug or Info")
|
||||
log.Debug(context.Background()).Msg("This message appears only when log level set to Debug")
|
||||
log.Info(context.Background()).Msg("This message appears when log level set to Debug or Info")
|
||||
|
||||
if e := log.Debug(); e.Enabled() {
|
||||
if e := log.Debug(context.Background()); e.Enabled() {
|
||||
// Compute log output only if enabled.
|
||||
value := "bar"
|
||||
e.Str("foo", value).Msg("some debug message")
|
||||
|
@ -134,19 +135,19 @@ func Example() {
|
|||
func ExampleSetLevel() {
|
||||
setup()
|
||||
log.SetLevel("info")
|
||||
log.Debug().Msg("Debug")
|
||||
log.Info().Msg("Debug or Info")
|
||||
log.Debug(context.Background()).Msg("Debug")
|
||||
log.Info(context.Background()).Msg("Debug or Info")
|
||||
log.SetLevel("warn")
|
||||
log.Debug().Msg("Debug")
|
||||
log.Info().Msg("Debug or Info")
|
||||
log.Warn().Msg("Debug or Info or Warn")
|
||||
log.Debug(context.Background()).Msg("Debug")
|
||||
log.Info(context.Background()).Msg("Debug or Info")
|
||||
log.Warn(context.Background()).Msg("Debug or Info or Warn")
|
||||
log.SetLevel("error")
|
||||
log.Debug().Msg("Debug")
|
||||
log.Info().Msg("Debug or Info")
|
||||
log.Warn().Msg("Debug or Info or Warn")
|
||||
log.Error().Msg("Debug or Info or Warn or Error")
|
||||
log.Debug(context.Background()).Msg("Debug")
|
||||
log.Info(context.Background()).Msg("Debug or Info")
|
||||
log.Warn(context.Background()).Msg("Debug or Info or Warn")
|
||||
log.Error(context.Background()).Msg("Debug or Info or Warn or Error")
|
||||
log.SetLevel("default-fall-through")
|
||||
log.Debug().Msg("Debug")
|
||||
log.Debug(context.Background()).Msg("Debug")
|
||||
|
||||
// Output:
|
||||
// {"level":"info","time":1199811905,"message":"Debug or Info"}
|
||||
|
@ -154,3 +155,32 @@ func ExampleSetLevel() {
|
|||
// {"level":"error","time":1199811905,"message":"Debug or Info or Warn or Error"}
|
||||
// {"level":"debug","time":1199811905,"message":"Debug"}
|
||||
}
|
||||
|
||||
func ExampleContext() {
|
||||
setup()
|
||||
|
||||
bg := context.Background()
|
||||
ctx1 := log.WithContext(bg, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("param_one", "one")
|
||||
})
|
||||
ctx2 := log.WithContext(ctx1, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("param_two", "two")
|
||||
})
|
||||
|
||||
log.Warn(bg).Str("non_context_param", "value").Msg("background")
|
||||
log.Warn(ctx1).Str("non_context_param", "value").Msg("first")
|
||||
log.Warn(ctx2).Str("non_context_param", "value").Msg("second")
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
ctx1 = log.WithContext(ctx1, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Int("counter", i)
|
||||
})
|
||||
}
|
||||
log.Info(ctx1).Str("non_ctx_param", "value").Msg("after counter")
|
||||
|
||||
/*
|
||||
{"level":"warn","ctx":"one","param":"first","time":1199811905,"message":"first"}
|
||||
{"level":"warn","ctx":"two","param":"second","time":1199811905,"message":"second"}
|
||||
{"level":"warn","param":"third","time":1199811905,"message":"third"}
|
||||
*/
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ func (s *inMemoryServer) periodicCheck(ctx context.Context) {
|
|||
return
|
||||
case <-time.After(after):
|
||||
if s.lockAndRmExpired() {
|
||||
s.onchange.Broadcast()
|
||||
s.onchange.Broadcast(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ func (s *inMemoryServer) Report(ctx context.Context, req *pb.RegisterRequest) (*
|
|||
}
|
||||
|
||||
if updated {
|
||||
s.onchange.Broadcast()
|
||||
s.onchange.Broadcast(ctx)
|
||||
}
|
||||
|
||||
return &pb.RegisterResponse{
|
||||
|
@ -171,7 +171,7 @@ func (s *inMemoryServer) Watch(req *pb.ListRequest, srv pb.Registry_WatchServer)
|
|||
}
|
||||
}
|
||||
|
||||
func (s *inMemoryServer) getServiceUpdates(ctx context.Context, kinds map[pb.ServiceKind]bool, updates chan struct{}) ([]*pb.Service, error) {
|
||||
func (s *inMemoryServer) getServiceUpdates(ctx context.Context, kinds map[pb.ServiceKind]bool, updates chan context.Context) ([]*pb.Service, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
|
|
|
@ -23,25 +23,25 @@ type Reporter struct {
|
|||
}
|
||||
|
||||
// OnConfigChange applies configuration changes to the reporter
|
||||
func (r *Reporter) OnConfigChange(cfg *config.Config) {
|
||||
func (r *Reporter) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
if r.cancel != nil {
|
||||
r.cancel()
|
||||
}
|
||||
|
||||
services, err := getReportedServices(cfg)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("metrics announce to service registry is disabled")
|
||||
log.Warn(ctx).Err(err).Msg("metrics announce to service registry is disabled")
|
||||
}
|
||||
|
||||
sharedKey, err := base64.StdEncoding.DecodeString(cfg.Options.SharedKey)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("decoding shared key")
|
||||
log.Error(ctx).Err(err).Msg("decoding shared key")
|
||||
return
|
||||
}
|
||||
|
||||
urls, err := cfg.Options.GetDataBrokerURLs()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("invalid databroker urls")
|
||||
log.Error(ctx).Err(err).Msg("invalid databroker urls")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -58,7 +58,7 @@ func (r *Reporter) OnConfigChange(cfg *config.Config) {
|
|||
SignedJWTKey: sharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("connecting to registry")
|
||||
log.Error(ctx).Err(err).Msg("connecting to registry")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -145,7 +145,7 @@ func runReporter(
|
|||
after = resp.CallBackAfter.AsDuration()
|
||||
backoff.Reset()
|
||||
case <-ctx.Done():
|
||||
log.Info().Msg("service registry reporter stopping")
|
||||
log.Info(ctx).Msg("service registry reporter stopping")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,28 +2,29 @@
|
|||
package signal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A Signal is used to let multiple listeners know when something happened.
|
||||
type Signal struct {
|
||||
mu sync.Mutex
|
||||
chs map[chan struct{}]struct{}
|
||||
chs map[chan context.Context]struct{}
|
||||
}
|
||||
|
||||
// New creates a new Signal.
|
||||
func New() *Signal {
|
||||
return &Signal{
|
||||
chs: make(map[chan struct{}]struct{}),
|
||||
chs: make(map[chan context.Context]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast signals all the listeners. Broadcast never blocks.
|
||||
func (s *Signal) Broadcast() {
|
||||
func (s *Signal) Broadcast(ctx context.Context) {
|
||||
s.mu.Lock()
|
||||
for ch := range s.chs {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
case ch <- ctx:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
@ -32,8 +33,8 @@ func (s *Signal) Broadcast() {
|
|||
|
||||
// Bind creates a new listening channel bound to the signal. The channel used has a size of 1
|
||||
// and any given broadcast will signal at least one event, but may signal more than one.
|
||||
func (s *Signal) Bind() chan struct{} {
|
||||
ch := make(chan struct{}, 1)
|
||||
func (s *Signal) Bind() chan context.Context {
|
||||
ch := make(chan context.Context, 1)
|
||||
s.mu.Lock()
|
||||
s.chs[ch] = struct{}{}
|
||||
s.mu.Unlock()
|
||||
|
@ -41,7 +42,7 @@ func (s *Signal) Bind() chan struct{} {
|
|||
}
|
||||
|
||||
// Unbind stops the listening channel bound to the signal.
|
||||
func (s *Signal) Unbind(ch chan struct{}) {
|
||||
func (s *Signal) Unbind(ch chan context.Context) {
|
||||
s.mu.Lock()
|
||||
delete(s.chs, ch)
|
||||
s.mu.Unlock()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package tcptunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cliutil"
|
||||
|
@ -19,7 +20,7 @@ func getConfig(options ...Option) *config {
|
|||
if jwtCache, err := cliutil.NewLocalJWTCache(); err == nil {
|
||||
WithJWTCache(jwtCache)(cfg)
|
||||
} else {
|
||||
log.Error().Err(err).Msg("tcptunnel: error creating local JWT cache, using in-memory JWT cache")
|
||||
log.Error(context.TODO()).Err(err).Msg("tcptunnel: error creating local JWT cache, using in-memory JWT cache")
|
||||
WithJWTCache(cliutil.NewMemoryJWTCache())(cfg)
|
||||
}
|
||||
for _, o := range options {
|
||||
|
|
|
@ -43,7 +43,7 @@ func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) erro
|
|||
return err
|
||||
}
|
||||
defer func() { _ = li.Close() }()
|
||||
log.Info().Msg("tcptunnel: listening on " + li.Addr().String())
|
||||
log.Info(ctx).Msg("tcptunnel: listening on " + li.Addr().String())
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
@ -62,7 +62,7 @@ func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) erro
|
|||
}
|
||||
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
|
||||
log.Warn().Err(err).Msg("tcptunnel: temporarily failed to accept local connection")
|
||||
log.Warn(ctx).Err(err).Msg("tcptunnel: temporarily failed to accept local connection")
|
||||
select {
|
||||
case <-time.After(bo.NextBackOff()):
|
||||
case <-ctx.Done():
|
||||
|
@ -79,7 +79,7 @@ func (tun *Tunnel) RunListener(ctx context.Context, listenerAddress string) erro
|
|||
|
||||
err := tun.Run(ctx, conn)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("tcptunnel: error serving local connection")
|
||||
log.Error(ctx).Err(err).Msg("tcptunnel: error serving local connection")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ func (tun *Tunnel) Run(ctx context.Context, local io.ReadWriter) error {
|
|||
}
|
||||
|
||||
func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string, retryCount int) error {
|
||||
log.Info().
|
||||
log.Info(ctx).
|
||||
Str("dst", tun.cfg.dstHost).
|
||||
Str("proxy", tun.cfg.proxyHost).
|
||||
Bool("secure", tun.cfg.tlsConfig != nil).
|
||||
|
@ -132,7 +132,7 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string,
|
|||
}
|
||||
defer func() {
|
||||
_ = remote.Close()
|
||||
log.Info().Msg("tcptunnel: connection closed")
|
||||
log.Info(ctx).Msg("tcptunnel: connection closed")
|
||||
}()
|
||||
if done := ctx.Done(); done != nil {
|
||||
go func() {
|
||||
|
@ -189,7 +189,7 @@ func (tun *Tunnel) run(ctx context.Context, local io.ReadWriter, rawJWT string,
|
|||
return fmt.Errorf("tcptunnel: invalid http response code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
log.Info().Msg("tcptunnel: connection established")
|
||||
log.Info(ctx).Msg("tcptunnel: connection established")
|
||||
|
||||
errc := make(chan error, 2)
|
||||
go func() {
|
||||
|
|
|
@ -142,7 +142,7 @@ func GRPCClientInterceptor(service string) grpc.UnaryClientInterceptor {
|
|||
tag.Upsert(TagKeyGRPCService, rpcService),
|
||||
)
|
||||
if tagErr != nil {
|
||||
log.Warn().Err(tagErr).Str("context", "GRPCClientInterceptor").Msg("telemetry/metrics: failed to create context")
|
||||
log.Warn(ctx).Err(tagErr).Str("context", "GRPCClientInterceptor").Msg("telemetry/metrics: failed to create context")
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
|
||||
|
@ -181,7 +181,7 @@ func (h *GRPCServerMetricsHandler) TagRPC(ctx context.Context, tagInfo *grpcstat
|
|||
tag.Upsert(TagKeyGRPCService, rpcService),
|
||||
)
|
||||
if tagErr != nil {
|
||||
log.Warn().Err(tagErr).Str("context", "GRPCServerStatsHandler").Msg("telemetry/metrics: failed to create context")
|
||||
log.Warn(ctx).Err(tagErr).Str("context", "GRPCServerStatsHandler").Msg("telemetry/metrics: failed to create context")
|
||||
return ctx
|
||||
|
||||
}
|
||||
|
|
|
@ -120,7 +120,7 @@ func HTTPMetricsHandler(getInstallationID func() string, service string) func(ne
|
|||
tag.Upsert(TagKeyHTTPMethod, r.Method),
|
||||
)
|
||||
if tagErr != nil {
|
||||
log.Warn().Err(tagErr).Str("context", "HTTPMetricsHandler").Msg("telemetry/metrics: failed to create metrics tag")
|
||||
log.Warn(ctx).Err(tagErr).Str("context", "HTTPMetricsHandler").Msg("telemetry/metrics: failed to create metrics tag")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
@ -148,7 +148,7 @@ func HTTPMetricsRoundTripper(getInstallationID func() string, service string, de
|
|||
tag.Upsert(TagKeyDestination, destination),
|
||||
)
|
||||
if tagErr != nil {
|
||||
log.Warn().Err(tagErr).Str("context", "HTTPMetricsRoundTripper").Msg("telemetry/metrics: failed to create metrics tag")
|
||||
log.Warn(ctx).Err(tagErr).Str("context", "HTTPMetricsRoundTripper").Msg("telemetry/metrics: failed to create metrics tag")
|
||||
return next.RoundTrip(r)
|
||||
}
|
||||
|
||||
|
|
|
@ -104,8 +104,8 @@ func RecordIdentityManagerLastRefresh() {
|
|||
|
||||
// SetDBConfigInfo records status, databroker version and error count while parsing
|
||||
// the configuration from a databroker
|
||||
func SetDBConfigInfo(service, configID string, version uint64, errCount int64) {
|
||||
log.Info().
|
||||
func SetDBConfigInfo(ctx context.Context, service, configID string, version uint64, errCount int64) {
|
||||
log.Info(ctx).
|
||||
Str("service", service).
|
||||
Str("config_id", configID).
|
||||
Uint64("version", version).
|
||||
|
@ -113,14 +113,14 @@ func SetDBConfigInfo(service, configID string, version uint64, errCount int64) {
|
|||
Msg("set db config info")
|
||||
|
||||
if err := stats.RecordWithTags(
|
||||
context.Background(),
|
||||
ctx,
|
||||
[]tag.Mutator{
|
||||
tag.Insert(TagKeyService, service),
|
||||
tag.Insert(TagConfigID, configID),
|
||||
},
|
||||
configDBVersion.M(int64(version)),
|
||||
); err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: failed to record config version number")
|
||||
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to record config version number")
|
||||
}
|
||||
|
||||
if err := stats.RecordWithTags(
|
||||
|
@ -131,20 +131,20 @@ func SetDBConfigInfo(service, configID string, version uint64, errCount int64) {
|
|||
},
|
||||
configDBErrors.M(errCount),
|
||||
); err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: failed to record config error count")
|
||||
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to record config error count")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// SetDBConfigRejected records that a certain databroker config version has been rejected
|
||||
func SetDBConfigRejected(service, configID string, version uint64, err error) {
|
||||
log.Warn().Err(err).Msg("databroker: invalid config detected, ignoring")
|
||||
SetDBConfigInfo(service, configID, version, -1)
|
||||
func SetDBConfigRejected(ctx context.Context, service, configID string, version uint64, err error) {
|
||||
log.Warn(ctx).Err(err).Msg("databroker: invalid config detected, ignoring")
|
||||
SetDBConfigInfo(ctx, service, configID, version, -1)
|
||||
}
|
||||
|
||||
// SetConfigInfo records the status, checksum and timestamp of a configuration
|
||||
// reload. You must register InfoViews or the related config views before calling
|
||||
func SetConfigInfo(service, configName string, checksum uint64, success bool) {
|
||||
func SetConfigInfo(ctx context.Context, service, configName string, checksum uint64, success bool) {
|
||||
if success {
|
||||
registry.setConfigChecksum(service, configName, checksum)
|
||||
|
||||
|
@ -154,7 +154,7 @@ func SetConfigInfo(service, configName string, checksum uint64, success bool) {
|
|||
[]tag.Mutator{serviceTag},
|
||||
configLastReload.M(time.Now().Unix()),
|
||||
); err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: failed to record config checksum timestamp")
|
||||
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to record config checksum timestamp")
|
||||
}
|
||||
|
||||
if err := stats.RecordWithTags(
|
||||
|
@ -162,12 +162,12 @@ func SetConfigInfo(service, configName string, checksum uint64, success bool) {
|
|||
[]tag.Mutator{serviceTag},
|
||||
configLastReloadSuccess.M(1),
|
||||
); err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: failed to record config reload")
|
||||
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to record config reload")
|
||||
}
|
||||
} else {
|
||||
stats.Record(context.Background(), configLastReloadSuccess.M(0))
|
||||
}
|
||||
log.Info().
|
||||
log.Info(ctx).
|
||||
Str("service", service).
|
||||
Str("config", configName).
|
||||
Str("checksum", fmt.Sprintf("%x", checksum)).
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
@ -28,7 +29,7 @@ func Test_SetConfigInfo(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
view.Unregister(InfoViews...)
|
||||
view.Register(InfoViews...)
|
||||
SetConfigInfo("test_service", "test config", 0, tt.success)
|
||||
SetConfigInfo(context.Background(), "test_service", "test config", 0, tt.success)
|
||||
|
||||
testDataRetrieval(ConfigLastReloadView, t, tt.wantLastReload)
|
||||
testDataRetrieval(ConfigLastReloadSuccessView, t, tt.wantLastReloadSuccess)
|
||||
|
@ -55,7 +56,7 @@ func Test_SetDBConfigInfo(t *testing.T) {
|
|||
t.Run(fmt.Sprintf("version=%d errors=%d", tt.version, tt.errCount), func(t *testing.T) {
|
||||
view.Unregister(InfoViews...)
|
||||
view.Register(InfoViews...)
|
||||
SetDBConfigInfo("test_service", "test_config", tt.version, tt.errCount)
|
||||
SetDBConfigInfo(context.TODO(), "test_service", "test_config", tt.version, tt.errCount)
|
||||
|
||||
testDataRetrieval(ConfigDBVersionView, t, tt.wantVersion)
|
||||
testDataRetrieval(ConfigDBErrorsView, t, tt.wantErrors)
|
||||
|
|
|
@ -89,26 +89,26 @@ func newProxyMetricsHandler(exporter *ocprom.Exporter, envoyURL url.URL, install
|
|||
|
||||
err := writeMetricsWithInstallationID(w, rec.Body, installationID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Send()
|
||||
log.Error(r.Context()).Err(err).Send()
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(r.Context(), "GET", envoyURL.String(), nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: failed to create request for envoy")
|
||||
log.Error(r.Context()).Err(err).Msg("telemetry/metrics: failed to create request for envoy")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: fail to fetch proxy metrics")
|
||||
log.Error(r.Context()).Err(err).Msg("telemetry/metrics: fail to fetch proxy metrics")
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = writeMetricsWithInstallationID(w, resp.Body, installationID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Send()
|
||||
log.Error(r.Context()).Err(err).Send()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
|
@ -34,6 +35,7 @@ func newMetricRegistry() *metricRegistry {
|
|||
}
|
||||
|
||||
func (r *metricRegistry) init() {
|
||||
ctx := context.TODO()
|
||||
r.Do(
|
||||
func() {
|
||||
r.registry = metric.NewRegistry()
|
||||
|
@ -49,7 +51,7 @@ func (r *metricRegistry) init() {
|
|||
),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: failed to register build info metric")
|
||||
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to register build info metric")
|
||||
}
|
||||
|
||||
r.configChecksum, err = r.registry.AddFloat64Gauge(metrics.ConfigChecksumDecimal,
|
||||
|
@ -57,7 +59,7 @@ func (r *metricRegistry) init() {
|
|||
metric.WithLabelKeys(metrics.ServiceLabel, metrics.ConfigLabel),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: failed to register config checksum metric")
|
||||
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to register config checksum metric")
|
||||
}
|
||||
|
||||
r.policyCount, err = r.registry.AddInt64DerivedGauge(metrics.PolicyCountTotal,
|
||||
|
@ -65,12 +67,12 @@ func (r *metricRegistry) init() {
|
|||
metric.WithLabelKeys(metrics.ServiceLabel),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: failed to register policy count metric")
|
||||
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to register policy count metric")
|
||||
}
|
||||
|
||||
err = registerAutocertMetrics(r.registry)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: failed to register autocert metrics")
|
||||
log.Error(ctx).Err(err).Msg("telemetry/metrics: failed to register autocert metrics")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -89,7 +91,7 @@ func (r *metricRegistry) setBuildInfo(service, hostname string) {
|
|||
metricdata.NewLabelValue(hostname),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: failed to get build info metric")
|
||||
log.Error(context.TODO()).Err(err).Msg("telemetry/metrics: failed to get build info metric")
|
||||
}
|
||||
|
||||
// This sets our build_info metric to a constant 1 per
|
||||
|
@ -103,7 +105,7 @@ func (r *metricRegistry) addPolicyCountCallback(service string, f func() int64)
|
|||
}
|
||||
err := r.policyCount.UpsertEntry(f, metricdata.NewLabelValue(service))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: failed to get policy count metric")
|
||||
log.Error(context.TODO()).Err(err).Msg("telemetry/metrics: failed to get policy count metric")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -113,7 +115,7 @@ func (r *metricRegistry) setConfigChecksum(service string, configName string, ch
|
|||
}
|
||||
m, err := r.configChecksum.GetEntry(metricdata.NewLabelValue(service), metricdata.NewLabelValue(configName))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("telemetry/metrics: failed to get config checksum metric")
|
||||
log.Error(context.TODO()).Err(err).Msg("telemetry/metrics: failed to get config checksum metric")
|
||||
}
|
||||
m.Set(float64(checksum))
|
||||
}
|
||||
|
@ -122,13 +124,13 @@ func (r *metricRegistry) addInt64DerivedGaugeMetric(name, desc, service string,
|
|||
m, err := r.registry.AddInt64DerivedGauge(name, metric.WithDescription(desc),
|
||||
metric.WithLabelKeys(metrics.ServiceLabel))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("service", service).Msg("telemetry/metrics: failed to register metric")
|
||||
log.Error(context.TODO()).Err(err).Str("service", service).Msg("telemetry/metrics: failed to register metric")
|
||||
return
|
||||
}
|
||||
|
||||
err = m.UpsertEntry(f, metricdata.NewLabelValue(service))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("service", service).Msg("telemetry/metrics: failed to update metric")
|
||||
log.Error(context.TODO()).Err(err).Str("service", service).Msg("telemetry/metrics: failed to update metric")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -137,13 +139,13 @@ func (r *metricRegistry) addInt64DerivedCumulativeMetric(name, desc, service str
|
|||
m, err := r.registry.AddInt64DerivedCumulative(name, metric.WithDescription(desc),
|
||||
metric.WithLabelKeys(metrics.ServiceLabel))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("service", service).Msg("telemetry/metrics: failed to register metric")
|
||||
log.Error(context.TODO()).Err(err).Str("service", service).Msg("telemetry/metrics: failed to register metric")
|
||||
return
|
||||
}
|
||||
|
||||
err = m.UpsertEntry(f, metricdata.NewLabelValue(service))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("service", service).Msg("telemetry/metrics: failed to update metric")
|
||||
log.Error(context.TODO()).Err(err).Str("service", service).Msg("telemetry/metrics: failed to update metric")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,6 +57,6 @@ func RecordStorageOperation(ctx context.Context, tags *StorageOperationTags, dur
|
|||
storageOperationDuration.M(duration.Milliseconds()),
|
||||
)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("internal/telemetry/metrics: failed to record")
|
||||
log.Warn(ctx).Err(err).Msg("internal/telemetry/metrics: failed to record")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -77,7 +77,7 @@ func RegisterTracing(opts *TracingOptions) (trace.Exporter, error) {
|
|||
}
|
||||
trace.ApplyConfig(trace.Config{DefaultSampler: trace.ProbabilitySampler(opts.SampleRate)})
|
||||
|
||||
log.Debug().Interface("Opts", opts).Msg("telemetry/trace: exporter created")
|
||||
log.Debug(context.TODO()).Interface("Opts", opts).Msg("telemetry/trace: exporter created")
|
||||
return exporter, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package cryptutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
|
@ -14,9 +15,10 @@ import (
|
|||
|
||||
// GetCertPool gets a cert pool for the given CA or CAFile.
|
||||
func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
|
||||
ctx := context.TODO()
|
||||
rootCAs, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one")
|
||||
log.Error(ctx).Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one")
|
||||
rootCAs = x509.NewCertPool()
|
||||
}
|
||||
if ca == "" && caFile == "" {
|
||||
|
@ -38,7 +40,7 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
|
|||
if ok := rootCAs.AppendCertsFromPEM(data); !ok {
|
||||
return nil, fmt.Errorf("failed to append any PEM-encoded certificates")
|
||||
}
|
||||
log.Debug().Msg("pkg/cryptutil: added custom certificate authority")
|
||||
log.Debug(ctx).Msg("pkg/cryptutil: added custom certificate authority")
|
||||
return rootCAs, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ type Options struct {
|
|||
|
||||
// NewGRPCClientConn returns a new gRPC pomerium service client connection.
|
||||
func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
|
||||
ctx := context.TODO()
|
||||
if len(opts.Addrs) == 0 {
|
||||
return nil, errors.New("internal/grpc: connection address required")
|
||||
}
|
||||
|
@ -105,7 +106,7 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
|
|||
}
|
||||
|
||||
if opts.WithInsecure {
|
||||
log.Info().Str("addr", connAddr).Msg("internal/grpc: grpc with insecure")
|
||||
log.Info(ctx).Str("addr", connAddr).Msg("internal/grpc: grpc with insecure")
|
||||
dialOptions = append(dialOptions, grpc.WithInsecure())
|
||||
} else {
|
||||
rootCAs, err := cryptutil.GetCertPool(opts.CA, opts.CAFile)
|
||||
|
@ -117,7 +118,7 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
|
|||
|
||||
// override allowed certificate name string, typically used when doing behind ingress connection
|
||||
if opts.OverrideCertificateName != "" {
|
||||
log.Debug().Str("cert-override-name", opts.OverrideCertificateName).Msg("internal/grpc: grpc")
|
||||
log.Debug(ctx).Str("cert-override-name", opts.OverrideCertificateName).Msg("internal/grpc: grpc")
|
||||
err := cert.OverrideServerName(opts.OverrideCertificateName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -169,7 +170,7 @@ func GetGRPCClientConn(name string, opts *Options) (*grpc.ClientConn, error) {
|
|||
|
||||
err := current.conn.Close()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("grpc: failed to close existing connection")
|
||||
log.Error(context.TODO()).Err(err).Msg("grpc: failed to close existing connection")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ func (syncer *Syncer) Run(ctx context.Context) error {
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
syncer.log().Error().Err(err).Msg("sync")
|
||||
log.Error(syncer.logCtx(ctx)).Err(err).Msg("sync")
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
|
@ -112,42 +112,42 @@ func (syncer *Syncer) Run(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (syncer *Syncer) init(ctx context.Context) error {
|
||||
syncer.log().Info().Msg("initial sync")
|
||||
records, recordVersion, serverVersion, err := InitialSync(ctx, syncer.handler.GetDataBrokerServiceClient(), &SyncLatestRequest{
|
||||
log.Info(syncer.logCtx(ctx)).Msg("initial sync")
|
||||
records, recordVersion, serverVersion, err := InitialSync(syncer.logCtx(ctx), syncer.handler.GetDataBrokerServiceClient(), &SyncLatestRequest{
|
||||
Type: syncer.cfg.typeURL,
|
||||
})
|
||||
if err != nil {
|
||||
syncer.log().Error().Err(err).Msg("error during initial sync")
|
||||
log.Error(syncer.logCtx(ctx)).Err(err).Msg("error during initial sync")
|
||||
return err
|
||||
}
|
||||
syncer.backoff.Reset()
|
||||
|
||||
// reset the records as we have to sync latest
|
||||
syncer.handler.ClearRecords(ctx)
|
||||
syncer.handler.ClearRecords(syncer.logCtx(ctx))
|
||||
|
||||
syncer.recordVersion = recordVersion
|
||||
syncer.serverVersion = serverVersion
|
||||
syncer.handler.UpdateRecords(ctx, serverVersion, records)
|
||||
syncer.handler.UpdateRecords(syncer.logCtx(ctx), serverVersion, records)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (syncer *Syncer) sync(ctx context.Context) error {
|
||||
stream, err := syncer.handler.GetDataBrokerServiceClient().Sync(ctx, &SyncRequest{
|
||||
stream, err := syncer.handler.GetDataBrokerServiceClient().Sync(syncer.logCtx(ctx), &SyncRequest{
|
||||
ServerVersion: syncer.serverVersion,
|
||||
RecordVersion: syncer.recordVersion,
|
||||
})
|
||||
if err != nil {
|
||||
syncer.log().Error().Err(err).Msg("error during sync")
|
||||
log.Error(syncer.logCtx(ctx)).Err(err).Msg("error during sync")
|
||||
return err
|
||||
}
|
||||
|
||||
syncer.log().Info().Msg("listening for updates")
|
||||
log.Info(syncer.logCtx(ctx)).Msg("listening for updates")
|
||||
|
||||
for {
|
||||
res, err := stream.Recv()
|
||||
if status.Code(err) == codes.Aborted {
|
||||
syncer.log().Error().Err(err).Msg("aborted sync due to mismatched server version")
|
||||
log.Error(syncer.logCtx(ctx)).Err(err).Msg("aborted sync due to mismatched server version")
|
||||
// server version changed, so re-init
|
||||
syncer.serverVersion = 0
|
||||
return nil
|
||||
|
@ -155,14 +155,13 @@ func (syncer *Syncer) sync(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
syncer.log().Debug().
|
||||
log.Debug(syncer.logCtx(ctx)).
|
||||
Uint("version", uint(res.Record.GetVersion())).
|
||||
Str("type", res.Record.Type).
|
||||
Str("id", res.Record.Id).
|
||||
Msg("syncer got record")
|
||||
|
||||
if syncer.recordVersion != res.GetRecord().GetVersion()-1 {
|
||||
syncer.log().Error().Err(err).
|
||||
log.Error(syncer.logCtx(ctx)).Err(err).
|
||||
Uint64("received", res.GetRecord().GetVersion()).
|
||||
Msg("aborted sync due to missing record")
|
||||
syncer.serverVersion = 0
|
||||
|
@ -170,15 +169,17 @@ func (syncer *Syncer) sync(ctx context.Context) error {
|
|||
}
|
||||
syncer.recordVersion = res.GetRecord().GetVersion()
|
||||
if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() {
|
||||
syncer.handler.UpdateRecords(ctx, syncer.serverVersion, []*Record{res.GetRecord()})
|
||||
syncer.handler.UpdateRecords(syncer.logCtx(ctx), syncer.serverVersion, []*Record{res.GetRecord()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (syncer *Syncer) log() *zerolog.Logger {
|
||||
l := log.With().Str("syncer_id", syncer.id).
|
||||
Str("type", syncer.cfg.typeURL).
|
||||
Uint64("server_version", syncer.serverVersion).
|
||||
Uint64("record_version", syncer.recordVersion).Logger()
|
||||
return &l
|
||||
// logCtx adds log params to context which
|
||||
func (syncer *Syncer) logCtx(ctx context.Context) context.Context {
|
||||
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("syncer_id", syncer.id).
|
||||
Str("type", syncer.cfg.typeURL).
|
||||
Uint64("server_version", syncer.serverVersion).
|
||||
Uint64("record_version", syncer.recordVersion)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -9,9 +9,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/btree"
|
||||
"github.com/rs/zerolog"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
|
@ -141,14 +143,20 @@ func (backend *Backend) GetAll(_ context.Context) ([]*databroker.Record, uint64,
|
|||
}
|
||||
|
||||
// Put puts a record into the in-memory store.
|
||||
func (backend *Backend) Put(_ context.Context, record *databroker.Record) error {
|
||||
func (backend *Backend) Put(ctx context.Context, record *databroker.Record) error {
|
||||
if record == nil {
|
||||
return fmt.Errorf("records cannot be nil")
|
||||
}
|
||||
|
||||
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("db_op", "put").
|
||||
Str("db_id", record.Id).
|
||||
Str("db_type", record.Type)
|
||||
})
|
||||
|
||||
backend.mu.Lock()
|
||||
defer backend.mu.Unlock()
|
||||
defer backend.onChange.Broadcast()
|
||||
defer backend.onChange.Broadcast(ctx)
|
||||
|
||||
record.ModifiedAt = timestamppb.Now()
|
||||
record.Version = backend.nextVersion()
|
||||
|
|
|
@ -12,7 +12,7 @@ type recordStream struct {
|
|||
ctx context.Context
|
||||
backend *Backend
|
||||
|
||||
changed chan struct{}
|
||||
changed chan context.Context
|
||||
ready []*databroker.Record
|
||||
version uint64
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ type logger struct {
|
|||
}
|
||||
|
||||
func (l logger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
log.Info().Str("service", "redis").Msgf(format, v...)
|
||||
log.Info(ctx).Str("service", "redis").Msgf(format, v...)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
|
@ -51,6 +51,7 @@ type Backend struct {
|
|||
|
||||
// New creates a new redis storage backend.
|
||||
func New(rawURL string, options ...Option) (*Backend, error) {
|
||||
ctx := context.TODO()
|
||||
cfg := getConfig(options...)
|
||||
backend := &Backend{
|
||||
cfg: cfg,
|
||||
|
@ -63,7 +64,7 @@ func New(rawURL string, options ...Option) (*Backend, error) {
|
|||
return nil, err
|
||||
}
|
||||
metrics.AddRedisMetrics(backend.client.PoolStats)
|
||||
go backend.listenForVersionChanges()
|
||||
go backend.listenForVersionChanges(ctx)
|
||||
if cfg.expiry != 0 {
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
|
@ -75,7 +76,7 @@ func New(rawURL string, options ...Option) (*Backend, error) {
|
|||
case <-ticker.C:
|
||||
}
|
||||
|
||||
backend.removeChangesBefore(time.Now().Add(-cfg.expiry))
|
||||
backend.removeChangesBefore(ctx, time.Now().Add(-cfg.expiry))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
@ -146,7 +147,7 @@ func (backend *Backend) GetAll(ctx context.Context) (records []*databroker.Recor
|
|||
var record databroker.Record
|
||||
err := proto.Unmarshal([]byte(result), &record)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("redis: invalid record detected")
|
||||
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
||||
continue
|
||||
}
|
||||
records = append(records, &record)
|
||||
|
@ -246,8 +247,8 @@ func (backend *Backend) incrementVersion(ctx context.Context,
|
|||
return ErrExceededMaxRetries
|
||||
}
|
||||
|
||||
func (backend *Backend) listenForVersionChanges() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
func (backend *Backend) listenForVersionChanges(ctx context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
<-backend.closed
|
||||
cancel()
|
||||
|
@ -274,14 +275,13 @@ outer:
|
|||
|
||||
switch msg.(type) {
|
||||
case *redis.Message:
|
||||
backend.onChange.Broadcast()
|
||||
backend.onChange.Broadcast(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (backend *Backend) removeChangesBefore(cutoff time.Time) {
|
||||
ctx := context.Background()
|
||||
func (backend *Backend) removeChangesBefore(ctx context.Context, cutoff time.Time) {
|
||||
for {
|
||||
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
|
||||
Min: "-inf",
|
||||
|
@ -291,7 +291,7 @@ func (backend *Backend) removeChangesBefore(cutoff time.Time) {
|
|||
})
|
||||
results, err := cmd.Result()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("redis: error retrieving changes for expiration")
|
||||
log.Error(ctx).Err(err).Msg("redis: error retrieving changes for expiration")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -303,7 +303,7 @@ func (backend *Backend) removeChangesBefore(cutoff time.Time) {
|
|||
var record databroker.Record
|
||||
err = proto.Unmarshal([]byte(results[0]), &record)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("redis: invalid record detected")
|
||||
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
||||
record.ModifiedAt = timestamppb.New(cutoff.Add(-time.Second)) // set the modified so will delete it
|
||||
}
|
||||
|
||||
|
@ -315,7 +315,7 @@ func (backend *Backend) removeChangesBefore(cutoff time.Time) {
|
|||
// remove the record
|
||||
err = backend.client.ZRem(ctx, changesSetKey, results[0]).Err()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("redis: error removing member")
|
||||
log.Error(ctx).Err(err).Msg("redis: error removing member")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -181,7 +181,7 @@ func TestExpiry(t *testing.T) {
|
|||
_ = stream.Close()
|
||||
require.Len(t, records, 1000)
|
||||
|
||||
backend.removeChangesBefore(time.Now().Add(time.Second))
|
||||
backend.removeChangesBefore(ctx, time.Now().Add(time.Second))
|
||||
|
||||
stream, err = backend.Sync(ctx, 0)
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -17,7 +17,7 @@ type recordStream struct {
|
|||
ctx context.Context
|
||||
backend *Backend
|
||||
|
||||
changed chan struct{}
|
||||
changed chan context.Context
|
||||
version uint64
|
||||
record *databroker.Record
|
||||
err error
|
||||
|
@ -63,6 +63,7 @@ func (stream *recordStream) Next(block bool) bool {
|
|||
ticker := time.NewTicker(watchPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
changeCtx := context.Background()
|
||||
for {
|
||||
cmd := stream.backend.client.ZRangeByScore(stream.ctx, changesSetKey, &redis.ZRangeBy{
|
||||
Min: fmt.Sprintf("(%d", stream.version),
|
||||
|
@ -81,7 +82,7 @@ func (stream *recordStream) Next(block bool) bool {
|
|||
var record databroker.Record
|
||||
err = proto.Unmarshal([]byte(result), &record)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("redis: invalid record detected")
|
||||
log.Warn(changeCtx).Err(err).Msg("redis: invalid record detected")
|
||||
} else {
|
||||
stream.record = &record
|
||||
}
|
||||
|
@ -97,7 +98,7 @@ func (stream *recordStream) Next(block bool) bool {
|
|||
case <-stream.closed:
|
||||
return false
|
||||
case <-ticker.C: // check again
|
||||
case <-stream.changed: // check again
|
||||
case changeCtx = <-stream.changed: // check again
|
||||
}
|
||||
} else {
|
||||
return false
|
||||
|
|
|
@ -57,7 +57,7 @@ func MatchAny(any *anypb.Any, query string) bool {
|
|||
msg, err := any.UnmarshalNew()
|
||||
if err != nil {
|
||||
// ignore invalid any types
|
||||
log.Error().Err(err).Msg("storage: invalid any type")
|
||||
log.Error(context.TODO()).Err(err).Msg("storage: invalid any type")
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ func (m *mockCheckClient) Check(ctx context.Context, in *envoy_service_auth_v2.C
|
|||
}
|
||||
|
||||
func TestProxy_ForwardAuth(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
t.Parallel()
|
||||
|
||||
allowClient := &mockCheckClient{
|
||||
|
@ -85,7 +86,7 @@ func TestProxy_ForwardAuth(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.OnConfigChange(&config.Config{Options: tt.options})
|
||||
p.OnConfigChange(ctx, &config.Config{Options: tt.options})
|
||||
state := p.state.Load()
|
||||
state.sessionStore = tt.sessionStore
|
||||
signer, err := jws.NewHS256Signer(nil)
|
||||
|
|
|
@ -2,6 +2,7 @@ package proxy
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -240,7 +241,7 @@ func TestProxy_Callback(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.OnConfigChange(&config.Config{Options: tt.options})
|
||||
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
|
||||
state := p.state.Load()
|
||||
state.encoder = tt.cipher
|
||||
state.sessionStore = tt.sessionStore
|
||||
|
@ -486,7 +487,7 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.OnConfigChange(&config.Config{Options: tt.options})
|
||||
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
|
||||
state := p.state.Load()
|
||||
state.encoder = tt.cipher
|
||||
state.sessionStore = tt.sessionStore
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
|
@ -76,17 +77,17 @@ func New(cfg *config.Config) (*Proxy, error) {
|
|||
}
|
||||
|
||||
// OnConfigChange updates internal structures based on config.Options
|
||||
func (p *Proxy) OnConfigChange(cfg *config.Config) {
|
||||
func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.currentOptions.Store(cfg.Options)
|
||||
if err := p.setHandlers(cfg.Options); err != nil {
|
||||
log.Error().Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
|
||||
log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
|
||||
}
|
||||
if state, err := newProxyStateFromConfig(cfg); err != nil {
|
||||
log.Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings")
|
||||
log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy state from configuration settings")
|
||||
} else {
|
||||
p.state.Store(state)
|
||||
}
|
||||
|
@ -94,7 +95,7 @@ func (p *Proxy) OnConfigChange(cfg *config.Config) {
|
|||
|
||||
func (p *Proxy) setHandlers(opts *config.Options) error {
|
||||
if len(opts.GetAllPolicies()) == 0 {
|
||||
log.Warn().Msg("proxy: configuration has no policies")
|
||||
log.Warn(context.TODO()).Msg("proxy: configuration has no policies")
|
||||
}
|
||||
r := httputil.NewRouter()
|
||||
r.NotFoundHandler = httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
|
@ -197,7 +198,7 @@ func Test_UpdateOptions(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p.OnConfigChange(&config.Config{Options: tt.updatedOptions})
|
||||
p.OnConfigChange(context.Background(), &config.Config{Options: tt.updatedOptions})
|
||||
r := httptest.NewRequest("GET", tt.host, nil)
|
||||
w := httptest.NewRecorder()
|
||||
p.ServeHTTP(w, r)
|
||||
|
@ -210,5 +211,5 @@ func Test_UpdateOptions(t *testing.T) {
|
|||
|
||||
// Test nil
|
||||
var p *Proxy
|
||||
p.OnConfigChange(&config.Config{})
|
||||
p.OnConfigChange(context.Background(), &config.Config{})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue