pomerium/internal/testenv/upstreams/grpc.go
Joe Kralicky 08623ef346
add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets (#5471)
* add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets

testenv:
- add new TCP upstream
- add websocket functions to HTTP upstream
- add https support to mock idp (default on)
- add new debug flags -env.bind-address and -env.use-trace-environ to
  allow changing the default bind address, and enabling otel environment
  based trace config, respectively

* linter pass

---------

Co-authored-by: Denis Mishin <dmishin@pomerium.com>
2025-03-19 18:42:19 -04:00

198 lines
5.7 KiB
Go

package upstreams
import (
"context"
"fmt"
"net"
"strings"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/values"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
oteltrace "go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)
type GRPCUpstreamOptions struct {
CommonUpstreamOptions
serverOpts []grpc.ServerOption
}
type GRPCUpstreamOption interface {
applyGRPC(*GRPCUpstreamOptions)
}
type GRPCUpstreamOptionFunc func(*GRPCUpstreamOptions)
func (f GRPCUpstreamOptionFunc) applyGRPC(o *GRPCUpstreamOptions) {
f(o)
}
func ServerOpts(opt ...grpc.ServerOption) GRPCUpstreamOption {
return GRPCUpstreamOptionFunc(func(o *GRPCUpstreamOptions) {
o.serverOpts = append(o.serverOpts, opt...)
})
}
// GRPCUpstream represents a GRPC server which can be used as the target for
// one or more Pomerium routes in a test environment.
//
// This upstream implements [grpc.ServiceRegistrar], and can be used similarly
// in the same way as [*grpc.Server] to register services before it is started.
//
// Any [testenv.Route] instances created from this upstream can be referenced
// in the Dial() method to establish a connection to that route.
type GRPCUpstream interface {
testenv.Upstream
grpc.ServiceRegistrar
Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn
// Dials the server directly instead of going through a Pomerium route.
DirectConnect(dialOpts ...grpc.DialOption) *grpc.ClientConn
}
type grpcUpstream struct {
GRPCUpstreamOptions
testenv.Aggregate
serverPort values.MutableValue[int]
creds credentials.TransportCredentials
serverTracerProvider values.MutableValue[oteltrace.TracerProvider]
clientTracerProvider values.MutableValue[oteltrace.TracerProvider]
services []service
}
var (
_ testenv.Upstream = (*grpcUpstream)(nil)
_ grpc.ServiceRegistrar = (*grpcUpstream)(nil)
)
// GRPC creates a new GRPC upstream server.
func GRPC(creds credentials.TransportCredentials, opts ...GRPCUpstreamOption) GRPCUpstream {
options := GRPCUpstreamOptions{
CommonUpstreamOptions: CommonUpstreamOptions{
displayName: "GRPC Upstream",
},
}
for _, op := range opts {
op.applyGRPC(&options)
}
up := &grpcUpstream{
GRPCUpstreamOptions: options,
creds: creds,
serverPort: values.Deferred[int](),
serverTracerProvider: values.Deferred[oteltrace.TracerProvider](),
clientTracerProvider: values.Deferred[oteltrace.TracerProvider](),
}
up.RecordCaller()
return up
}
type service struct {
desc *grpc.ServiceDesc
impl any
}
func (g *grpcUpstream) Addr() values.Value[string] {
return values.Bind(g.serverPort, func(port int) string {
return fmt.Sprintf("%s:%d", g.Env().Host(), port)
})
}
// RegisterService implements grpc.ServiceRegistrar.
func (g *grpcUpstream) RegisterService(desc *grpc.ServiceDesc, impl any) {
g.services = append(g.services, service{desc, impl})
}
// Route implements testenv.Upstream.
func (g *grpcUpstream) Route() testenv.RouteStub {
r := &testenv.PolicyRoute{}
var protocol string
switch g.creds.Info().SecurityProtocol {
case "insecure":
protocol = "h2c"
default:
protocol = "https"
}
r.To(values.Bind(g.serverPort, func(port int) string {
return fmt.Sprintf("%s://%s:%d", protocol, g.Env().Host(), port)
}))
g.Add(r)
return r
}
// Start implements testenv.Upstream.
func (g *grpcUpstream) Run(ctx context.Context) error {
listener, err := net.Listen("tcp", fmt.Sprintf("%s:0", g.Env().Host()))
if err != nil {
return err
}
g.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
if g.serverTracerProviderOverride != nil {
g.serverTracerProvider.Resolve(g.serverTracerProviderOverride)
} else {
g.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, g.displayName))
}
if g.clientTracerProviderOverride != nil {
g.clientTracerProvider.Resolve(g.clientTracerProviderOverride)
} else {
g.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "GRPC Client"))
}
server := grpc.NewServer(append(g.serverOpts,
grpc.Creds(g.creds),
grpc.StatsHandler(otelgrpc.NewServerHandler(
otelgrpc.WithTracerProvider(g.serverTracerProvider.Value()),
)),
)...)
for _, s := range g.services {
server.RegisterService(s.desc, s.impl)
}
if g.delayShutdown {
return snippets.RunWithDelayedShutdown(ctx,
func() error {
return server.Serve(listener)
},
server.GracefulStop,
)()
}
errC := make(chan error, 1)
go func() {
errC <- server.Serve(listener)
}()
select {
case <-ctx.Done():
server.GracefulStop()
return context.Cause(ctx)
case err := <-errC:
return err
}
}
func (g *grpcUpstream) withDefaultDialOpts(extraDialOpts []grpc.DialOption) []grpc.DialOption {
return append(extraDialOpts,
grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(g.Env().ServerCAs(), "")),
grpc.WithDefaultCallOptions(grpc.WaitForReady(true)),
grpc.WithStatsHandler(otelgrpc.NewClientHandler(otelgrpc.WithTracerProvider(g.clientTracerProvider.Value()))),
)
}
func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn {
cc, err := grpc.NewClient(strings.TrimPrefix(r.URL().Value(), "https://"), g.withDefaultDialOpts(dialOpts)...)
if err != nil {
panic(err)
}
return cc
}
func (g *grpcUpstream) DirectConnect(dialOpts ...grpc.DialOption) *grpc.ClientConn {
cc, err := grpc.NewClient(g.Addr().Value(),
append(g.withDefaultDialOpts(dialOpts), grpc.WithTransportCredentials(insecure.NewCredentials()))...)
if err != nil {
panic(err)
}
return cc
}