From d6b02441b31196b481e808bcf2900afb26e92e55 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 19 Mar 2025 14:41:28 -0600 Subject: [PATCH 1/7] authorize: return 403 on invalid sessions (#5536) --- authorize/grpc.go | 9 +++++++-- config/session.go | 4 ++-- internal/sessions/errors.go | 3 +++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/authorize/grpc.go b/authorize/grpc.go index b85e5f3ff..e5a1980b5 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -4,6 +4,7 @@ import ( "context" "encoding/pem" "errors" + "fmt" "io" "net/http" "net/url" @@ -54,8 +55,11 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe // load the session s, err := a.loadSession(ctx, hreq, req) - if err != nil { - return nil, err + if errors.Is(err, sessions.ErrInvalidSession) { + // ENG-2172: if this is an invalid session, don't evaluate policy, return forbidden + return a.deniedResponse(ctx, in, int32(http.StatusForbidden), http.StatusText(http.StatusForbidden), nil) + } else if err != nil { + return nil, fmt.Errorf("error loading session: %w", err) } // if there's a session or service account, load the user @@ -122,6 +126,7 @@ func (a *Authorize) loadSession( Str("request-id", requestID). Err(err). Msg("error creating session for incoming idp token") + return nil, err } sessionState, _ := a.state.Load().sessionStore.LoadSessionStateAndCheckIDP(hreq) diff --git a/config/session.go b/config/session.go index d895f18eb..6270f9872 100644 --- a/config/session.go +++ b/config/session.go @@ -202,7 +202,7 @@ func (c *incomingIDPTokenSessionCreator) createSessionAccessToken( if err != nil { return nil, fmt.Errorf("error verifying access token: %w", err) } else if !res.Valid { - return nil, fmt.Errorf("invalid access token") + return nil, fmt.Errorf("%w: invalid access token", sessions.ErrInvalidSession) } s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims) @@ -265,7 +265,7 @@ func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken( if err != nil { return nil, fmt.Errorf("error verifying identity token: %w", err) } else if !res.Valid { - return nil, fmt.Errorf("invalid identity token") + return nil, fmt.Errorf("%w: invalid identity token", sessions.ErrInvalidSession) } s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims) diff --git a/internal/sessions/errors.go b/internal/sessions/errors.go index 7d43ede6a..bcac5bd97 100644 --- a/internal/sessions/errors.go +++ b/internal/sessions/errors.go @@ -8,6 +8,9 @@ var ( // ErrNoSessionFound is the error for when no session is found. ErrNoSessionFound = errors.New("internal/sessions: session is not found") + // ErrInvalidSession is the error for when a session is invalid. + ErrInvalidSession = errors.New("internal/sessions: invalid session") + // ErrMalformed is the error for when a session is found but is malformed. ErrMalformed = errors.New("internal/sessions: session is malformed") From 08623ef346f693de456139ad0a0e92d904e77362 Mon Sep 17 00:00:00 2001 From: Joe Kralicky Date: Wed, 19 Mar 2025 18:42:19 -0400 Subject: [PATCH 2/7] add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets (#5471) * add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets testenv: - add new TCP upstream - add websocket functions to HTTP upstream - add https support to mock idp (default on) - add new debug flags -env.bind-address and -env.use-trace-environ to allow changing the default bind address, and enabling otel environment based trace config, respectively * linter pass --------- Co-authored-by: Denis Mishin --- config/envoyconfig/protocols_int_test.go | 238 ++++++++++++- internal/testenv/environment.go | 64 +++- internal/testenv/route.go | 17 + internal/testenv/scenarios/mock_idp.go | 57 ++- internal/testenv/scenarios/trace_receiver.go | 12 +- internal/testenv/types.go | 3 +- internal/testenv/upstreams/grpc.go | 12 +- internal/testenv/upstreams/http.go | 339 +++++++++--------- internal/testenv/upstreams/options.go | 4 + internal/testenv/upstreams/tcp.go | 343 +++++++++++++++++++ internal/testenv/upstreams/util.go | 195 +++++++++++ pkg/storage/postgres/tracing_test.go | 2 +- 12 files changed, 1104 insertions(+), 182 deletions(-) create mode 100644 internal/testenv/upstreams/tcp.go create mode 100644 internal/testenv/upstreams/util.go diff --git a/config/envoyconfig/protocols_int_test.go b/config/envoyconfig/protocols_int_test.go index ef1e82785..8c332b0d9 100644 --- a/config/envoyconfig/protocols_int_test.go +++ b/config/envoyconfig/protocols_int_test.go @@ -1,11 +1,17 @@ package envoyconfig_test import ( + "context" + "errors" "fmt" "io" + "net" "net/http" + "sync/atomic" "testing" + "time" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc/credentials/insecure" @@ -28,9 +34,9 @@ func TestH2C(t *testing.T) { http := up.Route(). From(env.SubdomainURL("grpc-http")). - To(values.Bind(up.Port(), func(port int) string { + To(values.Bind(up.Addr(), func(addr string) string { // override the target protocol to use http:// - return fmt.Sprintf("http://127.0.0.1:%d", port) + return fmt.Sprintf("http://%s", addr) })). Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true }) @@ -118,6 +124,234 @@ func TestHTTP(t *testing.T) { }) } +func TestTCPTunnel(t *testing.T) { + env := testenv.New(t, testenv.Debug()) + + env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}})) + up := upstreams.TCP() + routeH1 := up.Route(). + From(env.SubdomainURL("h1")). + PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`) + routeH2 := up.Route(). + From(env.SubdomainURL("h2")). + Policy(func(p *config.Policy) { + p.AllowWebsockets = true + }). + PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`) + + up.Handle(func(_ context.Context, c net.Conn) error { + c.SetReadDeadline(time.Now().Add(1 * time.Second)) + buf := make([]byte, 8) + n, err := c.Read(buf) + require.NoError(t, err) + assert.Equal(t, string(buf[:n]), "hello") + c.SetWriteDeadline(time.Now().Add(1 * time.Second)) + _, err = c.Write([]byte("world")) + require.NoError(t, err) + + return nil + }) + + env.AddUpstream(up) + env.Start() + snippets.WaitStartupComplete(env) + + t.Run("http1", func(t *testing.T) { + assert.NoError(t, up.Dial(routeH1, func(_ context.Context, c net.Conn) error { + c.SetWriteDeadline(time.Now().Add(1 * time.Second)) + _, err := c.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 8) + c.SetReadDeadline(time.Now().Add(1 * time.Second)) + n, err := c.Read(buf) + require.NoError(t, err) + + assert.Equal(t, string(buf[:n]), "world") + return nil + }, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(upstreams.DialHTTP1))) + }) + + t.Run("http2", func(t *testing.T) { + assert.NoError(t, up.Dial(routeH2, func(_ context.Context, c net.Conn) error { + c.SetWriteDeadline(time.Now().Add(1 * time.Second)) + _, err := c.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 8) + c.SetReadDeadline(time.Now().Add(1 * time.Second)) + n, err := c.Read(buf) + require.NoError(t, err) + + assert.Equal(t, string(buf[:n]), "world") + return nil + }, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(upstreams.DialHTTP2))) + }) +} + +func BenchmarkHTTP1TCPTunnel(b *testing.B) { + env := testenv.New(b, testenv.Silent()) + env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}})) + up := upstreams.TCP() + h1 := up.Route(). + From(env.SubdomainURL("bench-h1")). + PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`) + + env.AddUpstream(up) + env.Start() + snippets.WaitStartupComplete(env) + + b.Run("http1", func(b *testing.B) { + benchmarkTCP(b, up, h1, tcpBenchmarkParams{ + msgLen: 512, + protocol: upstreams.DialHTTP1, + }) + }) +} + +func BenchmarkHTTP2TCPTunnel(b *testing.B) { + env := testenv.New(b, testenv.Silent()) + env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}})) + up := upstreams.TCP() + + h2 := up.Route(). + From(env.SubdomainURL("bench-h2")). + Policy(func(p *config.Policy) { + p.AllowWebsockets = true + }). + PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`) + + env.AddUpstream(up) + env.Start() + snippets.WaitStartupComplete(env) + + b.Run("http2", func(b *testing.B) { + benchmarkTCP(b, up, h2, tcpBenchmarkParams{ + msgLen: 512, + protocol: upstreams.DialHTTP2, + }) + }) +} + +type tcpBenchmarkParams struct { + msgLen int + protocol upstreams.Protocol +} + +func benchmarkTCP(b *testing.B, up upstreams.TCPUpstream, route testenv.Route, params tcpBenchmarkParams) { + sendMsg := func(c net.Conn, buf []byte) error { + c.SetWriteDeadline(time.Now().Add(1 * time.Second)) + _, err := c.Write(buf) + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + } + return err + } + recvMsg := func(c net.Conn, buf []byte) error { + c.SetReadDeadline(time.Now().Add(1 * time.Second)) + for read := 0; read != len(buf); { + n, err := c.Read(buf) + read += n + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + } + return nil + } + up.Handle(func(_ context.Context, c net.Conn) error { + for { + buf := make([]byte, params.msgLen) + if err := recvMsg(c, buf[:]); err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + if err := sendMsg(c, buf[:]); err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + } + }) + var threads atomic.Int32 + var requests atomic.Int32 + var bytes atomic.Int64 + start := time.Now() + b.RunParallel(func(p *testing.PB) { + threads.Add(1) + require.NoError(b, up.Dial(route, func(_ context.Context, c net.Conn) error { + buf := make([]byte, params.msgLen) + for p.Next() { + requests.Add(1) + bytes.Add(int64(params.msgLen)) + require.NoError(b, sendMsg(c, buf[:])) + require.NoError(b, recvMsg(c, buf[:])) + } + return nil + }, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(params.protocol))) + }) + duration := time.Since(start) + b.Logf("sent %d requests over %d parallel connections in %s", requests.Load(), threads.Load(), duration) + b.Logf("throughput: %f bytes/s", float64(bytes.Load())/duration.Seconds()) +} + +func TestHttp1Websocket(t *testing.T) { + env := testenv.New(t) + + up := upstreams.HTTP(nil) + up.HandleWS("/ws", websocket.Upgrader{}, func(conn *websocket.Conn) error { + for { + mt, message, err := conn.ReadMessage() + if err != nil { + return err + } + + // echo the message back + err = conn.WriteMessage(mt, message) + if err != nil { + return err + } + } + }) + + route := up.Route(). + From(env.SubdomainURL("ws-test")). + Policy(func(p *config.Policy) { + p.AllowPublicUnauthenticatedAccess = true + p.AllowWebsockets = true + }) + + env.AddUpstream(up) + env.Start() + snippets.WaitStartupComplete(env) + + assert.NoError(t, up.DialWS(route, func(conn *websocket.Conn) error { + if err := conn.SetWriteDeadline(time.Now().Add(1 * time.Second)); err != nil { + return err + } + if err := conn.WriteMessage(websocket.TextMessage, []byte("hello world")); err != nil { + return err + } + if err := conn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { + return err + } + mt, bytes, err := conn.ReadMessage() + if err := err; err != nil { + return err + } + assert.Equal(t, websocket.TextMessage, mt) + assert.Equal(t, "hello world", string(bytes)) + return nil + }, upstreams.Path("/ws"))) +} + func TestClientCert(t *testing.T) { env := testenv.New(t) env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection)) diff --git a/internal/testenv/environment.go b/internal/testenv/environment.go index c42e6c471..8b16f99ea 100644 --- a/internal/testenv/environment.go +++ b/internal/testenv/environment.go @@ -23,6 +23,7 @@ import ( "os/signal" "path" "path/filepath" + "reflect" "runtime" "strconv" "strings" @@ -34,11 +35,13 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config/envoyconfig/filemgr" + "github.com/pomerium/pomerium/config/otelconfig" databroker_service "github.com/pomerium/pomerium/databroker" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/testenv/envutil" "github.com/pomerium/pomerium/internal/testenv/values" + "github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/pkg/cmd/pomerium" "github.com/pomerium/pomerium/pkg/envoy" "github.com/pomerium/pomerium/pkg/grpc/databroker" @@ -97,6 +100,7 @@ type Environment interface { AuthenticateURL() values.Value[string] DatabrokerURL() values.Value[string] Ports() Ports + Host() string SharedSecret() []byte CookieSecret() []byte @@ -244,6 +248,8 @@ type EnvironmentOptions struct { forceSilent bool traceDebugFlags trace.DebugFlags traceClient otlptrace.Client + traceConfig *otelconfig.Config + host string } type EnvironmentOption func(*EnvironmentOptions) @@ -300,15 +306,23 @@ func WithTraceClient(traceClient otlptrace.Client) EnvironmentOption { } } +func WithTraceConfig(traceConfig *otelconfig.Config) EnvironmentOption { + return func(o *EnvironmentOptions) { + o.traceConfig = traceConfig + } +} + var setGrpcLoggerOnce sync.Once const defaultTraceDebugFlags = trace.TrackSpanCallers | trace.TrackSpanReferences var ( - flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)") - flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)") - flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)") - flagTraceDebugFlags = flag.String("env.trace-debug-flags", strconv.Itoa(defaultTraceDebugFlags), "trace debug flags (equivalent to TraceDebugFlags() option)") + flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)") + flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)") + flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)") + flagTraceDebugFlags = flag.String("env.trace-debug-flags", strconv.Itoa(defaultTraceDebugFlags), "trace debug flags (equivalent to TraceDebugFlags() option)") + flagBindAddress = flag.String("env.bind-address", "127.0.0.1", "bind address for local services") + flagTraceEnvironConfig = flag.Bool("env.use-trace-environ", false, "if true, will configure a trace client from environment variables if no trace client has been set") ) func New(t testing.TB, opts ...EnvironmentOption) Environment { @@ -323,6 +337,7 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment { forceSilent: *flagSilent, traceDebugFlags: trace.DebugFlags(defaultTraceDebugFlags), traceClient: trace.NoopClient{}, + host: *flagBindAddress, } options.apply(opts...) if testing.Short() { @@ -332,6 +347,17 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment { if addTraceDebugFlags { options.traceDebugFlags |= trace.DebugFlags(defaultTraceDebugFlags) } + if *flagTraceEnvironConfig && options.traceConfig == nil && + (reflect.TypeOf(options.traceClient) == reflect.TypeFor[trace.NoopClient]()) { + cfg := newOtelConfigFromEnv(t) + options.traceConfig = &cfg + client, err := trace.NewTraceClientFromConfig(cfg) + if err != nil { + t.Fatal(err) + } + t.Log("tracing configured from environment") + options.traceClient = client + } trace.UseGlobalPanicTracer() databroker.DebugUseFasterBackoff.Store(true) workspaceFolder, err := os.Getwd() @@ -495,7 +521,7 @@ func (e *environment) AuthenticateURL() values.Value[string] { func (e *environment) DatabrokerURL() values.Value[string] { return values.Bind(e.ports.Outbound, func(port int) string { - return fmt.Sprintf("127.0.0.1:%d", port) + return fmt.Sprintf("%s:%d", e.host, port) }) } @@ -503,6 +529,13 @@ func (e *environment) Ports() Ports { return e.ports } +func (e *environment) Host() string { + if e.host == "" { + return "127.0.0.1" + } + return e.host +} + func (e *environment) CACert() *tls.Certificate { caCert, err := tls.LoadX509KeyPair( filepath.Join(e.tempDir, "certs", "ca.pem"), @@ -571,9 +604,9 @@ func (e *environment) Start() { cfg.Options.Services = "all" cfg.Options.LogLevel = config.LogLevelDebug cfg.Options.ProxyLogLevel = config.LogLevelInfo - cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyHTTP.Value()) - cfg.Options.GRPCAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyGRPC.Value()) - cfg.Options.MetricsAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyMetrics.Value()) + cfg.Options.Addr = fmt.Sprintf("%s:%d", e.host, e.ports.ProxyHTTP.Value()) + cfg.Options.GRPCAddr = fmt.Sprintf("%s:%d", e.host, e.ports.ProxyGRPC.Value()) + cfg.Options.MetricsAddr = fmt.Sprintf("%s:%d", e.host, e.ports.ProxyMetrics.Value()) cfg.Options.CAFile = filepath.Join(e.tempDir, "certs", "ca.pem") cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem") cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem") @@ -598,6 +631,9 @@ func (e *environment) Start() { log.AccessLogFieldUserAgent, log.AccessLogFieldClientCertificate, } + if e.traceConfig != nil { + cfg.Options.Tracing = *e.traceConfig + } e.src = &configSource{cfg: cfg} e.AddTask(TaskFunc(func(ctx context.Context) error { @@ -799,6 +835,7 @@ func (e *environment) Pause() { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT) <-c + signal.Stop(c) e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m") } @@ -816,6 +853,7 @@ func (e *environment) cleanup(cancelCause error) { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT) <-c + signal.Stop(c) e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m") signal.Stop(c) } @@ -1043,3 +1081,13 @@ func (src *configSource) ModifyConfig(ctx context.Context, m Modifier) { li(ctx, src.cfg) } } + +func newOtelConfigFromEnv(t testing.TB) otelconfig.Config { + f, err := os.CreateTemp("", "tmp-config-*.yaml") + require.NoError(t, err) + defer os.Remove(f.Name()) + f.Close() + cfg, err := config.NewFileOrEnvironmentSource(context.Background(), f.Name(), version.FullVersion()) + require.NoError(t, err) + return cfg.GetConfig().Options.Tracing +} diff --git a/internal/testenv/route.go b/internal/testenv/route.go index 8b002f8bb..0eaea287c 100644 --- a/internal/testenv/route.go +++ b/internal/testenv/route.go @@ -1,6 +1,7 @@ package testenv import ( + "fmt" "net/url" "strings" @@ -74,3 +75,19 @@ func (b *PolicyRoute) PPL(ppl string) Route { func (b *PolicyRoute) URL() values.Value[string] { return b.from } + +type TCPRoute struct { + PolicyRoute +} + +func (b *TCPRoute) From(fromURL values.Value[string]) Route { + b.from = values.Bind(fromURL, func(urlStr string) string { + from, _ := url.Parse(urlStr) + from.Scheme = "tcp+https" + from.Host = fmt.Sprintf("%s:%s", from.Hostname(), from.Port()) + return from.String() + }) + return b +} + +var _ Route = (*TCPRoute)(nil) diff --git a/internal/testenv/scenarios/mock_idp.go b/internal/testenv/scenarios/mock_idp.go index 85babd674..9e07b70e1 100644 --- a/internal/testenv/scenarios/mock_idp.go +++ b/internal/testenv/scenarios/mock_idp.go @@ -6,6 +6,8 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "crypto/tls" + "crypto/x509" "encoding/base64" "encoding/hex" "encoding/json" @@ -32,6 +34,7 @@ import ( ) type IDP struct { + IDPOptions id values.Value[string] url values.Value[string] publicJWK jose.JSONWebKey @@ -41,18 +44,56 @@ type IDP struct { userLookup map[string]*User } +type IDPOptions struct { + enableTLS bool +} + +type IDPOption func(*IDPOptions) + +func (o *IDPOptions) apply(opts ...IDPOption) { + for _, op := range opts { + op(o) + } +} + +func WithEnableTLS(enableTLS bool) IDPOption { + return func(o *IDPOptions) { + o.enableTLS = enableTLS + } +} + // Attach implements testenv.Modifier. func (idp *IDP) Attach(ctx context.Context) { env := testenv.EnvFromContext(ctx) - router := upstreams.HTTP(nil, upstreams.WithDisplayName("IDP")) + idpURL := env.SubdomainURL("mock-idp") - idp.url = values.Bind2(env.SubdomainURL("mock-idp"), router.Port(), func(urlStr string, port int) string { + var tlsConfig values.Value[*tls.Config] + if idp.enableTLS { + tlsConfig = values.Bind(idpURL, func(urlStr string) *tls.Config { + u, _ := url.Parse(urlStr) + cert := env.NewServerCert(&x509.Certificate{ + DNSNames: []string{u.Hostname()}, + }) + return &tls.Config{ + RootCAs: env.ServerCAs(), + Certificates: []tls.Certificate{tls.Certificate(*cert)}, + NextProtos: []string{"http/1.1", "h2"}, + } + }) + } + + router := upstreams.HTTP(tlsConfig, upstreams.WithDisplayName("IDP")) + + idp.url = values.Bind2(idpURL, router.Addr(), func(urlStr string, addr string) string { u, _ := url.Parse(urlStr) host, _, _ := net.SplitHostPort(u.Host) + _, port, err := net.SplitHostPort(addr) + if err != nil { + panic("bug: " + err.Error()) + } return u.ResolveReference(&url.URL{ - Scheme: "http", - Host: fmt.Sprintf("%s:%d", host, port), + Host: fmt.Sprintf("%s:%s", host, port), }).String() }) var err error @@ -108,7 +149,12 @@ func (idp *IDP) Modify(cfg *config.Config) { var _ testenv.Modifier = (*IDP)(nil) -func NewIDP(users []*User) *IDP { +func NewIDP(users []*User, opts ...IDPOption) *IDP { + options := IDPOptions{ + enableTLS: true, + } + options.apply(opts...) + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { panic(err) @@ -136,6 +182,7 @@ func NewIDP(users []*User) *IDP { userLookup[user.ID] = user } return &IDP{ + IDPOptions: options, publicJWK: publicJWK, signingKey: signingKey, userLookup: userLookup, diff --git a/internal/testenv/scenarios/trace_receiver.go b/internal/testenv/scenarios/trace_receiver.go index 96b0c1b8b..f5d761fbb 100644 --- a/internal/testenv/scenarios/trace_receiver.go +++ b/internal/testenv/scenarios/trace_receiver.go @@ -174,16 +174,16 @@ func (rec *OTLPTraceReceiver) FlushResourceSpans() []*tracev1.ResourceSpans { // GRPCEndpointURL returns a url suitable for use with the environment variable // $OTEL_EXPORTER_OTLP_TRACES_ENDPOINT or with [otlptracegrpc.WithEndpointURL]. func (rec *OTLPTraceReceiver) GRPCEndpointURL() values.Value[string] { - return values.Chain(rec.grpcUpstream, upstreams.GRPCUpstream.Port, func(port int) string { - return fmt.Sprintf("http://127.0.0.1:%d", port) + return values.Chain(rec.grpcUpstream, upstreams.GRPCUpstream.Addr, func(addr string) string { + return fmt.Sprintf("http://%s", addr) }) } // GRPCEndpointURL returns a url suitable for use with the environment variable // $OTEL_EXPORTER_OTLP_TRACES_ENDPOINT or with [otlptracehttp.WithEndpointURL]. func (rec *OTLPTraceReceiver) HTTPEndpointURL() values.Value[string] { - return values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Port, func(port int) string { - return fmt.Sprintf("http://127.0.0.1:%d/v1/traces", port) + return values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Addr, func(addr string) string { + return fmt.Sprintf("http://%s/v1/traces", addr) }) } @@ -200,9 +200,9 @@ func (rec *OTLPTraceReceiver) NewGRPCClient(opts ...otlptracegrpc.Option) otlptr func (rec *OTLPTraceReceiver) NewHTTPClient(opts ...otlptracehttp.Option) otlptrace.Client { return &deferredClient{ - client: values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Port, func(port int) otlptrace.Client { + client: values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Addr, func(addr string) otlptrace.Client { return otlptracehttp.NewClient(append(opts, - otlptracehttp.WithEndpointURL(fmt.Sprintf("http://127.0.0.1:%d/v1/traces", port)), + otlptracehttp.WithEndpointURL(fmt.Sprintf("http://%s/v1/traces", addr)), otlptracehttp.WithTimeout(1*time.Minute), )...) }), diff --git a/internal/testenv/types.go b/internal/testenv/types.go index 8e863d3eb..959fd9a33 100644 --- a/internal/testenv/types.go +++ b/internal/testenv/types.go @@ -204,7 +204,8 @@ func (f *taskFunc) Run(ctx context.Context) error { type Upstream interface { Modifier Task - Port() values.Value[int] + + Addr() values.Value[string] Route() RouteStub } diff --git a/internal/testenv/upstreams/grpc.go b/internal/testenv/upstreams/grpc.go index 90773d8c1..c6c203586 100644 --- a/internal/testenv/upstreams/grpc.go +++ b/internal/testenv/upstreams/grpc.go @@ -97,8 +97,10 @@ type service struct { impl any } -func (g *grpcUpstream) Port() values.Value[int] { - return g.serverPort +func (g *grpcUpstream) Addr() values.Value[string] { + return values.Bind(g.serverPort, func(port int) string { + return fmt.Sprintf("%s:%d", g.Env().Host(), port) + }) } // RegisterService implements grpc.ServiceRegistrar. @@ -117,7 +119,7 @@ func (g *grpcUpstream) Route() testenv.RouteStub { protocol = "https" } r.To(values.Bind(g.serverPort, func(port int) string { - return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port) + return fmt.Sprintf("%s://%s:%d", protocol, g.Env().Host(), port) })) g.Add(r) return r @@ -125,7 +127,7 @@ func (g *grpcUpstream) Route() testenv.RouteStub { // Start implements testenv.Upstream. func (g *grpcUpstream) Run(ctx context.Context) error { - listener, err := net.Listen("tcp", "127.0.0.1:0") + listener, err := net.Listen("tcp", fmt.Sprintf("%s:0", g.Env().Host())) if err != nil { return err } @@ -187,7 +189,7 @@ func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc. } func (g *grpcUpstream) DirectConnect(dialOpts ...grpc.DialOption) *grpc.ClientConn { - cc, err := grpc.NewClient(fmt.Sprintf("127.0.0.1:%d", g.Port().Value()), + cc, err := grpc.NewClient(g.Addr().Value(), append(g.withDefaultDialOpts(dialOpts), grpc.WithTransportCredentials(insecure.NewCredentials()))...) if err != nil { panic(err) diff --git a/internal/testenv/upstreams/http.go b/internal/testenv/upstreams/http.go index fd8d53ffd..8d97d61e3 100644 --- a/internal/testenv/upstreams/http.go +++ b/internal/testenv/upstreams/http.go @@ -1,10 +1,8 @@ package upstreams import ( - "bytes" "context" "crypto/tls" - "encoding/json" "errors" "fmt" "io" @@ -13,14 +11,12 @@ import ( "net/http/cookiejar" "net/http/httptrace" "net/url" - "strconv" - "strings" + "os" "sync" "time" "github.com/gorilla/mux" - "github.com/pomerium/pomerium/integration/forms" - "github.com/pomerium/pomerium/internal/retry" + "github.com/gorilla/websocket" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/snippets" @@ -29,9 +25,15 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" - semconv "go.opentelemetry.io/otel/semconv/v1.26.0" oteltrace "go.opentelemetry.io/otel/trace" - "google.golang.org/protobuf/proto" +) + +type Protocol string + +const ( + DialHTTP1 Protocol = "http/1.1" + DialHTTP2 Protocol = "h2" + DialHTTP3 Protocol = "h3" ) type RequestOptions struct { @@ -42,12 +44,18 @@ type RequestOptions struct { authenticateAs string body any clientCerts []tls.Certificate - client *http.Client + clientHook func(*http.Client) *http.Client + dialerHook func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL) + dialProtocol Protocol trace *httptrace.ClientTrace } type RequestOption func(*RequestOptions) +func (ro RequestOption) Format(fmt.State, rune) { + panic("test bug: request option mistakenly passed to assert function") +} + func (o *RequestOptions) apply(opts ...RequestOption) { for _, op := range opts { op(o) @@ -82,9 +90,38 @@ func AuthenticateAs(email string) RequestOption { } } -func Client(c *http.Client) RequestOption { +// ClientHook allows editing or replacing the http client before it is used. +// When any request is about to start, this function will be called with the +// client that would be used to make the request. The returned client will +// be the actual client used for that request. It can be the same as the input +// (with or without modification), or replaced entirely. +// +// Note: the Transport of the client passed to the hook will always be a +// [*Transport]. That transport's underlying transport will always be +// a [*otelhttp.Transport]. +func ClientHook(f func(*http.Client) *http.Client) RequestOption { return func(o *RequestOptions) { - o.client = c + o.clientHook = f + } +} + +// DialerHook allows editing or replacing the websocket dialer before it is +// used. When a websocket request is about to start (using the DialWS method), +// this function will be called with the dialer that would be used, and the +// destination URL (including wss:// scheme, and path if one is present). The +// returned dialer+URL will be the actual dialer+URL used for that request. +// +// If ClientHook is also set, both will be called. The dialer passed to this +// hook will have its TLSClientConfig and Jar fields set from the client. +func DialerHook(f func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)) RequestOption { + return func(o *RequestOptions) { + o.dialerHook = f + } +} + +func DialProtocol(protocol Protocol) RequestOption { + return func(o *RequestOptions) { + o.dialProtocol = protocol } } @@ -143,10 +180,12 @@ type HTTPUpstream interface { testenv.Upstream Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route + HandleWS(path string, upgrader websocket.Upgrader, f func(conn *websocket.Conn) error) *mux.Route Get(r testenv.Route, opts ...RequestOption) (*http.Response, error) Post(r testenv.Route, opts ...RequestOption) (*http.Response, error) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) + DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error } type httpUpstream struct { @@ -194,8 +233,10 @@ func HTTP(tlsConfig values.Value[*tls.Config], opts ...HTTPUpstreamOption) HTTPU } // Port implements HTTPUpstream. -func (h *httpUpstream) Port() values.Value[int] { - return h.serverPort +func (h *httpUpstream) Addr() values.Value[string] { + return values.Bind(h.serverPort, func(port int) string { + return fmt.Sprintf("%s:%d", h.Env().Host(), port) + }) } // Router implements HTTPUpstream. @@ -203,12 +244,37 @@ func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Req return h.router.HandleFunc(path, f) } +// Router implements HTTPUpstream. +func (h *httpUpstream) HandleWS(path string, upgrader websocket.Upgrader, f func(*websocket.Conn) error) *mux.Route { + return h.router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { + ctx, span := trace.Continue(r.Context(), "HandleWS") + defer span.End() + c, err := upgrader.Upgrade(w, r.WithContext(ctx), nil) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(err.Error())) + return + } + defer c.Close() + + err = f(c) + if err != nil { + if errors.Is(err, io.EOF) { + return + } + span.SetStatus(codes.Error, err.Error()) + fmt.Fprintf(os.Stderr, "websocket error: %s\n", err.Error()) + } + }) +} + // Route implements HTTPUpstream. func (h *httpUpstream) Route() testenv.RouteStub { r := &testenv.PolicyRoute{} protocol := "http" r.To(values.Bind(h.serverPort, func(port int) string { - return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port) + return fmt.Sprintf("%s://%s:%d", protocol, h.Env().Host(), port) })) h.Add(r) return r @@ -216,15 +282,21 @@ func (h *httpUpstream) Route() testenv.RouteStub { // Run implements HTTPUpstream. func (h *httpUpstream) Run(ctx context.Context) error { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return err + var listener net.Listener + if h.tlsConfig != nil { + var err error + listener, err = tls.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()), h.tlsConfig.Value()) + if err != nil { + return err + } + } else { + var err error + listener, err = net.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host())) + if err != nil { + return err + } } h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port) - var tlsConfig *tls.Config - if h.tlsConfig != nil { - tlsConfig = h.tlsConfig.Value() - } if h.serverTracerProviderOverride != nil { h.serverTracerProvider.Resolve(h.serverTracerProviderOverride) } else { @@ -238,8 +310,7 @@ func (h *httpUpstream) Run(ctx context.Context) error { h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.serverTracerProvider.Value()))) server := &http.Server{ - Handler: h.router, - TLSConfig: tlsConfig, + Handler: h.router, // BaseContext: func(net.Listener) context.Context { // return ctx // }, @@ -277,6 +348,53 @@ func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Respo return h.Do(http.MethodPost, r, opts...) } +type Transport struct { + *otelhttp.Transport + // The underlying http.Transport instance wrapped by the otelhttp.Transport. + Base *http.Transport +} + +var _ http.RoundTripper = Transport{} + +func (h *httpUpstream) newClient(options *RequestOptions) *http.Client { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = &tls.Config{ + RootCAs: h.Env().ServerCAs(), + Certificates: options.clientCerts, + } + transport.DialTLSContext = nil + c := http.Client{ + Transport: &Transport{ + Transport: otelhttp.NewTransport(transport, + otelhttp.WithTracerProvider(h.clientTracerProvider.Value()), + otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string { + return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path) + }), + ), + Base: transport, + }, + } + c.Jar, _ = cookiejar.New(&cookiejar.Options{}) + return &c +} + +func (h *httpUpstream) getRouteClient(r testenv.Route, options *RequestOptions) *http.Client { + span := oteltrace.SpanFromContext(options.requestCtx) + var cachedClient any + var ok bool + if cachedClient, ok = h.clientCache.Load(r); !ok { + span.AddEvent("creating new http client") + cachedClient, _ = h.clientCache.LoadOrStore(r, h.newClient(options)) + } else { + span.AddEvent("using cached http client") + } + client := cachedClient.(*http.Client) + if options.clientHook != nil { + client = options.clientHook(client) + } + return client +} + // Do implements HTTPUpstream. func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) { options := RequestOptions{ @@ -303,141 +421,54 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) options.requestCtx = ctx defer span.End() - newClient := func() *http.Client { - transport := http.DefaultTransport.(*http.Transport).Clone() - transport.TLSClientConfig = &tls.Config{ - RootCAs: h.Env().ServerCAs(), - Certificates: options.clientCerts, - } - transport.DialTLSContext = nil - c := http.Client{ - Transport: otelhttp.NewTransport(transport, - otelhttp.WithTracerProvider(h.clientTracerProvider.Value()), - otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string { - return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path) - }), - ), - } - c.Jar, _ = cookiejar.New(&cookiejar.Options{}) - return &c - } - var client *http.Client - if options.client != nil { - client = options.client - } else { - var cachedClient any - var ok bool - if cachedClient, ok = h.clientCache.Load(r); !ok { - span.AddEvent("creating new http client") - cachedClient, _ = h.clientCache.LoadOrStore(r, newClient()) - } else { - span.AddEvent("using cached http client") - } - client = cachedClient.(*http.Client) - } - - var resp *http.Response - resendCount := 0 - if err := retry.Retry(ctx, "http", func(ctx context.Context) error { - req, err := http.NewRequestWithContext(ctx, method, u.String(), nil) - if err != nil { - return retry.NewTerminalError(err) - } - switch body := options.body.(type) { - case string: - req.Body = io.NopCloser(strings.NewReader(body)) - case []byte: - req.Body = io.NopCloser(bytes.NewReader(body)) - case io.Reader: - req.Body = io.NopCloser(body) - case proto.Message: - buf, err := proto.Marshal(body) - if err != nil { - return retry.NewTerminalError(err) - } - req.Body = io.NopCloser(bytes.NewReader(buf)) - req.Header.Set("Content-Type", "application/octet-stream") - default: - buf, err := json.Marshal(body) - if err != nil { - panic(fmt.Sprintf("unsupported body type: %T", body)) - } - req.Body = io.NopCloser(bytes.NewReader(buf)) - req.Header.Set("Content-Type", "application/json") - case nil: - } - - if options.authenticateAs != "" { - resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose - } else { - resp, err = client.Do(req) //nolint:bodyclose - } - // retry on connection refused - if err != nil { - span.RecordError(err) - var opErr *net.OpError - if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" { - span.AddEvent("Retrying on dial error") - return err - } - return retry.NewTerminalError(err) - } - if resp.StatusCode/100 == 5 { - resendCount++ - _, _ = io.ReadAll(resp.Body) - _ = resp.Body.Close() - span.SetAttributes(semconv.HTTPRequestResendCount(resendCount)) - span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes( - attribute.String("status", resp.Status), - )) - return errors.New(http.StatusText(resp.StatusCode)) - } - span.SetStatus(codes.Ok, "request completed successfully") - return nil - }, - retry.WithInitialInterval(1*time.Millisecond), - retry.WithMaxInterval(100*time.Millisecond), - ); err != nil { - return nil, err - } - return resp, nil + return doAuthenticatedRequest(options.requestCtx, + func(ctx context.Context) (*http.Request, error) { + return http.NewRequestWithContext(ctx, method, u.String(), nil) + }, + func(context.Context) *http.Client { + return h.getRouteClient(r, &options) + }, + &options, + ) } -func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string) (*http.Response, error) { - span := oteltrace.SpanFromContext(ctx) - var res *http.Response - originalHostname := req.URL.Hostname() - res, err := client.Do(req) +func (h *httpUpstream) DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error { + options := RequestOptions{ + requestCtx: h.Env().Context(), + } + options.apply(opts...) + u, err := url.Parse(r.URL().Value()) if err != nil { - span.RecordError(err) - return nil, err + return err } - location := res.Request.URL - if location.Hostname() == originalHostname { - // already authenticated - span.SetStatus(codes.Ok, "already authenticated") - return res, nil + u.Scheme = "wss" + if options.path != "" || options.query != nil { + u = u.ResolveReference(&url.URL{ + Path: options.path, + RawQuery: options.query.Encode(), + }) } - fs := forms.Parse(res.Body) - _, _ = io.ReadAll(res.Body) - _ = res.Body.Close() - if len(fs) > 0 { - f := fs[0] - f.Inputs["email"] = email - f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds())) - span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String()))) - formReq, err := f.NewRequestWithContext(ctx, location) - if err != nil { - span.RecordError(err) - return nil, err - } - resp, err := client.Do(formReq) - if err != nil { - span.RecordError(err) - return nil, err - } - span.SetStatus(codes.Ok, "form submitted successfully") - return resp, nil + ctx, span := h.clientTracer.Value().Start(options.requestCtx, "httpUpstream.Dial", oteltrace.WithAttributes( + attribute.String("url", u.String()), + )) + options.requestCtx = ctx + defer span.End() + + client := h.getRouteClient(r, &options) + d := &websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + TLSClientConfig: client.Transport.(*Transport).Base.TLSClientConfig, + Jar: client.Jar, } - return nil, fmt.Errorf("test bug: expected IDP login form") + if options.dialerHook != nil { + d, u = options.dialerHook(d, u) + } + conn, resp, err := d.DialContext(options.requestCtx, u.String(), nil) + if err != nil { + resp.Body.Close() + return fmt.Errorf("DialContext: %w", err) + } + defer conn.Close() + + return f(conn) } diff --git a/internal/testenv/upstreams/options.go b/internal/testenv/upstreams/options.go index c0bad32b5..3d58548bf 100644 --- a/internal/testenv/upstreams/options.go +++ b/internal/testenv/upstreams/options.go @@ -15,6 +15,7 @@ type CommonUpstreamOptions struct { type CommonUpstreamOption interface { GRPCUpstreamOption HTTPUpstreamOption + TCPUpstreamOption } type commonUpstreamOption func(o *CommonUpstreamOptions) @@ -25,6 +26,9 @@ func (c commonUpstreamOption) applyGRPC(o *GRPCUpstreamOptions) { c(&o.CommonUps // applyHTTP implements CommonUpstreamOption. func (c commonUpstreamOption) applyHTTP(o *HTTPUpstreamOptions) { c(&o.CommonUpstreamOptions) } +// applyTCP implements CommonUpstreamOption. +func (c commonUpstreamOption) applyTCP(o *TCPUpstreamOptions) { c(&o.CommonUpstreamOptions) } + func WithDisplayName(displayName string) CommonUpstreamOption { return commonUpstreamOption(func(o *CommonUpstreamOptions) { o.displayName = displayName diff --git a/internal/testenv/upstreams/tcp.go b/internal/testenv/upstreams/tcp.go new file mode 100644 index 000000000..6eab5b629 --- /dev/null +++ b/internal/testenv/upstreams/tcp.go @@ -0,0 +1,343 @@ +package upstreams + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/cookiejar" + "net/http/httptrace" + "net/url" + "sync" + + "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/values" + "go.opentelemetry.io/otel/attribute" + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/net/http2" +) + +type TCPUpstream interface { + testenv.Upstream + + Handle(fn func(context.Context, net.Conn) error) + + Dial(r testenv.Route, fn func(context.Context, net.Conn) error, opts ...RequestOption) error +} + +type TCPUpstreamOptions struct { + CommonUpstreamOptions +} + +type TCPUpstreamOption interface { + applyTCP(*TCPUpstreamOptions) +} + +type tcpUpstream struct { + TCPUpstreamOptions + testenv.Aggregate + serverPort values.MutableValue[int] + serverHandler func(context.Context, net.Conn) error + + serverTracerProvider values.MutableValue[oteltrace.TracerProvider] + clientTracerProvider values.MutableValue[oteltrace.TracerProvider] + clientTracer values.Value[oteltrace.Tracer] +} + +func TCP(opts ...TCPUpstreamOption) TCPUpstream { + options := TCPUpstreamOptions{ + CommonUpstreamOptions: CommonUpstreamOptions{ + displayName: "TCP Upstream", + }, + } + for _, op := range opts { + op.applyTCP(&options) + } + up := &tcpUpstream{ + TCPUpstreamOptions: options, + serverPort: values.Deferred[int](), + + serverTracerProvider: values.Deferred[oteltrace.TracerProvider](), + clientTracerProvider: values.Deferred[oteltrace.TracerProvider](), + } + up.clientTracer = values.Bind(up.clientTracerProvider, func(tp oteltrace.TracerProvider) oteltrace.Tracer { + return tp.Tracer(trace.PomeriumCoreTracer) + }) + up.RecordCaller() + return up +} + +// Dial implements TCPUpstream. +func (t *tcpUpstream) Dial(r testenv.Route, clientHandler func(context.Context, net.Conn) error, opts ...RequestOption) error { + options := RequestOptions{ + requestCtx: t.Env().Context(), + dialProtocol: DialHTTP1, + } + options.apply(opts...) + u, err := url.Parse(r.URL().Value()) + if err != nil { + return err + } + + ctx, span := t.clientTracer.Value().Start(options.requestCtx, "tcpUpstream.Do", oteltrace.WithAttributes( + attribute.String("protocol", string(options.dialProtocol)), + attribute.String("url", u.String()), + )) + if options.path != "" || options.query != nil { + u = u.ResolveReference(&url.URL{ + Path: options.path, + RawQuery: options.query.Encode(), + }) + } + if options.trace != nil { + ctx = httptrace.WithClientTrace(ctx, options.trace) + } + options.requestCtx = ctx + defer span.End() + + var remoteConn *tls.Conn + remoteWriter := make(chan *io.PipeWriter, 1) + + connectURL := &url.URL{Scheme: "https", Host: u.Host, Path: u.Path} + + var getClientFn func(context.Context) *http.Client + var newRequestFn func(ctx context.Context) (*http.Request, error) + switch options.dialProtocol { + case DialHTTP1: + getClientFn = t.h1Dialer(&options, connectURL, &remoteConn) + newRequestFn = func(ctx context.Context) (*http.Request, error) { + req := (&http.Request{ + Method: http.MethodConnect, + URL: connectURL, + Host: u.Host, + }).WithContext(ctx) + return req, nil + } + case DialHTTP2: + getClientFn = t.h2Dialer(&options, connectURL, &remoteConn, remoteWriter) + newRequestFn = func(ctx context.Context) (*http.Request, error) { + req := (&http.Request{ + Method: http.MethodConnect, + URL: connectURL, + Host: u.Host, + Proto: "HTTP/2", + }).WithContext(ctx) + return req, nil + } + case DialHTTP3: + panic("not implemented") + } + resp, err := doAuthenticatedRequest(options.requestCtx, newRequestFn, getClientFn, &options) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return errors.New(resp.Status) + } + if resp.Request.URL.Path == "/oidc/auth" { + if options.authenticateAs == "" { + return errors.New("test bug: unexpected IDP redirect; missing AuthenticateAs option to Dial()") + } + return errors.New("internal test bug: unexpected IDP redirect") + } + + var w io.WriteCloser = remoteConn + if options.dialProtocol == DialHTTP2 { + w = <-remoteWriter + } + + conn := NewRWConn(resp.Body, w) + defer conn.Close() + return clientHandler(resp.Request.Context(), conn) +} + +func (t *tcpUpstream) h1Dialer( + options *RequestOptions, + connectURL *url.URL, + remoteConn **tls.Conn, +) func(context.Context) *http.Client { + jar, _ := cookiejar.New(nil) + return func(context.Context) *http.Client { + tlsConfig := &tls.Config{ + RootCAs: t.Env().ServerCAs(), + Certificates: options.clientCerts, + NextProtos: []string{"http/1.1"}, + } + client := &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if *remoteConn != nil { + (*remoteConn).Close() + *remoteConn = nil + } + dialer := &tls.Dialer{ + Config: tlsConfig, + } + cc, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrRetry, err) + } + protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol + if protocol != "http/1.1" { + cc.Close() + return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol) + } + *remoteConn = cc.(*tls.Conn) + return cc, nil + }, + TLSClientConfig: tlsConfig, // important + }, + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.String() == connectURL.String() && req.Method == http.MethodGet { + req.Method = http.MethodConnect + } + return nil + }, + Jar: jar, + } + return client + } +} + +func (t *tcpUpstream) h2Dialer( + options *RequestOptions, + connectURL *url.URL, + remoteConn **tls.Conn, + writer chan<- *io.PipeWriter, +) func(context.Context) *http.Client { + jar, _ := cookiejar.New(nil) + return func(context.Context) *http.Client { + h1 := &http.Transport{ + ForceAttemptHTTP2: true, + DisableKeepAlives: true, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if *remoteConn != nil { + (*remoteConn).Close() + *remoteConn = nil + } + dialer := &tls.Dialer{ + Config: &tls.Config{ + RootCAs: t.Env().ServerCAs(), + Certificates: options.clientCerts, + NextProtos: []string{"h2"}, + }, + } + cc, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrRetry, err) + } + protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol + if protocol != "h2" { + cc.Close() + return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol) + } + *remoteConn = cc.(*tls.Conn) + + return cc, nil + }, + TLSClientConfig: &tls.Config{ + RootCAs: t.Env().ServerCAs(), + Certificates: options.clientCerts, + NextProtos: []string{"h2"}, + }, + } + if err := http2.ConfigureTransport(h1); err != nil { + panic(err) + } + client := &http.Client{ + Transport: h1, + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.String() == connectURL.String() && req.Method == http.MethodGet { + pr, pw := io.Pipe() + req.Method = http.MethodConnect + req.Body = pr + req.ContentLength = -1 + writer <- pw + } + return nil + }, + Jar: jar, + } + return client + } +} + +// Handle implements TCPUpstream. +func (t *tcpUpstream) Handle(fn func(context.Context, net.Conn) error) { + t.serverHandler = fn +} + +// Port implements TCPUpstream. +func (t *tcpUpstream) Addr() values.Value[string] { + return values.Bind(t.serverPort, func(port int) string { + return fmt.Sprintf("%s:%d", t.Env().Host(), port) + }) +} + +// Route implements TCPUpstream. +func (t *tcpUpstream) Route() testenv.RouteStub { + r := &testenv.TCPRoute{} + r.To(values.Bind(t.serverPort, func(port int) string { + return fmt.Sprintf("tcp://%s:%d", t.Env().Host(), port) + })) + t.Add(r) + return r +} + +// Run implements TCPUpstream. +func (t *tcpUpstream) Run(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", fmt.Sprintf("%s:0", t.Env().Host())) + if err != nil { + return err + } + context.AfterFunc(ctx, func() { + listener.Close() + }) + t.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port) + if t.serverTracerProviderOverride != nil { + t.serverTracerProvider.Resolve(t.serverTracerProviderOverride) + } else { + t.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, t.displayName)) + } + if t.clientTracerProviderOverride != nil { + t.clientTracerProvider.Resolve(t.clientTracerProviderOverride) + } else { + t.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "TCP Client")) + } + var wg sync.WaitGroup + defer wg.Wait() + + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + cancel() + return nil + } + continue + } + wg.Add(1) + go func() { + defer wg.Done() + if err := t.serverHandler(ctx, conn); err != nil { + if errors.Is(err, io.EOF) { + return + } + panic("server handler error: " + err.Error()) + } + }() + } +} + +var ( + _ testenv.Upstream = (*tcpUpstream)(nil) + _ TCPUpstream = (*tcpUpstream)(nil) +) diff --git a/internal/testenv/upstreams/util.go b/internal/testenv/upstreams/util.go new file mode 100644 index 000000000..cc5b3cc2f --- /dev/null +++ b/internal/testenv/upstreams/util.go @@ -0,0 +1,195 @@ +package upstreams + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/pomerium/pomerium/integration/forms" + "github.com/pomerium/pomerium/internal/retry" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + oteltrace "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/proto" +) + +var ErrRetry = errors.New("error") + +func doAuthenticatedRequest( + ctx context.Context, + newRequest func(context.Context) (*http.Request, error), + getClient func(context.Context) *http.Client, + options *RequestOptions, +) (*http.Response, error) { + var resp *http.Response + resendCount := 0 + client := getClient(ctx) + + if err := retry.Retry(ctx, "http", func(ctx context.Context) error { + req, err := newRequest(ctx) + if err != nil { + return retry.NewTerminalError(err) + } + + switch body := options.body.(type) { + case string: + req.Body = io.NopCloser(strings.NewReader(body)) + case []byte: + req.Body = io.NopCloser(bytes.NewReader(body)) + case io.Reader: + req.Body = io.NopCloser(body) + case proto.Message: + buf, err := proto.Marshal(body) + if err != nil { + return retry.NewTerminalError(err) + } + req.Body = io.NopCloser(bytes.NewReader(buf)) + req.Header.Set("Content-Type", "application/octet-stream") + default: + buf, err := json.Marshal(body) + if err != nil { + panic(fmt.Sprintf("unsupported body type: %T", body)) + } + req.Body = io.NopCloser(bytes.NewReader(buf)) + req.Header.Set("Content-Type", "application/json") + case nil: + } + + if options.headers != nil && req.Header == nil { + req.Header = http.Header{} + } + for k, v := range options.headers { + req.Header.Add(k, v) + } + + if options.authenticateAs != "" { + resp, err = authenticateFlow(ctx, client, req, options.authenticateAs, true) //nolint:bodyclose + } else { + resp, err = client.Do(req) //nolint:bodyclose + } + // retry on connection refused + span := oteltrace.SpanFromContext(ctx) + if err != nil { + span.RecordError(err) + var opErr *net.OpError + if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" { + span.AddEvent("Retrying on dial error") + return err + } + return retry.NewTerminalError(err) + } + if resp.StatusCode/100 == 5 { + resendCount++ + _, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + span.SetAttributes(semconv.HTTPRequestResendCount(resendCount)) + span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes( + attribute.String("status", resp.Status), + )) + return errors.New(http.StatusText(resp.StatusCode)) + } + span.SetStatus(codes.Ok, "request completed successfully") + return nil + }, + retry.WithInitialInterval(1*time.Millisecond), + retry.WithMaxInterval(100*time.Millisecond), + ); err != nil { + return nil, err + } + return resp, nil +} + +func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string, checkLocation bool) (*http.Response, error) { + span := oteltrace.SpanFromContext(ctx) + var res *http.Response + originalHostname := req.URL.Hostname() + res, err := client.Do(req) + if err != nil { + span.RecordError(err) + return nil, err + } + location := res.Request.URL + if checkLocation && location.Hostname() == originalHostname { + // already authenticated + span.SetStatus(codes.Ok, "already authenticated") + return res, nil + } + fs := forms.Parse(res.Body) + _, _ = io.ReadAll(res.Body) + _ = res.Body.Close() + if len(fs) > 0 { + f := fs[0] + f.Inputs["email"] = email + f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds())) + span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String()))) + formReq, err := f.NewRequestWithContext(ctx, location) + if err != nil { + span.RecordError(err) + return nil, err + } + resp, err := client.Do(formReq) + if err != nil { + span.RecordError(err) + return nil, err + } + span.SetStatus(codes.Ok, "form submitted successfully") + return resp, nil + } + return nil, fmt.Errorf("test bug: expected IDP login form") +} + +type rwConn struct { + serverReader io.ReadCloser + serverWriter io.WriteCloser + + net.Conn + remote net.Conn + + closeOnce sync.Once + wg *sync.WaitGroup +} + +func NewRWConn(reader io.ReadCloser, writer io.WriteCloser) net.Conn { + rwc := &rwConn{ + serverReader: reader, + serverWriter: writer, + wg: &sync.WaitGroup{}, + } + + rwc.Conn, rwc.remote = net.Pipe() + rwc.wg.Add(2) + go func() { + defer rwc.wg.Done() + _, _ = io.Copy(rwc.remote, rwc.serverReader) + rwc.remote.Close() + }() + go func() { + defer rwc.wg.Done() + _, _ = io.Copy(rwc.serverWriter, rwc.remote) + rwc.serverWriter.Close() + }() + return rwc +} + +func (rwc *rwConn) Close() error { + var err error + rwc.closeOnce.Do(func() { + readerErr := rwc.serverReader.Close() + localErr := rwc.Conn.Close() + rwc.wg.Wait() + err = errors.Join(localErr, readerErr) + }) + return err +} + +var _ net.Conn = (*rwConn)(nil) diff --git a/pkg/storage/postgres/tracing_test.go b/pkg/storage/postgres/tracing_test.go index 3f00eb9f6..7e966ea30 100644 --- a/pkg/storage/postgres/tracing_test.go +++ b/pkg/storage/postgres/tracing_test.go @@ -48,7 +48,7 @@ func TestQueryTracing(t *testing.T) { snippets.WaitStartupComplete(env) resp, err := up.Get(route, upstreams.AuthenticateAs("user@example.com"), upstreams.Path("/foo")) - assert.NoError(t, err) + require.NoError(t, err) io.ReadAll(resp.Body) resp.Body.Close() From bc263e3ee58705b1bc7b3c6cab2370d719d12b45 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Thu, 20 Mar 2025 09:50:22 -0600 Subject: [PATCH 3/7] proxy: use querier cache for user info (#5532) --- authorize/authorize.go | 5 +- authorize/databroker.go | 50 +------------ authorize/databroker_test.go | 37 ---------- authorize/grpc.go | 14 +--- pkg/grpc/databroker/databroker.go | 30 -------- pkg/storage/cache.go | 3 + pkg/storage/querier.go | 112 ++++++++++++++++++++++++++++++ pkg/storage/querier_test.go | 101 +++++++++++++++++++++++++++ proxy/data.go | 25 +++---- proxy/data_test.go | 8 ++- proxy/proxy.go | 16 +++++ proxy/state.go | 14 ++-- 12 files changed, 259 insertions(+), 156 deletions(-) create mode 100644 pkg/storage/querier_test.go diff --git a/authorize/authorize.go b/authorize/authorize.go index 82ab173a3..f6ed77525 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -6,7 +6,6 @@ import ( "context" "fmt" "slices" - "time" "github.com/rs/zerolog" oteltrace "go.opentelemetry.io/otel/trace" @@ -31,7 +30,6 @@ type Authorize struct { store *store.Store currentConfig *atomicutil.Value[*config.Config] accessTracker *AccessTracker - globalCache storage.Cache groupsCacheWarmer *cacheWarmer tracerProvider oteltrace.TracerProvider @@ -45,7 +43,6 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) { a := &Authorize{ currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}), store: store.New(), - globalCache: storage.NewGlobalCache(time.Minute), tracerProvider: tracerProvider, tracer: tracer, } @@ -57,7 +54,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) { } a.state = atomicutil.NewValue(state) - a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, a.globalCache, directory.GroupRecordType) + a.groupsCacheWarmer = newCacheWarmer(state.dataBrokerClientConnection, storage.GlobalCache, directory.GroupRecordType) return a, nil } diff --git a/authorize/databroker.go b/authorize/databroker.go index 2c59e4c30..1e474e792 100644 --- a/authorize/databroker.go +++ b/authorize/databroker.go @@ -3,9 +3,6 @@ package authorize import ( "context" - "google.golang.org/grpc" - - "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpcutil" @@ -18,47 +15,6 @@ type sessionOrServiceAccount interface { Validate() error } -func getDataBrokerRecord( - ctx context.Context, - recordType string, - recordID string, - lowestRecordVersion uint64, -) (*databroker.Record, error) { - q := storage.GetQuerier(ctx) - - req := &databroker.QueryRequest{ - Type: recordType, - Limit: 1, - } - req.SetFilterByIDOrIndex(recordID) - - res, err := q.Query(ctx, req, grpc.WaitForReady(true)) - if err != nil { - return nil, err - } - if len(res.GetRecords()) == 0 { - return nil, storage.ErrNotFound - } - - // if the current record version is less than the lowest we'll accept, invalidate the cache - if res.GetRecords()[0].GetVersion() < lowestRecordVersion { - q.InvalidateCache(ctx, req) - } else { - return res.GetRecords()[0], nil - } - - // retry with an up to date cache - res, err = q.Query(ctx, req) - if err != nil { - return nil, err - } - if len(res.GetRecords()) == 0 { - return nil, storage.ErrNotFound - } - - return res.GetRecords()[0], nil -} - func (a *Authorize) getDataBrokerSessionOrServiceAccount( ctx context.Context, sessionID string, @@ -67,9 +23,9 @@ func (a *Authorize) getDataBrokerSessionOrServiceAccount( ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerSessionOrServiceAccount") defer span.End() - record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion) + record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, dataBrokerRecordVersion) if storage.IsNotFound(err) { - record, err = getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion) + record, err = storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.ServiceAccount)), sessionID, dataBrokerRecordVersion) } if err != nil { return nil, err @@ -100,7 +56,7 @@ func (a *Authorize) getDataBrokerUser( ctx, span := a.tracer.Start(ctx, "authorize.getDataBrokerUser") defer span.End() - record, err := getDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0) + record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(user.User)), userID, 0) if err != nil { return nil, err } diff --git a/authorize/databroker_test.go b/authorize/databroker_test.go index eefb5987b..f8f47bef7 100644 --- a/authorize/databroker_test.go +++ b/authorize/databroker_test.go @@ -2,7 +2,6 @@ package authorize import ( "context" - "fmt" "testing" "time" @@ -12,45 +11,9 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/pkg/grpc/session" - "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/storage" ) -func Test_getDataBrokerRecord(t *testing.T) { - t.Parallel() - - ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) - t.Cleanup(clearTimeout) - - for _, tc := range []struct { - name string - recordVersion, queryVersion uint64 - underlyingQueryCount, cachedQueryCount int - }{ - {"cached", 1, 1, 1, 2}, - {"invalidated", 1, 2, 3, 4}, - } { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)} - - sq := storage.NewStaticQuerier(s1) - cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute)) - qctx := storage.WithQuerier(ctx, cq) - - s, err := getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) - assert.NoError(t, err) - assert.NotNil(t, s) - - s, err = getDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) - assert.NoError(t, err) - assert.NotNil(t, s) - }) - } -} - func TestAuthorize_getDataBrokerSessionOrServiceAccount(t *testing.T) { t.Parallel() diff --git a/authorize/grpc.go b/authorize/grpc.go index e5a1980b5..94c172b30 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -36,7 +36,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe querier := storage.NewCachingQuerier( storage.NewQuerier(a.state.Load().dataBrokerClient), - a.globalCache, + storage.GlobalCache, ) ctx = storage.WithQuerier(ctx, querier) @@ -98,7 +98,7 @@ func (a *Authorize) loadSession( // attempt to create a session from an incoming idp token s, err = config.NewIncomingIDPTokenSessionCreator( func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) { - return getDataBrokerRecord(ctx, recordType, recordID, 0) + return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0) }, func(ctx context.Context, records []*databroker.Record) error { _, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{ @@ -107,15 +107,7 @@ func (a *Authorize) loadSession( if err != nil { return err } - // invalidate cache - for _, record := range records { - q := &databroker.QueryRequest{ - Type: record.GetType(), - Limit: 1, - } - q.SetFilterByIDOrIndex(record.GetId()) - storage.GetQuerier(ctx).InvalidateCache(ctx, q) - } + storage.InvalidateCacheForDataBrokerRecords(ctx, records...) return nil }, ).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq) diff --git a/pkg/grpc/databroker/databroker.go b/pkg/grpc/databroker/databroker.go index 26841ca7f..d3933dbfa 100644 --- a/pkg/grpc/databroker/databroker.go +++ b/pkg/grpc/databroker/databroker.go @@ -3,14 +3,12 @@ package databroker import ( "context" - "encoding/json" "errors" "fmt" "io" "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" structpb "google.golang.org/protobuf/types/known/structpb" @@ -53,34 +51,6 @@ func Get(ctx context.Context, client DataBrokerServiceClient, object recordObjec return res.GetRecord().GetData().UnmarshalTo(object) } -// GetViaJSON gets a record from the databroker, marshals it to JSON, and then unmarshals it to the given type. -func GetViaJSON[T any](ctx context.Context, client DataBrokerServiceClient, recordType, recordID string) (*T, error) { - res, err := client.Get(ctx, &GetRequest{ - Type: recordType, - Id: recordID, - }) - if err != nil { - return nil, err - } - - msg, err := res.GetRecord().GetData().UnmarshalNew() - if err != nil { - return nil, err - } - - bs, err := protojson.Marshal(msg) - if err != nil { - return nil, err - } - - var obj T - err = json.Unmarshal(bs, &obj) - if err != nil { - return nil, err - } - return &obj, nil -} - // Put puts a record into the databroker. func Put(ctx context.Context, client DataBrokerServiceClient, objects ...recordObject) (*PutResponse, error) { records := make([]*Record, len(objects)) diff --git a/pkg/storage/cache.go b/pkg/storage/cache.go index 1f0e391b4..3d8d563e6 100644 --- a/pkg/storage/cache.go +++ b/pkg/storage/cache.go @@ -107,3 +107,6 @@ func (cache *globalCache) set(expiry time.Time, key, value []byte) { cache.fastcache.Set(key, item) cache.mu.Unlock() } + +// GlobalCache is a global cache with a TTL of one minute. +var GlobalCache = NewGlobalCache(time.Minute) diff --git a/pkg/storage/querier.go b/pkg/storage/querier.go index c70247e5c..ddda78769 100644 --- a/pkg/storage/querier.go +++ b/pkg/storage/querier.go @@ -15,6 +15,7 @@ import ( "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/protoutil" ) @@ -222,3 +223,114 @@ func MarshalQueryResponse(res *databroker.QueryResponse) ([]byte, error) { Deterministic: true, }).Marshal(res) } + +// GetDataBrokerRecord uses a querier to get a databroker record. +func GetDataBrokerRecord( + ctx context.Context, + recordType string, + recordID string, + lowestRecordVersion uint64, +) (*databroker.Record, error) { + q := GetQuerier(ctx) + + req := &databroker.QueryRequest{ + Type: recordType, + Limit: 1, + } + req.SetFilterByIDOrIndex(recordID) + + res, err := q.Query(ctx, req, grpc.WaitForReady(true)) + if err != nil { + return nil, err + } + if len(res.GetRecords()) == 0 { + return nil, ErrNotFound + } + + // if the current record version is less than the lowest we'll accept, invalidate the cache + if res.GetRecords()[0].GetVersion() < lowestRecordVersion { + q.InvalidateCache(ctx, req) + } else { + return res.GetRecords()[0], nil + } + + // retry with an up to date cache + res, err = q.Query(ctx, req) + if err != nil { + return nil, err + } + if len(res.GetRecords()) == 0 { + return nil, ErrNotFound + } + + return res.GetRecords()[0], nil +} + +// GetDataBrokerMessage gets a databroker record and converts it into the message type. +func GetDataBrokerMessage[T any, TMessage interface { + *T + proto.Message +}]( + ctx context.Context, + recordID string, + lowestRecordVersion uint64, +) (TMessage, error) { + var msg T + + record, err := GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(TMessage(&msg)), recordID, lowestRecordVersion) + if err != nil { + return nil, err + } + + err = record.GetData().UnmarshalTo(TMessage(&msg)) + if err != nil { + return nil, err + } + + return TMessage(&msg), nil +} + +// GetDataBrokerObjectViaJSON gets a databroker record and converts it into the object type by going through protojson. +func GetDataBrokerObjectViaJSON[T any]( + ctx context.Context, + recordType string, + recordID string, + lowestRecordVersion uint64, +) (*T, error) { + record, err := GetDataBrokerRecord(ctx, recordType, recordID, lowestRecordVersion) + if err != nil { + return nil, err + } + + msg, err := record.GetData().UnmarshalNew() + if err != nil { + return nil, err + } + + bs, err := protojson.Marshal(msg) + if err != nil { + return nil, err + } + + var obj T + err = json.Unmarshal(bs, &obj) + if err != nil { + return nil, err + } + return &obj, nil +} + +// InvalidateCacheForDataBrokerRecords invalidates the cache of the querier for the databroker records. +func InvalidateCacheForDataBrokerRecords( + ctx context.Context, + records ...*databroker.Record, +) { + for _, record := range records { + q := &databroker.QueryRequest{ + Type: record.GetType(), + Limit: 1, + } + q.SetFilterByIDOrIndex(record.GetId()) + GetQuerier(ctx).InvalidateCache(ctx, q) + } +} diff --git a/pkg/storage/querier_test.go b/pkg/storage/querier_test.go new file mode 100644 index 000000000..904f261cd --- /dev/null +++ b/pkg/storage/querier_test.go @@ -0,0 +1,101 @@ +package storage_test + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/pomerium/datasource/pkg/directory" + "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpcutil" + "github.com/pomerium/pomerium/pkg/storage" +) + +func TestGetDataBrokerRecord(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + t.Cleanup(clearTimeout) + + for _, tc := range []struct { + name string + recordVersion, queryVersion uint64 + underlyingQueryCount, cachedQueryCount int + }{ + {"cached", 1, 1, 1, 2}, + {"invalidated", 1, 2, 3, 4}, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + s1 := &session.Session{Id: "s1", Version: fmt.Sprint(tc.recordVersion)} + + sq := storage.NewStaticQuerier(s1) + cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute)) + qctx := storage.WithQuerier(ctx, cq) + + s, err := storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) + assert.NoError(t, err) + assert.NotNil(t, s) + + s, err = storage.GetDataBrokerRecord(qctx, grpcutil.GetTypeURL(s1), s1.GetId(), tc.queryVersion) + assert.NoError(t, err) + assert.NotNil(t, s) + }) + } +} + +func TestGetDataBrokerMessage(t *testing.T) { + t.Parallel() + + ctx := testutil.GetContext(t, time.Minute) + + s1 := &session.Session{Id: "s1"} + sq := storage.NewStaticQuerier(s1) + cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute)) + qctx := storage.WithQuerier(ctx, cq) + + s2, err := storage.GetDataBrokerMessage[session.Session](qctx, "s1", 0) + assert.NoError(t, err) + assert.Empty(t, cmp.Diff(s1, s2, protocmp.Transform())) + + _, err = storage.GetDataBrokerMessage[session.Session](qctx, "s2", 0) + assert.ErrorIs(t, err, storage.ErrNotFound) +} + +func TestGetDataBrokerObjectViaJSON(t *testing.T) { + t.Parallel() + + ctx := testutil.GetContext(t, time.Minute) + + du1 := &directory.User{ + ID: "u1", + Email: "u1@example.com", + DisplayName: "User 1!", + } + sq := storage.NewStaticQuerier(newDirectoryUserRecord(du1)) + cq := storage.NewCachingQuerier(sq, storage.NewGlobalCache(time.Minute)) + qctx := storage.WithQuerier(ctx, cq) + + du2, err := storage.GetDataBrokerObjectViaJSON[directory.User](qctx, directory.UserRecordType, "u1", 0) + assert.NoError(t, err) + assert.Empty(t, cmp.Diff(du1, du2, protocmp.Transform())) +} + +func newDirectoryUserRecord(directoryUser *directory.User) *databroker.Record { + m := map[string]any{} + bs, _ := json.Marshal(directoryUser) + _ = json.Unmarshal(bs, &m) + s, _ := structpb.NewStruct(m) + return storage.NewStaticRecord(directory.UserRecordType, s) +} diff --git a/proxy/data.go b/proxy/data.go index 3a63c3d3d..3fad0f412 100644 --- a/proxy/data.go +++ b/proxy/data.go @@ -9,28 +9,24 @@ import ( "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/handlers/webauthn" "github.com/pomerium/pomerium/internal/urlutil" - "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/webauthnutil" ) func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Session, isImpersonated bool, err error) { - client := p.state.Load().dataBrokerClient - isImpersonated = false - s, err = session.Get(ctx, client, sessionID) + s, err = storage.GetDataBrokerMessage[session.Session](ctx, sessionID, 0) if s.GetImpersonateSessionId() != "" { - s, err = session.Get(ctx, client, s.GetImpersonateSessionId()) + s, err = storage.GetDataBrokerMessage[session.Session](ctx, s.GetImpersonateSessionId(), 0) isImpersonated = true } - return s, isImpersonated, err } func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) { - client := p.state.Load().dataBrokerClient - return user.Get(ctx, client, userID) + return storage.GetDataBrokerMessage[user.User](ctx, userID, 0) } func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData { @@ -72,21 +68,16 @@ func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData { } func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) { - client := p.state.Load().dataBrokerClient - - res, _ := client.Get(ctx, &databroker.GetRequest{ - Type: "type.googleapis.com/pomerium.config.Config", - Id: "dashboard-settings", - }) - data.IsEnterprise = res.GetRecord() != nil + record, _ := storage.GetDataBrokerRecord(ctx, "type.googleapis.com/pomerium.config.Config", "dashboard-settings", 0) + data.IsEnterprise = record != nil if !data.IsEnterprise { return } - data.DirectoryUser, _ = databroker.GetViaJSON[directory.User](ctx, client, directory.UserRecordType, data.Session.GetUserId()) + data.DirectoryUser, _ = storage.GetDataBrokerObjectViaJSON[directory.User](ctx, directory.UserRecordType, data.Session.GetUserId(), 0) if data.DirectoryUser != nil { for _, groupID := range data.DirectoryUser.GroupIDs { - directoryGroup, _ := databroker.GetViaJSON[directory.Group](ctx, client, directory.GroupRecordType, groupID) + directoryGroup, _ := storage.GetDataBrokerObjectViaJSON[directory.Group](ctx, directory.GroupRecordType, groupID, 0) if directoryGroup != nil { data.DirectoryGroups = append(data.DirectoryGroups, directoryGroup) } diff --git a/proxy/data_test.go b/proxy/data_test.go index 8f9042c12..2ff111ab0 100644 --- a/proxy/data_test.go +++ b/proxy/data_test.go @@ -25,6 +25,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/protoutil" + "github.com/pomerium/pomerium/pkg/storage" ) func Test_getUserInfoData(t *testing.T) { @@ -65,6 +66,7 @@ func Test_getUserInfoData(t *testing.T) { proxy, err := New(ctx, &config.Config{Options: opts}) require.NoError(t, err) proxy.state.Load().dataBrokerClient = client + ctx = storage.WithQuerier(ctx, storage.NewQuerier(client)) require.NoError(t, databrokerpb.PutMulti(ctx, client, makeRecord(&session.Session{ @@ -81,7 +83,7 @@ func Test_getUserInfoData(t *testing.T) { "group_ids": []any{"G1", "G2", "G3"}, }))) - r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) + r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/.pomerium/", nil) r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{ ID: "S1", })) @@ -89,7 +91,9 @@ func Test_getUserInfoData(t *testing.T) { assert.Equal(t, "S1", data.Session.Id) assert.Equal(t, "U1", data.User.Id) assert.True(t, data.IsEnterprise) - assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs) + if assert.NotNil(t, data.DirectoryUser) { + assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs) + } }) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 3926d3589..d48977e25 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -21,6 +21,7 @@ import ( "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/proxy/portal" ) @@ -124,6 +125,8 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error { r.StrictSlash(true) // dashboard handlers are registered to all routes r = p.registerDashboardHandlers(r, opts) + // attach the querier to the context + r.Use(p.querierMiddleware) r.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(p.tracerProvider))) p.currentRouter.Store(r) @@ -133,3 +136,16 @@ func (p *Proxy) setHandlers(ctx context.Context, opts *config.Options) error { func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { p.currentRouter.Load().ServeHTTP(w, r) } + +func (p *Proxy) querierMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx = storage.WithQuerier(ctx, storage.NewCachingQuerier( + storage.NewQuerier(p.state.Load().dataBrokerClient), + storage.GlobalCache, + )) + r = r.WithContext(ctx) + + h.ServeHTTP(w, r) + }) +} diff --git a/proxy/state.go b/proxy/state.go index dab2563d4..110459293 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -13,6 +13,7 @@ import ( "github.com/pomerium/pomerium/internal/authenticateflow" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/storage" ) var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) @@ -87,19 +88,16 @@ func newProxyStateFromConfig(ctx context.Context, tracerProvider oteltrace.Trace state.incomingIDPTokenSessionCreator = config.NewIncomingIDPTokenSessionCreator( func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) { - res, err := state.dataBrokerClient.Get(ctx, &databroker.GetRequest{ - Type: recordType, - Id: recordID, - }) - if err != nil { - return nil, err - } - return res.GetRecord(), nil + return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0) }, func(ctx context.Context, records []*databroker.Record) error { _, err := state.dataBrokerClient.Put(ctx, &databroker.PutRequest{ Records: records, }) + if err != nil { + return err + } + storage.InvalidateCacheForDataBrokerRecords(ctx, records...) return err }, ) From ab5f3ac7f3ba16bd3f4a8513803bc8e837387e63 Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Fri, 21 Mar 2025 11:14:50 -0400 Subject: [PATCH 4/7] core/envoyconfig: make adding ipv6 addresses to internal cidr list conditional on ipv6 support on the system (#5538) --- config/envoyconfig/acmetlsalpn_test.go | 4 +-- config/envoyconfig/bootstrap_test.go | 12 +++---- config/envoyconfig/builder.go | 23 +++++++------ config/envoyconfig/clusters_test.go | 8 ++--- config/envoyconfig/filters.go | 4 +-- config/envoyconfig/http_connection_manager.go | 32 +++++++++++-------- config/envoyconfig/listeners_envoy_admin.go | 2 +- config/envoyconfig/listeners_grpc.go | 2 +- config/envoyconfig/listeners_main.go | 2 +- config/envoyconfig/listeners_main_test.go | 2 +- config/envoyconfig/listeners_metrics.go | 2 +- config/envoyconfig/listeners_test.go | 6 ++-- config/envoyconfig/outbound.go | 2 +- config/envoyconfig/outbound_test.go | 2 +- .../envoyconfig/route_configurations_test.go | 2 +- .../main_http_connection_manager_filter.json | 8 ++--- .../metrics_http_connection_manager.json | 8 ++--- config/envoyconfig/tls_test.go | 2 +- go.mod | 10 +++--- go.sum | 24 +++++++------- internal/controlplane/server.go | 2 ++ pkg/envoy/resource_monitor_test.go | 2 +- 22 files changed, 86 insertions(+), 75 deletions(-) diff --git a/config/envoyconfig/acmetlsalpn_test.go b/config/envoyconfig/acmetlsalpn_test.go index 9af05afc3..33a3455e8 100644 --- a/config/envoyconfig/acmetlsalpn_test.go +++ b/config/envoyconfig/acmetlsalpn_test.go @@ -8,7 +8,7 @@ import ( ) func TestBuilder_buildACMETLSALPNCluster(t *testing.T) { - b := New("local-grpc", "local-http", "local-metrics", nil, nil) + b := New("local-grpc", "local-http", "local-metrics", nil, nil, true) testutil.AssertProtoJSONEqual(t, `{ "name": "pomerium-acme-tls-alpn", @@ -34,7 +34,7 @@ func TestBuilder_buildACMETLSALPNCluster(t *testing.T) { } func TestBuilder_buildACMETLSALPNFilterChain(t *testing.T) { - b := New("local-grpc", "local-http", "local-metrics", nil, nil) + b := New("local-grpc", "local-http", "local-metrics", nil, nil, true) testutil.AssertProtoJSONEqual(t, `{ "filterChainMatch": { diff --git a/config/envoyconfig/bootstrap_test.go b/config/envoyconfig/bootstrap_test.go index b90539c93..ab06b65a1 100644 --- a/config/envoyconfig/bootstrap_test.go +++ b/config/envoyconfig/bootstrap_test.go @@ -13,7 +13,7 @@ import ( func TestBuilder_BuildBootstrapAdmin(t *testing.T) { t.Setenv("TMPDIR", "/tmp") - b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) + b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true) t.Run("valid", func(t *testing.T) { adminCfg, err := b.BuildBootstrapAdmin(&config.Config{ Options: &config.Options{ @@ -35,7 +35,7 @@ func TestBuilder_BuildBootstrapAdmin(t *testing.T) { } func TestBuilder_BuildBootstrapLayeredRuntime(t *testing.T) { - b := New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil) + b := New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil, true) staticCfg, err := b.BuildBootstrapLayeredRuntime(context.Background(), &config.Config{}) assert.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` @@ -61,7 +61,7 @@ func TestBuilder_BuildBootstrapLayeredRuntime(t *testing.T) { func TestBuilder_BuildBootstrapStaticResources(t *testing.T) { t.Run("valid", func(t *testing.T) { - b := New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil) + b := New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil, true) staticCfg, err := b.BuildBootstrapStaticResources(context.Background(), &config.Config{}, false) assert.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` @@ -105,14 +105,14 @@ func TestBuilder_BuildBootstrapStaticResources(t *testing.T) { `, staticCfg) }) t.Run("bad gRPC address", func(t *testing.T) { - b := New("xyz:zyx", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil) + b := New("xyz:zyx", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil, true) _, err := b.BuildBootstrapStaticResources(context.Background(), &config.Config{}, false) assert.Error(t, err) }) } func TestBuilder_BuildBootstrapStatsConfig(t *testing.T) { - b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) + b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true) t.Run("valid", func(t *testing.T) { statsCfg, err := b.BuildBootstrapStatsConfig(&config.Config{ Options: &config.Options{ @@ -132,7 +132,7 @@ func TestBuilder_BuildBootstrapStatsConfig(t *testing.T) { } func TestBuilder_BuildBootstrap(t *testing.T) { - b := New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil) + b := New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil, true) t.Run("OverloadManager", func(t *testing.T) { bootstrap, err := b.BuildBootstrap(context.Background(), &config.Config{ Options: &config.Options{ diff --git a/config/envoyconfig/builder.go b/config/envoyconfig/builder.go index 5fbb30551..621d78c56 100644 --- a/config/envoyconfig/builder.go +++ b/config/envoyconfig/builder.go @@ -7,11 +7,12 @@ import ( // A Builder builds envoy config from pomerium config. type Builder struct { - localGRPCAddress string - localHTTPAddress string - localMetricsAddress string - filemgr *filemgr.Manager - reproxy *reproxy.Handler + localGRPCAddress string + localHTTPAddress string + localMetricsAddress string + filemgr *filemgr.Manager + reproxy *reproxy.Handler + addIPV6InternalRanges bool } // New creates a new Builder. @@ -21,15 +22,17 @@ func New( localMetricsAddress string, fileManager *filemgr.Manager, reproxyHandler *reproxy.Handler, + addIPV6InternalRanges bool, ) *Builder { if reproxyHandler == nil { reproxyHandler = reproxy.New() } return &Builder{ - localGRPCAddress: localGRPCAddress, - localHTTPAddress: localHTTPAddress, - localMetricsAddress: localMetricsAddress, - filemgr: fileManager, - reproxy: reproxyHandler, + localGRPCAddress: localGRPCAddress, + localHTTPAddress: localHTTPAddress, + localMetricsAddress: localMetricsAddress, + filemgr: fileManager, + reproxy: reproxyHandler, + addIPV6InternalRanges: addIPV6InternalRanges, } } diff --git a/config/envoyconfig/clusters_test.go b/config/envoyconfig/clusters_test.go index 91ed0828a..0bd18472c 100644 --- a/config/envoyconfig/clusters_test.go +++ b/config/envoyconfig/clusters_test.go @@ -27,7 +27,7 @@ func Test_BuildClusters(t *testing.T) { opts := config.NewDefaultOptions() ctx := context.Background() - b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) + b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true) clusters, err := b.BuildClusters(ctx, &config.Config{Options: opts}) require.NoError(t, err) testutil.AssertProtoJSONFileEqual(t, "testdata/clusters.json", clusters) @@ -38,7 +38,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { cacheDir, _ := os.UserCacheDir() customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-3133535332543131503345494c.pem") - b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) + b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true) rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}}) rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename() @@ -517,7 +517,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { func Test_buildCluster(t *testing.T) { ctx := context.Background() - b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) + b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true) rootCABytes, _ := getCombinedCertificateAuthority(ctx, &config.Config{Options: &config.Options{}}) rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename() o1 := config.NewDefaultOptions() @@ -1012,7 +1012,7 @@ func Test_bindConfig(t *testing.T) { ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) defer clearTimeout() - b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) + b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true) t.Run("no bind config", func(t *testing.T) { cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{}}, &config.Policy{ From: "https://from.example.com", diff --git a/config/envoyconfig/filters.go b/config/envoyconfig/filters.go index e4c593b5e..594c99512 100644 --- a/config/envoyconfig/filters.go +++ b/config/envoyconfig/filters.go @@ -44,10 +44,10 @@ func ExtAuthzFilter(grpcClientTimeout *durationpb.Duration) *envoy_extensions_fi } // HTTPConnectionManagerFilter creates a new HTTP connection manager filter. -func HTTPConnectionManagerFilter( +func (b *Builder) HTTPConnectionManagerFilter( httpConnectionManager *envoy_extensions_filters_network_http_connection_manager.HttpConnectionManager, ) *envoy_config_listener_v3.Filter { - applyGlobalHTTPConnectionManagerOptions(httpConnectionManager) + b.applyGlobalHTTPConnectionManagerOptions(httpConnectionManager) return &envoy_config_listener_v3.Filter{ Name: "envoy.filters.network.http_connection_manager", ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{ diff --git a/config/envoyconfig/http_connection_manager.go b/config/envoyconfig/http_connection_manager.go index 78d37ae5e..4f9f34ba7 100644 --- a/config/envoyconfig/http_connection_manager.go +++ b/config/envoyconfig/http_connection_manager.go @@ -128,23 +128,29 @@ func (b *Builder) buildLocalReplyConfig( }, nil } -func applyGlobalHTTPConnectionManagerOptions(hcm *envoy_http_connection_manager.HttpConnectionManager) { +func (b *Builder) applyGlobalHTTPConnectionManagerOptions(hcm *envoy_http_connection_manager.HttpConnectionManager) { if hcm.InternalAddressConfig == nil { - // see doc comment on InternalAddressConfig for details - hcm.InternalAddressConfig = &envoy_http_connection_manager.HttpConnectionManager_InternalAddressConfig{ - CidrRanges: []*envoy_config_core_v3.CidrRange{ - // localhost - {AddressPrefix: "127.0.0.1", PrefixLen: wrapperspb.UInt32(32)}, + ranges := []*envoy_config_core_v3.CidrRange{ + // localhost + {AddressPrefix: "127.0.0.1", PrefixLen: wrapperspb.UInt32(32)}, + + // RFC1918 + {AddressPrefix: "10.0.0.0", PrefixLen: wrapperspb.UInt32(8)}, + {AddressPrefix: "192.168.0.0", PrefixLen: wrapperspb.UInt32(16)}, + {AddressPrefix: "172.16.0.0", PrefixLen: wrapperspb.UInt32(12)}, + } + if b.addIPV6InternalRanges { + ranges = append(ranges, []*envoy_config_core_v3.CidrRange{ + // Localhost IPv6 {AddressPrefix: "::1", PrefixLen: wrapperspb.UInt32(128)}, - - // RFC1918 - {AddressPrefix: "10.0.0.0", PrefixLen: wrapperspb.UInt32(8)}, - {AddressPrefix: "192.168.0.0", PrefixLen: wrapperspb.UInt32(16)}, - {AddressPrefix: "172.16.0.0", PrefixLen: wrapperspb.UInt32(12)}, - // RFC4193 {AddressPrefix: "fd00::", PrefixLen: wrapperspb.UInt32(8)}, - }, + }...) + } + + // see doc comment on InternalAddressConfig for details + hcm.InternalAddressConfig = &envoy_http_connection_manager.HttpConnectionManager_InternalAddressConfig{ + CidrRanges: ranges, } } } diff --git a/config/envoyconfig/listeners_envoy_admin.go b/config/envoyconfig/listeners_envoy_admin.go index f2bffb611..96f95bfe2 100644 --- a/config/envoyconfig/listeners_envoy_admin.go +++ b/config/envoyconfig/listeners_envoy_admin.go @@ -51,7 +51,7 @@ func (b *Builder) buildEnvoyAdminHTTPConnectionManagerFilter() *envoy_config_lis }, }}) - return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{ + return b.HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{ CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO, StatPrefix: "envoy-admin", RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{ diff --git a/config/envoyconfig/listeners_grpc.go b/config/envoyconfig/listeners_grpc.go index 2f92b3de1..5a8295cdd 100644 --- a/config/envoyconfig/listeners_grpc.go +++ b/config/envoyconfig/listeners_grpc.go @@ -98,7 +98,7 @@ func (b *Builder) buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_ Routes: routes, }}) - return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{ + return b.HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{ CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO, StatPrefix: "grpc_ingress", // limit request first byte to last byte time diff --git a/config/envoyconfig/listeners_main.go b/config/envoyconfig/listeners_main.go index 04d2b0ca2..875cc3357 100644 --- a/config/envoyconfig/listeners_main.go +++ b/config/envoyconfig/listeners_main.go @@ -233,7 +233,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter( } } - return HTTPConnectionManagerFilter(mgr), nil + return b.HTTPConnectionManagerFilter(mgr), nil } func newListenerAccessLog() *envoy_config_accesslog_v3.AccessLog { diff --git a/config/envoyconfig/listeners_main_test.go b/config/envoyconfig/listeners_main_test.go index 0a77de4a3..f39d98415 100644 --- a/config/envoyconfig/listeners_main_test.go +++ b/config/envoyconfig/listeners_main_test.go @@ -12,7 +12,7 @@ import ( ) func Test_requireProxyProtocol(t *testing.T) { - b := New("local-grpc", "local-http", "local-metrics", nil, nil) + b := New("local-grpc", "local-http", "local-metrics", nil, nil, true) t.Run("required", func(t *testing.T) { li, err := b.buildMainListener(context.Background(), &config.Config{Options: &config.Options{ UseProxyProtocol: true, diff --git a/config/envoyconfig/listeners_metrics.go b/config/envoyconfig/listeners_metrics.go index 5ceff5af1..2488dbc7a 100644 --- a/config/envoyconfig/listeners_metrics.go +++ b/config/envoyconfig/listeners_metrics.go @@ -121,7 +121,7 @@ func (b *Builder) buildMetricsHTTPConnectionManagerFilter() *envoy_config_listen }, }}) - return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{ + return b.HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{ CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO, StatPrefix: "metrics", RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{ diff --git a/config/envoyconfig/listeners_test.go b/config/envoyconfig/listeners_test.go index a5a40c8bb..03b740f07 100644 --- a/config/envoyconfig/listeners_test.go +++ b/config/envoyconfig/listeners_test.go @@ -51,7 +51,7 @@ func TestBuildListeners(t *testing.T) { OutboundPort: "10003", MetricsPort: "10004", } - b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) + b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true) t.Run("enable grpc by default", func(t *testing.T) { cfg := cfg.Clone() lis, err := b.BuildListeners(ctx, cfg, false) @@ -125,7 +125,7 @@ func Test_buildMetricsHTTPConnectionManagerFilter(t *testing.T) { certFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-5a353247453159375849565a.pem") keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3159554e32473758435257364b.pem") - b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) + b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true) li, err := b.buildMetricsListener(&config.Config{ Options: &config.Options{ MetricsAddr: "127.0.0.1:9902", @@ -143,7 +143,7 @@ func Test_buildMetricsHTTPConnectionManagerFilter(t *testing.T) { } func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) { - b := New("local-grpc", "local-http", "local-metrics", nil, nil) + b := New("local-grpc", "local-http", "local-metrics", nil, nil, true) options := config.NewDefaultOptions() options.SkipXffAppend = true diff --git a/config/envoyconfig/outbound.go b/config/envoyconfig/outbound.go index fbab8c0d5..56d238e0a 100644 --- a/config/envoyconfig/outbound.go +++ b/config/envoyconfig/outbound.go @@ -42,7 +42,7 @@ func (b *Builder) buildOutboundListener(cfg *config.Config) (*envoy_config_liste func (b *Builder) buildOutboundHTTPConnectionManager() *envoy_config_listener_v3.Filter { rc := b.buildOutboundRouteConfiguration() - return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{ + return b.HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{ CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO, StatPrefix: "grpc_egress", // limit request first byte to last byte time diff --git a/config/envoyconfig/outbound_test.go b/config/envoyconfig/outbound_test.go index 3c7233f7a..36a6c40b7 100644 --- a/config/envoyconfig/outbound_test.go +++ b/config/envoyconfig/outbound_test.go @@ -7,7 +7,7 @@ import ( ) func Test_buildOutboundRoutes(t *testing.T) { - b := New("local-grpc", "local-http", "local-metrics", nil, nil) + b := New("local-grpc", "local-http", "local-metrics", nil, nil, true) routes := b.buildOutboundRoutes() testutil.AssertProtoJSONEqual(t, `[ { diff --git a/config/envoyconfig/route_configurations_test.go b/config/envoyconfig/route_configurations_test.go index a0cbef8f6..b5f61709e 100644 --- a/config/envoyconfig/route_configurations_test.go +++ b/config/envoyconfig/route_configurations_test.go @@ -32,7 +32,7 @@ func TestBuilder_buildMainRouteConfiguration(t *testing.T) { }, }, }} - b := New("grpc", "http", "metrics", filemgr.NewManager(), nil) + b := New("grpc", "http", "metrics", filemgr.NewManager(), nil, true) routeConfiguration, err := b.buildMainRouteConfiguration(ctx, cfg) assert.NoError(t, err) testutil.AssertProtoJSONEqual(t, `{ diff --git a/config/envoyconfig/testdata/main_http_connection_manager_filter.json b/config/envoyconfig/testdata/main_http_connection_manager_filter.json index baa435d2b..e95e56803 100644 --- a/config/envoyconfig/testdata/main_http_connection_manager_filter.json +++ b/config/envoyconfig/testdata/main_http_connection_manager_filter.json @@ -231,10 +231,6 @@ "addressPrefix": "127.0.0.1", "prefixLen": 32 }, - { - "addressPrefix": "::1", - "prefixLen": 128 - }, { "addressPrefix": "10.0.0.0", "prefixLen": 8 @@ -247,6 +243,10 @@ "addressPrefix": "172.16.0.0", "prefixLen": 12 }, + { + "addressPrefix": "::1", + "prefixLen": 128 + }, { "addressPrefix": "fd00::", "prefixLen": 8 diff --git a/config/envoyconfig/testdata/metrics_http_connection_manager.json b/config/envoyconfig/testdata/metrics_http_connection_manager.json index f1123d42c..bdd5e7a3b 100644 --- a/config/envoyconfig/testdata/metrics_http_connection_manager.json +++ b/config/envoyconfig/testdata/metrics_http_connection_manager.json @@ -61,10 +61,6 @@ "addressPrefix": "127.0.0.1", "prefixLen": 32 }, - { - "addressPrefix": "::1", - "prefixLen": 128 - }, { "addressPrefix": "10.0.0.0", "prefixLen": 8 @@ -77,6 +73,10 @@ "addressPrefix": "172.16.0.0", "prefixLen": 12 }, + { + "addressPrefix": "::1", + "prefixLen": 128 + }, { "addressPrefix": "fd00::", "prefixLen": 8 diff --git a/config/envoyconfig/tls_test.go b/config/envoyconfig/tls_test.go index f5652a218..0abb09d33 100644 --- a/config/envoyconfig/tls_test.go +++ b/config/envoyconfig/tls_test.go @@ -82,7 +82,7 @@ func TestValidateCertificate(t *testing.T) { } func Test_buildDownstreamTLSContext(t *testing.T) { - b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) + b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil, true) cacheDir, _ := os.UserCacheDir() clientCAFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "client-ca-4e4c564e5a36544a4a33385a.pem") diff --git a/go.mod b/go.mod index 2e029b421..94464482b 100644 --- a/go.mod +++ b/go.mod @@ -88,11 +88,11 @@ require ( go.uber.org/automaxprocs v1.6.0 go.uber.org/mock v0.5.0 go.uber.org/zap v1.27.0 - golang.org/x/crypto v0.35.0 - golang.org/x/net v0.36.0 + golang.org/x/crypto v0.36.0 + golang.org/x/net v0.37.0 golang.org/x/oauth2 v0.27.0 - golang.org/x/sync v0.11.0 - golang.org/x/sys v0.30.0 + golang.org/x/sync v0.12.0 + golang.org/x/sys v0.31.0 golang.org/x/time v0.10.0 google.golang.org/api v0.223.0 google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2 @@ -239,7 +239,7 @@ require ( go.uber.org/zap/exp v0.3.0 // indirect golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect golang.org/x/mod v0.20.0 // indirect - golang.org/x/text v0.22.0 // indirect + golang.org/x/text v0.23.0 // indirect golang.org/x/tools v0.24.0 // indirect google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f // indirect diff --git a/go.sum b/go.sum index af70a918a..7595b3557 100644 --- a/go.sum +++ b/go.sum @@ -750,8 +750,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= -golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -827,8 +827,8 @@ golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA= -golang.org/x/net v0.36.0/go.mod h1:bFmbeoIPfrw4sMHNhb4J9f6+tPziuGjq7Jk/38fxi1I= +golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= +golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -851,8 +851,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= -golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -912,15 +912,15 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= -golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= -golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= -golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= +golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= +golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -931,8 +931,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= -golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index 96bf1f3d4..9de4e29f4 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -13,6 +13,7 @@ import ( "github.com/rs/zerolog" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1" + "golang.org/x/net/nettest" "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/health/grpc_health_v1" @@ -177,6 +178,7 @@ func NewServer( srv.MetricsListener.Addr().String(), srv.filemgr, srv.reproxy, + nettest.SupportsIPv6(), ) res, err := srv.buildDiscoveryResources(ctx) diff --git a/pkg/envoy/resource_monitor_test.go b/pkg/envoy/resource_monitor_test.go index 8210f84e3..16581d63f 100644 --- a/pkg/envoy/resource_monitor_test.go +++ b/pkg/envoy/resource_monitor_test.go @@ -713,7 +713,7 @@ func TestSharedResourceMonitor(t *testing.T) { } func TestBootstrapConfig(t *testing.T) { - b := envoyconfig.New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil) + b := envoyconfig.New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil, true) testEnvoyPid := 99 tempDir := t.TempDir() monitor, err := NewSharedResourceMonitor(context.Background(), config.NewStaticSource(nil), tempDir, WithCgroupDriver(&cgroupV2Driver{ From a96ab2fe93ccc304cfb5f56944dc7c461aa24738 Mon Sep 17 00:00:00 2001 From: Joe Kralicky Date: Tue, 25 Mar 2025 10:43:04 -0400 Subject: [PATCH 5/7] move internal/telemetry/trace => pkg/telemetry/trace (#5541) --- authenticate/authenticate.go | 2 +- authenticate/handlers.go | 2 +- authorize/authorize.go | 2 +- authorize/evaluator/evaluator.go | 2 +- authorize/evaluator/headers_evaluator.go | 2 +- authorize/evaluator/policy_evaluator.go | 2 +- authorize/internal/store/store.go | 2 +- cmd/pomerium/main.go | 2 +- config/envoyconfig/bootstrap.go | 2 +- config/envoyconfig/clusters.go | 2 +- config/envoyconfig/listeners.go | 2 +- config/envoyconfig/route_configurations.go | 2 +- config/envoyconfig/tracing.go | 2 +- databroker/cache.go | 2 +- internal/authenticateflow/authenticateflow.go | 2 +- internal/authenticateflow/stateful.go | 2 +- internal/authenticateflow/stateless.go | 2 +- internal/controlplane/http.go | 2 +- internal/controlplane/server.go | 2 +- internal/databroker/config_source.go | 2 +- internal/databroker/server.go | 2 +- internal/testenv/environment.go | 2 +- internal/testenv/selftests/tracing_test.go | 2 +- internal/testenv/snippets/wait.go | 2 +- internal/testenv/upstreams/grpc.go | 2 +- internal/testenv/upstreams/http.go | 2 +- internal/testenv/upstreams/tcp.go | 2 +- internal/testutil/minio.go | 2 +- internal/testutil/postgres.go | 2 +- internal/testutil/tracetest/tracing.go | 2 +- pkg/cmd/pomerium/pomerium.go | 2 +- pkg/identity/oidc/oidc.go | 2 +- {internal => pkg}/telemetry/trace/carriers.go | 0 {internal => pkg}/telemetry/trace/carriers_test.go | 2 +- {internal => pkg}/telemetry/trace/client.go | 0 {internal => pkg}/telemetry/trace/client_test.go | 2 +- {internal => pkg}/telemetry/trace/debug.go | 0 {internal => pkg}/telemetry/trace/debug_test.go | 2 +- {internal => pkg}/telemetry/trace/global.go | 0 {internal => pkg}/telemetry/trace/global_test.go | 2 +- {internal => pkg}/telemetry/trace/main_test.go | 2 +- {internal => pkg}/telemetry/trace/middleware.go | 0 {internal => pkg}/telemetry/trace/middleware_test.go | 2 +- {internal => pkg}/telemetry/trace/server.go | 0 {internal => pkg}/telemetry/trace/trace.go | 0 {internal => pkg}/telemetry/trace/trace_export_test.go | 0 {internal => pkg}/telemetry/trace/util.go | 0 proxy/handlers.go | 2 +- proxy/proxy.go | 2 +- 49 files changed, 40 insertions(+), 40 deletions(-) rename {internal => pkg}/telemetry/trace/carriers.go (100%) rename {internal => pkg}/telemetry/trace/carriers_test.go (91%) rename {internal => pkg}/telemetry/trace/client.go (100%) rename {internal => pkg}/telemetry/trace/client_test.go (99%) rename {internal => pkg}/telemetry/trace/debug.go (100%) rename {internal => pkg}/telemetry/trace/debug_test.go (99%) rename {internal => pkg}/telemetry/trace/global.go (100%) rename {internal => pkg}/telemetry/trace/global_test.go (88%) rename {internal => pkg}/telemetry/trace/main_test.go (69%) rename {internal => pkg}/telemetry/trace/middleware.go (100%) rename {internal => pkg}/telemetry/trace/middleware_test.go (99%) rename {internal => pkg}/telemetry/trace/server.go (100%) rename {internal => pkg}/telemetry/trace/trace.go (100%) rename {internal => pkg}/telemetry/trace/trace_export_test.go (100%) rename {internal => pkg}/telemetry/trace/util.go (100%) diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 7cf18f0ab..9f56c6470 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -10,8 +10,8 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/telemetry/trace" oteltrace "go.opentelemetry.io/otel/trace" ) diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 5c54a806e..fe99d5c29 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -25,11 +25,11 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/sessions" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/identity" "github.com/pomerium/pomerium/pkg/identity/oidc" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) // Handler returns the authenticate service's handler chain. diff --git a/authorize/authorize.go b/authorize/authorize.go index f6ed77525..701113fed 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -18,10 +18,10 @@ import ( "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) // Authorize struct holds diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index e5fd9501d..0d13f166c 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -19,10 +19,10 @@ import ( "github.com/pomerium/pomerium/internal/errgrouputil" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/contextutil" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/policy/criteria" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) // Request contains the inputs needed for evaluation. diff --git a/authorize/evaluator/headers_evaluator.go b/authorize/evaluator/headers_evaluator.go index 4eaf32a06..0ed71863a 100644 --- a/authorize/evaluator/headers_evaluator.go +++ b/authorize/evaluator/headers_evaluator.go @@ -9,7 +9,7 @@ import ( "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) // HeadersResponse is the output from the headers.rego script. diff --git a/authorize/evaluator/policy_evaluator.go b/authorize/evaluator/policy_evaluator.go index 56a4f228d..f13c38481 100644 --- a/authorize/evaluator/policy_evaluator.go +++ b/authorize/evaluator/policy_evaluator.go @@ -11,11 +11,11 @@ import ( "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/contextutil" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/policy" "github.com/pomerium/pomerium/pkg/policy/criteria" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) // PolicyRequest is the input to policy evaluation. diff --git a/authorize/internal/store/store.go b/authorize/internal/store/store.go index 34f2c3676..6ea2426cb 100644 --- a/authorize/internal/store/store.go +++ b/authorize/internal/store/store.go @@ -21,9 +21,9 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) // A Store stores data for the OPA rego policy evaluation. diff --git a/cmd/pomerium/main.go b/cmd/pomerium/main.go index 3289a4391..f277ed5be 100644 --- a/cmd/pomerium/main.go +++ b/cmd/pomerium/main.go @@ -12,13 +12,13 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/version" _ "github.com/pomerium/pomerium/internal/zero/bootstrap/writers/filesystem" _ "github.com/pomerium/pomerium/internal/zero/bootstrap/writers/k8s" zero_cmd "github.com/pomerium/pomerium/internal/zero/cmd" "github.com/pomerium/pomerium/pkg/cmd/pomerium" "github.com/pomerium/pomerium/pkg/envoy/files" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) func main() { diff --git a/config/envoyconfig/bootstrap.go b/config/envoyconfig/bootstrap.go index 4e1e5228c..6e2401415 100644 --- a/config/envoyconfig/bootstrap.go +++ b/config/envoyconfig/bootstrap.go @@ -19,7 +19,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config/otelconfig" "github.com/pomerium/pomerium/internal/telemetry" - "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/structpb" ) diff --git a/config/envoyconfig/clusters.go b/config/envoyconfig/clusters.go index bfb083cc5..be095e3fe 100644 --- a/config/envoyconfig/clusters.go +++ b/config/envoyconfig/clusters.go @@ -20,8 +20,8 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) // BuildClusters builds envoy clusters from the given config. diff --git a/config/envoyconfig/listeners.go b/config/envoyconfig/listeners.go index e02b39e3e..c8676ee38 100644 --- a/config/envoyconfig/listeners.go +++ b/config/envoyconfig/listeners.go @@ -9,7 +9,7 @@ import ( "google.golang.org/protobuf/types/known/wrapperspb" "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) const listenerBufferLimit uint32 = 32 * 1024 diff --git a/config/envoyconfig/route_configurations.go b/config/envoyconfig/route_configurations.go index 9881e7322..dbf67a8fa 100644 --- a/config/envoyconfig/route_configurations.go +++ b/config/envoyconfig/route_configurations.go @@ -11,8 +11,8 @@ import ( "google.golang.org/protobuf/types/known/wrapperspb" "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) // BuildRouteConfigurations builds the route configurations for the RDS service. diff --git a/config/envoyconfig/tracing.go b/config/envoyconfig/tracing.go index 042108d93..a86b20788 100644 --- a/config/envoyconfig/tracing.go +++ b/config/envoyconfig/tracing.go @@ -17,7 +17,7 @@ import ( extensions_uuidx "github.com/pomerium/envoy-custom/api/extensions/request_id/uuidx" extensions_pomerium_otel "github.com/pomerium/envoy-custom/api/extensions/tracers/pomerium_otel" "github.com/pomerium/pomerium/config/otelconfig" - "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/wrapperspb" ) diff --git a/databroker/cache.go b/databroker/cache.go index a0f55c2c0..522291822 100644 --- a/databroker/cache.go +++ b/databroker/cache.go @@ -19,7 +19,6 @@ import ( "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/envoy/files" @@ -28,6 +27,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/identity" "github.com/pomerium/pomerium/pkg/identity/manager" + "github.com/pomerium/pomerium/pkg/telemetry/trace" oteltrace "go.opentelemetry.io/otel/trace" ) diff --git a/internal/authenticateflow/authenticateflow.go b/internal/authenticateflow/authenticateflow.go index 2fd7edafe..3e6dcc6d8 100644 --- a/internal/authenticateflow/authenticateflow.go +++ b/internal/authenticateflow/authenticateflow.go @@ -14,10 +14,10 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/structpb" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/identity" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) // timeNow is time.Now but pulled out as a variable for tests. diff --git a/internal/authenticateflow/stateful.go b/internal/authenticateflow/stateful.go index 58d62d516..cf7c4bb2c 100644 --- a/internal/authenticateflow/stateful.go +++ b/internal/authenticateflow/stateful.go @@ -23,7 +23,6 @@ import ( "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/sessions" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc" @@ -33,6 +32,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/identity" "github.com/pomerium/pomerium/pkg/identity/manager" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) // Stateful implements the stateful authentication flow. In this flow, the diff --git a/internal/authenticateflow/stateless.go b/internal/authenticateflow/stateless.go index ef188eae9..dcb22cdca 100644 --- a/internal/authenticateflow/stateless.go +++ b/internal/authenticateflow/stateless.go @@ -21,7 +21,6 @@ import ( "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/sessions" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc" @@ -31,6 +30,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/hpke" "github.com/pomerium/pomerium/pkg/identity" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.opentelemetry.io/otel" oteltrace "go.opentelemetry.io/otel/trace" diff --git a/internal/controlplane/http.go b/internal/controlplane/http.go index 8f01e5881..e6091cf97 100644 --- a/internal/controlplane/http.go +++ b/internal/controlplane/http.go @@ -17,10 +17,10 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/telemetry" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" hpke_handlers "github.com/pomerium/pomerium/pkg/hpke/handlers" "github.com/pomerium/pomerium/pkg/telemetry/requestid" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) func (srv *Server) addHTTPMiddleware(ctx context.Context, root *mux.Router, _ *config.Config) { diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index 9de4e29f4..ba58b5711 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -28,7 +28,6 @@ import ( "github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/httputil/reproxy" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/pkg/envoy/files" @@ -36,6 +35,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/httputil" "github.com/pomerium/pomerium/pkg/telemetry/requestid" + "github.com/pomerium/pomerium/pkg/telemetry/trace" oteltrace "go.opentelemetry.io/otel/trace" ) diff --git a/internal/databroker/config_source.go b/internal/databroker/config_source.go index 403bcf481..fbaabbcbe 100644 --- a/internal/databroker/config_source.go +++ b/internal/databroker/config_source.go @@ -15,13 +15,13 @@ import ( "github.com/pomerium/pomerium/internal/hashutil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc" configpb "github.com/pomerium/pomerium/pkg/grpc/config" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/health" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" oteltrace "go.opentelemetry.io/otel/trace" googlegrpc "google.golang.org/grpc" diff --git a/internal/databroker/server.go b/internal/databroker/server.go index 863a0f9d3..61b72c66b 100644 --- a/internal/databroker/server.go +++ b/internal/databroker/server.go @@ -17,11 +17,11 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/registry" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage/inmemory" "github.com/pomerium/pomerium/pkg/storage/postgres" + "github.com/pomerium/pomerium/pkg/telemetry/trace" oteltrace "go.opentelemetry.io/otel/trace" ) diff --git a/internal/testenv/environment.go b/internal/testenv/environment.go index 8b16f99ea..dce233b2a 100644 --- a/internal/testenv/environment.go +++ b/internal/testenv/environment.go @@ -38,7 +38,6 @@ import ( "github.com/pomerium/pomerium/config/otelconfig" databroker_service "github.com/pomerium/pomerium/databroker" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/testenv/envutil" "github.com/pomerium/pomerium/internal/testenv/values" "github.com/pomerium/pomerium/internal/version" @@ -49,6 +48,7 @@ import ( "github.com/pomerium/pomerium/pkg/identity/manager" "github.com/pomerium/pomerium/pkg/netutil" "github.com/pomerium/pomerium/pkg/slices" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/internal/testenv/selftests/tracing_test.go b/internal/testenv/selftests/tracing_test.go index 8e29bb1e4..3fdd9e515 100644 --- a/internal/testenv/selftests/tracing_test.go +++ b/internal/testenv/selftests/tracing_test.go @@ -21,12 +21,12 @@ import ( oteltrace "go.opentelemetry.io/otel/trace" "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/scenarios" "github.com/pomerium/pomerium/internal/testenv/snippets" "github.com/pomerium/pomerium/internal/testenv/upstreams" . "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) var allServices = []string{ diff --git a/internal/testenv/snippets/wait.go b/internal/testenv/snippets/wait.go index eb44ea36c..77204fc30 100644 --- a/internal/testenv/snippets/wait.go +++ b/internal/testenv/snippets/wait.go @@ -4,9 +4,9 @@ import ( "context" "time" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/pkg/grpcutil" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" diff --git a/internal/testenv/upstreams/grpc.go b/internal/testenv/upstreams/grpc.go index c6c203586..115718982 100644 --- a/internal/testenv/upstreams/grpc.go +++ b/internal/testenv/upstreams/grpc.go @@ -6,10 +6,10 @@ import ( "net" "strings" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/snippets" "github.com/pomerium/pomerium/internal/testenv/values" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" oteltrace "go.opentelemetry.io/otel/trace" "google.golang.org/grpc" diff --git a/internal/testenv/upstreams/http.go b/internal/testenv/upstreams/http.go index 8d97d61e3..36763ee40 100644 --- a/internal/testenv/upstreams/http.go +++ b/internal/testenv/upstreams/http.go @@ -17,10 +17,10 @@ import ( "github.com/gorilla/mux" "github.com/gorilla/websocket" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/snippets" "github.com/pomerium/pomerium/internal/testenv/values" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/attribute" diff --git a/internal/testenv/upstreams/tcp.go b/internal/testenv/upstreams/tcp.go index 6eab5b629..7d00feee0 100644 --- a/internal/testenv/upstreams/tcp.go +++ b/internal/testenv/upstreams/tcp.go @@ -13,9 +13,9 @@ import ( "net/url" "sync" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/values" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "go.opentelemetry.io/otel/attribute" oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/net/http2" diff --git a/internal/testutil/minio.go b/internal/testutil/minio.go index 2b0d7e57e..559e35776 100644 --- a/internal/testutil/minio.go +++ b/internal/testutil/minio.go @@ -6,7 +6,7 @@ import ( "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" - "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" oteltrace "go.opentelemetry.io/otel/trace" diff --git a/internal/testutil/postgres.go b/internal/testutil/postgres.go index 15bb70826..77d8d8045 100644 --- a/internal/testutil/postgres.go +++ b/internal/testutil/postgres.go @@ -8,7 +8,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5" - "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" oteltrace "go.opentelemetry.io/otel/trace" diff --git a/internal/testutil/tracetest/tracing.go b/internal/testutil/tracetest/tracing.go index 06db43f62..3ccb29119 100644 --- a/internal/testutil/tracetest/tracing.go +++ b/internal/testutil/tracetest/tracing.go @@ -16,7 +16,7 @@ import ( "unique" gocmp "github.com/google/go-cmp/cmp" - "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" oteltrace "go.opentelemetry.io/otel/trace" diff --git a/pkg/cmd/pomerium/pomerium.go b/pkg/cmd/pomerium/pomerium.go index e87a2fb6a..b26a41285 100644 --- a/pkg/cmd/pomerium/pomerium.go +++ b/pkg/cmd/pomerium/pomerium.go @@ -23,11 +23,11 @@ import ( "github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/registry" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/version" derivecert_config "github.com/pomerium/pomerium/pkg/derivecert/config" "github.com/pomerium/pomerium/pkg/envoy" "github.com/pomerium/pomerium/pkg/envoy/files" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "github.com/pomerium/pomerium/proxy" oteltrace "go.opentelemetry.io/otel/trace" ) diff --git a/pkg/identity/oidc/oidc.go b/pkg/identity/oidc/oidc.go index 41148834a..75950d013 100644 --- a/pkg/identity/oidc/oidc.go +++ b/pkg/identity/oidc/oidc.go @@ -16,11 +16,11 @@ import ( "golang.org/x/oauth2" "github.com/pomerium/pomerium/internal/httputil" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/pkg/identity/identity" "github.com/pomerium/pomerium/pkg/identity/oauth" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) // Name identifies the generic OpenID Connect provider. diff --git a/internal/telemetry/trace/carriers.go b/pkg/telemetry/trace/carriers.go similarity index 100% rename from internal/telemetry/trace/carriers.go rename to pkg/telemetry/trace/carriers.go diff --git a/internal/telemetry/trace/carriers_test.go b/pkg/telemetry/trace/carriers_test.go similarity index 91% rename from internal/telemetry/trace/carriers_test.go rename to pkg/telemetry/trace/carriers_test.go index 1e63df4b6..ff5d491fb 100644 --- a/internal/telemetry/trace/carriers_test.go +++ b/pkg/telemetry/trace/carriers_test.go @@ -4,7 +4,7 @@ import ( "net/url" "testing" - "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "github.com/stretchr/testify/assert" ) diff --git a/internal/telemetry/trace/client.go b/pkg/telemetry/trace/client.go similarity index 100% rename from internal/telemetry/trace/client.go rename to pkg/telemetry/trace/client.go diff --git a/internal/telemetry/trace/client_test.go b/pkg/telemetry/trace/client_test.go similarity index 99% rename from internal/telemetry/trace/client_test.go rename to pkg/telemetry/trace/client_test.go index ec4771048..026eb16ee 100644 --- a/internal/telemetry/trace/client_test.go +++ b/pkg/telemetry/trace/client_test.go @@ -13,13 +13,13 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/scenarios" "github.com/pomerium/pomerium/internal/testenv/snippets" . "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive "github.com/pomerium/pomerium/internal/testutil/tracetest/mock_otlptrace" "github.com/pomerium/pomerium/internal/version" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" diff --git a/internal/telemetry/trace/debug.go b/pkg/telemetry/trace/debug.go similarity index 100% rename from internal/telemetry/trace/debug.go rename to pkg/telemetry/trace/debug.go diff --git a/internal/telemetry/trace/debug_test.go b/pkg/telemetry/trace/debug_test.go similarity index 99% rename from internal/telemetry/trace/debug_test.go rename to pkg/telemetry/trace/debug_test.go index 05c937c43..1bf62a7ad 100644 --- a/internal/telemetry/trace/debug_test.go +++ b/pkg/telemetry/trace/debug_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - "github.com/pomerium/pomerium/internal/telemetry/trace" . "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive + "github.com/pomerium/pomerium/pkg/telemetry/trace" "github.com/stretchr/testify/assert" sdktrace "go.opentelemetry.io/otel/sdk/trace" oteltrace "go.opentelemetry.io/otel/trace" diff --git a/internal/telemetry/trace/global.go b/pkg/telemetry/trace/global.go similarity index 100% rename from internal/telemetry/trace/global.go rename to pkg/telemetry/trace/global.go diff --git a/internal/telemetry/trace/global_test.go b/pkg/telemetry/trace/global_test.go similarity index 88% rename from internal/telemetry/trace/global_test.go rename to pkg/telemetry/trace/global_test.go index 852d2c82a..00f08338d 100644 --- a/internal/telemetry/trace/global_test.go +++ b/pkg/telemetry/trace/global_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "github.com/stretchr/testify/assert" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace/noop" diff --git a/internal/telemetry/trace/main_test.go b/pkg/telemetry/trace/main_test.go similarity index 69% rename from internal/telemetry/trace/main_test.go rename to pkg/telemetry/trace/main_test.go index b546c8e45..dac0af9c0 100644 --- a/internal/telemetry/trace/main_test.go +++ b/pkg/telemetry/trace/main_test.go @@ -4,7 +4,7 @@ import ( "os" "testing" - "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) func TestMain(m *testing.M) { diff --git a/internal/telemetry/trace/middleware.go b/pkg/telemetry/trace/middleware.go similarity index 100% rename from internal/telemetry/trace/middleware.go rename to pkg/telemetry/trace/middleware.go diff --git a/internal/telemetry/trace/middleware_test.go b/pkg/telemetry/trace/middleware_test.go similarity index 99% rename from internal/telemetry/trace/middleware_test.go rename to pkg/telemetry/trace/middleware_test.go index ff01cc4d2..7a1a66537 100644 --- a/internal/telemetry/trace/middleware_test.go +++ b/pkg/telemetry/trace/middleware_test.go @@ -9,7 +9,7 @@ import ( "time" "github.com/gorilla/mux" - "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "github.com/stretchr/testify/assert" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" sdktrace "go.opentelemetry.io/otel/sdk/trace" diff --git a/internal/telemetry/trace/server.go b/pkg/telemetry/trace/server.go similarity index 100% rename from internal/telemetry/trace/server.go rename to pkg/telemetry/trace/server.go diff --git a/internal/telemetry/trace/trace.go b/pkg/telemetry/trace/trace.go similarity index 100% rename from internal/telemetry/trace/trace.go rename to pkg/telemetry/trace/trace.go diff --git a/internal/telemetry/trace/trace_export_test.go b/pkg/telemetry/trace/trace_export_test.go similarity index 100% rename from internal/telemetry/trace/trace_export_test.go rename to pkg/telemetry/trace/trace_export_test.go diff --git a/internal/telemetry/trace/util.go b/pkg/telemetry/trace/util.go similarity index 100% rename from internal/telemetry/trace/util.go rename to pkg/telemetry/trace/util.go diff --git a/proxy/handlers.go b/proxy/handlers.go index dda079a88..2013df2a4 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -15,8 +15,8 @@ import ( "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/middleware" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/telemetry/trace" ) // registerDashboardHandlers returns the proxy service's ServeMux diff --git a/proxy/proxy.go b/proxy/proxy.go index d48977e25..a6681da9f 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -19,9 +19,9 @@ import ( "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" - "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/storage" + "github.com/pomerium/pomerium/pkg/telemetry/trace" "github.com/pomerium/pomerium/proxy/portal" ) From e7675a5b2abf3da0e99c0efe9192c58ea519ea3b Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Tue, 25 Mar 2025 10:11:36 -0600 Subject: [PATCH 6/7] databroker: preserve data type when deleting changeset (#5540) * databroker: preserve data type when deleting changeset * use cs.now --- pkg/grpc/databroker/changeset.go | 15 ++++---- pkg/grpc/databroker/changeset_test.go | 52 +++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 pkg/grpc/databroker/changeset_test.go diff --git a/pkg/grpc/databroker/changeset.go b/pkg/grpc/databroker/changeset.go index cc72f0e2b..0d0d3ba81 100644 --- a/pkg/grpc/databroker/changeset.go +++ b/pkg/grpc/databroker/changeset.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -14,7 +14,7 @@ func GetChangeSet(current, target RecordSetBundle, cmpFn RecordCompareFn) []*Rec cs := &changeSet{now: timestamppb.Now()} for _, rec := range current.GetRemoved(target).Flatten() { - cs.Remove(rec.GetType(), rec.GetId()) + cs.Remove(rec) } for _, rec := range current.GetModified(target, cmpFn).Flatten() { cs.Upsert(rec) @@ -33,13 +33,10 @@ type changeSet struct { } // Remove adds a record to the change set. -func (cs *changeSet) Remove(typ string, id string) { - cs.updates = append(cs.updates, &Record{ - Type: typ, - Id: id, - DeletedAt: cs.now, - Data: &anypb.Any{TypeUrl: typ}, - }) +func (cs *changeSet) Remove(record *Record) { + record = proto.Clone(record).(*Record) + record.DeletedAt = cs.now + cs.updates = append(cs.updates, record) } // Upsert adds a record to the change set. diff --git a/pkg/grpc/databroker/changeset_test.go b/pkg/grpc/databroker/changeset_test.go new file mode 100644 index 000000000..14c101055 --- /dev/null +++ b/pkg/grpc/databroker/changeset_test.go @@ -0,0 +1,52 @@ +package databroker_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/pomerium/datasource/pkg/directory" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +func TestGetChangeset(t *testing.T) { + t.Parallel() + + rsb1 := databroker.RecordSetBundle{} + rsb2 := databroker.RecordSetBundle{} + updates := databroker.GetChangeSet(rsb1, rsb2, func(record1, record2 *databroker.Record) bool { + return cmp.Equal(record1, record2, protocmp.Transform()) + }) + assert.Len(t, updates, 0) + + rsb1 = databroker.RecordSetBundle{} + rsb1.Add(&databroker.Record{ + Type: directory.UserRecordType, + Id: "user-1", + Data: protoutil.NewAny(mustNewStruct(map[string]any{ + "email": "user-1@example.com", + })), + }) + rsb2 = databroker.RecordSetBundle{} + updates = databroker.GetChangeSet(rsb1, rsb2, func(record1, record2 *databroker.Record) bool { + return cmp.Equal(record1, record2, protocmp.Transform()) + }) + if assert.Len(t, updates, 1) { + assert.Equal(t, directory.UserRecordType, updates[0].GetType()) + assert.Equal(t, "type.googleapis.com/google.protobuf.Struct", updates[0].GetData().GetTypeUrl(), + "should preserve data type") + assert.NotNil(t, updates[0].GetDeletedAt()) + } +} + +func mustNewStruct(m map[string]any) *structpb.Struct { + s, err := structpb.NewStruct(m) + if err != nil { + panic(err) + } + return s +} From b188a168af1850ac9c2385c7d7cd432ccb7b8099 Mon Sep 17 00:00:00 2001 From: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com> Date: Tue, 25 Mar 2025 14:48:07 -0700 Subject: [PATCH 7/7] metrics: fix an apparent metric setup error (#5543) The IdentityManagerLastSessionRefreshErrorView appears to be a duplicate of IdentityManagerLastUserRefreshErrorView. Adjust it to use the matching identityManagerLastSessionRefreshError instead. --- internal/telemetry/metrics/info.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/telemetry/metrics/info.go b/internal/telemetry/metrics/info.go index 902cc404b..9b18d0eba 100644 --- a/internal/telemetry/metrics/info.go +++ b/internal/telemetry/metrics/info.go @@ -240,11 +240,11 @@ var ( Measure: identityManagerLastSessionRefreshSuccess, Aggregation: view.Count(), } - // IdentityManagerLastSessionRefreshErrorView contains user refresh errors counter + // IdentityManagerLastSessionRefreshErrorView contains session refresh errors counter IdentityManagerLastSessionRefreshErrorView = &view.View{ - Name: identityManagerLastUserRefreshError.Name(), - Description: identityManagerLastUserRefreshError.Description(), - Measure: identityManagerLastUserRefreshError, + Name: identityManagerLastSessionRefreshError.Name(), + Description: identityManagerLastSessionRefreshError.Description(), + Measure: identityManagerLastSessionRefreshError, Aggregation: view.Count(), } // IdentityManagerLastSessionRefreshSuccessTimestampView contains successful session refresh counter