package upstreams import ( "context" "fmt" "net" "strings" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/values" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) type Options struct { serverOpts []grpc.ServerOption } type Option func(*Options) func (o *Options) apply(opts ...Option) { for _, op := range opts { op(o) } } func ServerOpts(opt ...grpc.ServerOption) Option { return func(o *Options) { o.serverOpts = append(o.serverOpts, opt...) } } // GRPCUpstream represents a GRPC server which can be used as the target for // one or more Pomerium routes in a test environment. // // This upstream implements [grpc.ServiceRegistrar], and can be used similarly // in the same way as [*grpc.Server] to register services before it is started. // // Any [testenv.Route] instances created from this upstream can be referenced // in the Dial() method to establish a connection to that route. type GRPCUpstream interface { testenv.Upstream grpc.ServiceRegistrar Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn } type grpcUpstream struct { Options testenv.Aggregate serverPort values.MutableValue[int] creds credentials.TransportCredentials services []service } var ( _ testenv.Upstream = (*grpcUpstream)(nil) _ grpc.ServiceRegistrar = (*grpcUpstream)(nil) ) // GRPC creates a new GRPC upstream server. func GRPC(creds credentials.TransportCredentials, opts ...Option) GRPCUpstream { options := Options{} options.apply(opts...) up := &grpcUpstream{ Options: options, creds: creds, serverPort: values.Deferred[int](), } up.RecordCaller() return up } type service struct { desc *grpc.ServiceDesc impl any } func (g *grpcUpstream) Port() values.Value[int] { return g.serverPort } // RegisterService implements grpc.ServiceRegistrar. func (g *grpcUpstream) RegisterService(desc *grpc.ServiceDesc, impl any) { g.services = append(g.services, service{desc, impl}) } // Route implements testenv.Upstream. func (g *grpcUpstream) Route() testenv.RouteStub { r := &testenv.PolicyRoute{} var protocol string switch g.creds.Info().SecurityProtocol { case "insecure": protocol = "h2c" default: protocol = "https" } r.To(values.Bind(g.serverPort, func(port int) string { return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port) })) g.Add(r) return r } // Start implements testenv.Upstream. func (g *grpcUpstream) Run(ctx context.Context) error { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return err } g.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port) server := grpc.NewServer(append(g.serverOpts, grpc.Creds(g.creds))...) for _, s := range g.services { server.RegisterService(s.desc, s.impl) } errC := make(chan error, 1) go func() { errC <- server.Serve(listener) }() select { case <-ctx.Done(): server.Stop() return context.Cause(ctx) case err := <-errC: return err } } func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn { dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(g.Env().ServerCAs(), "")), grpc.WithDefaultCallOptions(grpc.WaitForReady(true)), ) cc, err := grpc.NewClient(strings.TrimPrefix(r.URL().Value(), "https://"), dialOpts...) if err != nil { panic(err) } return cc }