grpc: send client traffic through envoy (#2469)

* wip

* wip

* handle wildcards in override name

* remove wait for ready, add comment about sync, force initial sync complete in test

* address comments
This commit is contained in:
Caleb Doxsey 2021-08-16 16:12:22 -06:00 committed by GitHub
parent 87c3c675d2
commit bbec2cae9f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 391 additions and 480 deletions

View file

@ -110,7 +110,6 @@ func TestNew(t *testing.T) {
{"empty opts", &config.Options{}, true}, {"empty opts", &config.Options{}, true},
{"fails to validate", badRedirectURL, true}, {"fails to validate", badRedirectURL, true},
{"bad provider", badProvider, true}, {"bad provider", badProvider, true},
{"bad databroker url", badGRPCConn, true},
{"empty provider url", emptyProviderURL, true}, {"empty provider url", emptyProviderURL, true},
{"good signing key", goodSigningKey, false}, {"good signing key", goodSigningKey, false},
{"bad signing key", badSigningKey, true}, {"bad signing key", badSigningKey, true},

View file

@ -146,19 +146,8 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
return nil, err return nil, err
} }
urls, err := cfg.Options.GetDataBrokerURLs() dataBrokerConn, err := grpc.GetOutboundGRPCClientConn(context.Background(), &grpc.OutboundOptions{
if err != nil { OutboundPort: cfg.OutboundPort,
return nil, err
}
dataBrokerConn, err := grpc.GetGRPCClientConn(context.Background(), "databroker", &grpc.Options{
Addrs: urls,
OverrideCertificateName: cfg.Options.OverrideCertificateName,
CA: cfg.Options.CA,
CAFile: cfg.Options.CAFile,
RequestTimeout: cfg.Options.GRPCClientTimeout,
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GetGRPCInsecure(),
InstallationID: cfg.Options.InstallationID, InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey, SignedJWTKey: sharedKey,

View file

@ -23,6 +23,11 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
ctx, span := trace.StartSpan(ctx, "authorize.grpc.Check") ctx, span := trace.StartSpan(ctx, "authorize.grpc.Check")
defer span.End() defer span.End()
// wait for the initial sync to complete so that data is available for evaluation
if err := a.WaitForInitialSync(ctx); err != nil {
return nil, err
}
state := a.state.Load() state := a.state.Load()
// convert the incoming envoy-style http request into a go-style http request // convert the incoming envoy-style http request into a go-style http request

View file

@ -330,6 +330,8 @@ func TestAuthorize_Check(t *testing.T) {
} }
a.currentOptions.Store(&config.Options{ForwardAuthURLString: "https://forward-auth.example.com"}) a.currentOptions.Store(&config.Options{ForwardAuthURLString: "https://forward-auth.example.com"})
close(a.dataBrokerInitialSync)
cmpOpts := []cmp.Option{ cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}), cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}),
cmpopts.IgnoreUnexported(status.Status{}), cmpopts.IgnoreUnexported(status.Status{}),

View file

@ -51,19 +51,8 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a
return nil, err return nil, err
} }
urls, err := cfg.Options.GetDataBrokerURLs() cc, err := grpc.GetOutboundGRPCClientConn(context.Background(), &grpc.OutboundOptions{
if err != nil { OutboundPort: cfg.OutboundPort,
return nil, err
}
cc, err := grpc.GetGRPCClientConn(context.Background(), "databroker", &grpc.Options{
Addrs: urls,
OverrideCertificateName: cfg.Options.OverrideCertificateName,
CA: cfg.Options.CA,
CAFile: cfg.Options.CAFile,
RequestTimeout: cfg.Options.GRPCClientTimeout,
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GetGRPCInsecure(),
InstallationID: cfg.Options.InstallationID, InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey, SignedJWTKey: sharedKey,

View file

@ -11,6 +11,13 @@ type Config struct {
Options *Options Options *Options
AutoCertificates []tls.Certificate AutoCertificates []tls.Certificate
EnvoyVersion string EnvoyVersion string
// GRPCPort is the port the gRPC server is running on.
GRPCPort string
// HTTPPort is the port the HTTP server is running on.
HTTPPort string
// OutboundPort is the port the outbound gRPC listener is running on.
OutboundPort string
} }
// Clone creates a clone of the config. // Clone creates a clone of the config.
@ -21,6 +28,10 @@ func (cfg *Config) Clone() *Config {
Options: newOptions, Options: newOptions,
AutoCertificates: cfg.AutoCertificates, AutoCertificates: cfg.AutoCertificates,
EnvoyVersion: cfg.EnvoyVersion, EnvoyVersion: cfg.EnvoyVersion,
GRPCPort: cfg.GRPCPort,
HTTPPort: cfg.HTTPPort,
OutboundPort: cfg.OutboundPort,
} }
} }

View file

@ -12,6 +12,7 @@ import (
"github.com/pomerium/pomerium/internal/fileutil" "github.com/pomerium/pomerium/internal/fileutil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/netutil"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
) )
@ -97,7 +98,9 @@ type FileOrEnvironmentSource struct {
} }
// NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource. // NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource.
func NewFileOrEnvironmentSource(configFile, envoyVersion string) (*FileOrEnvironmentSource, error) { func NewFileOrEnvironmentSource(
configFile, envoyVersion string,
) (*FileOrEnvironmentSource, error) {
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context { ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context {
return c.Str("config_file_source", configFile) return c.Str("config_file_source", configFile)
}) })
@ -107,9 +110,21 @@ func NewFileOrEnvironmentSource(configFile, envoyVersion string) (*FileOrEnviron
return nil, err return nil, err
} }
ports, err := netutil.AllocatePorts(3)
if err != nil {
return nil, err
}
grpcPort := ports[0]
httpPort := ports[1]
outboundPort := ports[2]
cfg := &Config{ cfg := &Config{
Options: options, Options: options,
EnvoyVersion: envoyVersion, EnvoyVersion: envoyVersion,
GRPCPort: grpcPort,
HTTPPort: httpPort,
OutboundPort: outboundPort,
} }
metrics.SetConfigInfo(ctx, cfg.Options.Services, "local", cfg.Checksum(), true) metrics.SetConfigInfo(ctx, cfg.Options.Services, "local", cfg.Checksum(), true)

View file

@ -35,7 +35,11 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
Scheme: "http", Scheme: "http",
Host: b.localHTTPAddress, Host: b.localHTTPAddress,
} }
authzURLs, err := cfg.Options.GetAuthorizeURLs() authorizeURLs, err := cfg.Options.GetAuthorizeURLs()
if err != nil {
return nil, err
}
databrokerURLs, err := cfg.Options.GetDataBrokerURLs()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -44,24 +48,35 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
if err != nil { if err != nil {
return nil, err return nil, err
} }
controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto) controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
if err != nil { if err != nil {
return nil, err return nil, err
} }
authZ, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authzURLs, upstreamProtocolHTTP2)
authorizeCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(authorizeURLs) > 1 {
authorizeCluster.HealthChecks = grpcHealthChecks("pomerium-authorize")
authorizeCluster.OutlierDetection = grpcAuthorizeOutlierDetection()
}
if len(authzURLs) > 1 { databrokerCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
authZ.HealthChecks = grpcHealthChecks("pomerium-authorize") if err != nil {
authZ.OutlierDetection = grpcAuthorizeOutlierDetection() return nil, err
}
if len(databrokerURLs) > 1 {
authorizeCluster.HealthChecks = grpcHealthChecks("pomerium-databroker")
authorizeCluster.OutlierDetection = grpcAuthorizeOutlierDetection()
} }
clusters := []*envoy_config_cluster_v3.Cluster{ clusters := []*envoy_config_cluster_v3.Cluster{
controlGRPC, controlGRPC,
controlHTTP, controlHTTP,
authZ, authorizeCluster,
databrokerCluster,
} }
tracingCluster, err := buildTracingCluster(cfg.Options) tracingCluster, err := buildTracingCluster(cfg.Options)
@ -170,16 +185,11 @@ func (b *Builder) buildInternalTransportSocket(
if endpoint.Scheme != "https" { if endpoint.Scheme != "https" {
return nil, nil return nil, nil
} }
sni := endpoint.Hostname()
if options.OverrideCertificateName != "" {
sni = options.OverrideCertificateName
}
validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{ validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{
MatchSubjectAltNames: []*envoy_type_matcher_v3.StringMatcher{{ MatchSubjectAltNames: []*envoy_type_matcher_v3.StringMatcher{
MatchPattern: &envoy_type_matcher_v3.StringMatcher_Exact{ b.buildSubjectAlternativeNameMatcher(endpoint, options.OverrideCertificateName),
Exact: sni,
}, },
}},
} }
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile) bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile)
if err != nil { if err != nil {
@ -194,7 +204,7 @@ func (b *Builder) buildInternalTransportSocket(
ValidationContext: validationContext, ValidationContext: validationContext,
}, },
}, },
Sni: sni, Sni: b.buildSubjectNameIndication(endpoint, options.OverrideCertificateName),
} }
tlsConfig := marshalAny(tlsContext) tlsConfig := marshalAny(tlsContext)
return &envoy_config_core_v3.TransportSocket{ return &envoy_config_core_v3.TransportSocket{
@ -279,16 +289,10 @@ func (b *Builder) buildPolicyValidationContext(
policy *config.Policy, policy *config.Policy,
dst url.URL, dst url.URL,
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) { ) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
sni := dst.Hostname()
if policy.TLSServerName != "" {
sni = policy.TLSServerName
}
validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{ validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{
MatchSubjectAltNames: []*envoy_type_matcher_v3.StringMatcher{{ MatchSubjectAltNames: []*envoy_type_matcher_v3.StringMatcher{
MatchPattern: &envoy_type_matcher_v3.StringMatcher_Exact{ b.buildSubjectAlternativeNameMatcher(&dst, policy.TLSServerName),
Exact: sni,
}, },
}},
} }
if policy.TLSCustomCAFile != "" { if policy.TLSCustomCAFile != "" {
validationContext.TrustedCa = b.filemgr.FileDataSource(policy.TLSCustomCAFile) validationContext.TrustedCa = b.filemgr.FileDataSource(policy.TLSCustomCAFile)

View file

@ -85,6 +85,12 @@ func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*en
listeners = append(listeners, li) listeners = append(listeners, li)
} }
li, err := b.buildOutboundListener(cfg)
if err != nil {
return nil, err
}
listeners = append(listeners, li)
return listeners, nil return listeners, nil
} }

View file

@ -0,0 +1,134 @@
package envoyconfig
import (
"fmt"
"strconv"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
envoy_config_listener_v3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
envoy_config_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3"
envoy_http_connection_manager "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/pomerium/pomerium/config"
)
func (b *Builder) buildOutboundListener(cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
outboundPort, err := strconv.Atoi(cfg.OutboundPort)
if err != nil {
return nil, fmt.Errorf("invalid outbound port: %w", err)
}
filter, err := b.buildOutboundHTTPConnectionManager()
if err != nil {
return nil, fmt.Errorf("error building outbound http connection manager filter: %w", err)
}
li := &envoy_config_listener_v3.Listener{
Name: "outbound-ingress",
Address: &envoy_config_core_v3.Address{
Address: &envoy_config_core_v3.Address_SocketAddress{
SocketAddress: &envoy_config_core_v3.SocketAddress{
Address: "127.0.0.1",
PortSpecifier: &envoy_config_core_v3.SocketAddress_PortValue{
PortValue: uint32(outboundPort),
},
},
},
},
FilterChains: []*envoy_config_listener_v3.FilterChain{{
Name: "outbound-ingress",
Filters: []*envoy_config_listener_v3.Filter{filter},
}},
}
return li, nil
}
func (b *Builder) buildOutboundHTTPConnectionManager() (*envoy_config_listener_v3.Filter, error) {
rc, err := b.buildOutboundRouteConfiguration()
if err != nil {
return nil, err
}
tc := marshalAny(&envoy_http_connection_manager.HttpConnectionManager{
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
StatPrefix: "grpc_egress",
// limit request first byte to last byte time
RequestTimeout: &durationpb.Duration{
Seconds: 15,
},
RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{
RouteConfig: rc,
},
HttpFilters: []*envoy_http_connection_manager.HttpFilter{{
Name: "envoy.filters.http.router",
}},
})
return &envoy_config_listener_v3.Filter{
Name: "envoy.filters.network.http_connection_manager",
ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{
TypedConfig: tc,
},
}, nil
}
func (b *Builder) buildOutboundRouteConfiguration() (*envoy_config_route_v3.RouteConfiguration, error) {
return b.buildRouteConfiguration("grpc", []*envoy_config_route_v3.VirtualHost{{
Name: "grpc",
Domains: []string{"*"},
Routes: b.buildOutboundRoutes(),
}})
}
func (b *Builder) buildOutboundRoutes() []*envoy_config_route_v3.Route {
type Def struct {
Cluster string
Prefixes []string
}
defs := []Def{
{
Cluster: "pomerium-authorize",
Prefixes: []string{
"/envoy.service.auth.v3.Authorization/",
},
},
{
Cluster: "pomerium-databroker",
Prefixes: []string{
"/databroker.DataBrokerService/",
"/directory.DirectoryService/",
"/registry.Registry/",
},
},
{
Cluster: "pomerium-control-plane-grpc",
Prefixes: []string{
"/",
},
},
}
var routes []*envoy_config_route_v3.Route
for _, def := range defs {
for _, prefix := range def.Prefixes {
routes = append(routes, &envoy_config_route_v3.Route{
Name: def.Cluster,
Match: &envoy_config_route_v3.RouteMatch{
PathSpecifier: &envoy_config_route_v3.RouteMatch_Prefix{Prefix: prefix},
Grpc: &envoy_config_route_v3.RouteMatch_GrpcRouteMatchOptions{},
},
Action: &envoy_config_route_v3.Route_Route{
Route: &envoy_config_route_v3.RouteAction{
ClusterSpecifier: &envoy_config_route_v3.RouteAction_Cluster{
Cluster: def.Cluster,
},
// disable the timeout to support grpc streaming
Timeout: durationpb.New(0),
IdleTimeout: durationpb.New(0),
},
},
})
}
}
return routes
}

52
config/envoyconfig/tls.go Normal file
View file

@ -0,0 +1,52 @@
package envoyconfig
import (
"net/url"
"regexp"
"strings"
envoy_type_matcher_v3 "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3"
)
func (b *Builder) buildSubjectAlternativeNameMatcher(
dst *url.URL,
overrideName string,
) *envoy_type_matcher_v3.StringMatcher {
sni := dst.Hostname()
if overrideName != "" {
sni = overrideName
}
if strings.Contains(sni, "*") {
pattern := regexp.QuoteMeta(sni)
pattern = strings.Replace(pattern, "\\*", ".*", -1)
return &envoy_type_matcher_v3.StringMatcher{
MatchPattern: &envoy_type_matcher_v3.StringMatcher_SafeRegex{
SafeRegex: &envoy_type_matcher_v3.RegexMatcher{
EngineType: &envoy_type_matcher_v3.RegexMatcher_GoogleRe2{
GoogleRe2: &envoy_type_matcher_v3.RegexMatcher_GoogleRE2{},
},
Regex: pattern,
},
},
}
}
return &envoy_type_matcher_v3.StringMatcher{
MatchPattern: &envoy_type_matcher_v3.StringMatcher_Exact{
Exact: sni,
},
}
}
func (b *Builder) buildSubjectNameIndication(
dst *url.URL,
overrideName string,
) string {
sni := dst.Hostname()
if overrideName != "" {
sni = overrideName
}
sni = strings.Replace(sni, "*", "example", -1)
return sni
}

View file

@ -0,0 +1,33 @@
package envoyconfig
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/testutil"
)
func TestBuildSubjectAlternativeNameMatcher(t *testing.T) {
b := new(Builder)
testutil.AssertProtoJSONEqual(t, `
{ "exact": "example.com" }
`, b.buildSubjectAlternativeNameMatcher(&url.URL{Host: "example.com:1234"}, ""))
testutil.AssertProtoJSONEqual(t, `
{ "exact": "example.org" }
`, b.buildSubjectAlternativeNameMatcher(&url.URL{Host: "example.com:1234"}, "example.org"))
testutil.AssertProtoJSONEqual(t, `
{ "safeRegex": {
"googleRe2": {},
"regex": ".*\\.example\\.org"
} }
`, b.buildSubjectAlternativeNameMatcher(&url.URL{Host: "example.com:1234"}, "*.example.org"))
}
func TestBuildSubjectNameIndication(t *testing.T) {
b := new(Builder)
assert.Equal(t, "example.com", b.buildSubjectNameIndication(&url.URL{Host: "example.com:1234"}, ""))
assert.Equal(t, "example.org", b.buildSubjectNameIndication(&url.URL{Host: "example.com:1234"}, "example.org"))
assert.Equal(t, "example.example.org", b.buildSubjectNameIndication(&url.URL{Host: "example.com:1234"}, "*.example.org"))
}

View file

@ -5,7 +5,6 @@ package pomerium
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
@ -66,7 +65,7 @@ func Run(ctx context.Context, configFile string) error {
defer traceMgr.Close() defer traceMgr.Close()
// setup the control plane // setup the control plane
controlPlane, err := controlplane.NewServer(src.GetConfig().Options.Services, metricsMgr) controlPlane, err := controlplane.NewServer(src.GetConfig(), metricsMgr)
if err != nil { if err != nil {
return fmt.Errorf("error creating control plane: %w", err) return fmt.Errorf("error creating control plane: %w", err)
} }
@ -83,14 +82,14 @@ func Run(ctx context.Context, configFile string) error {
return fmt.Errorf("applying config: %w", err) return fmt.Errorf("applying config: %w", err)
} }
_, grpcPort, _ := net.SplitHostPort(controlPlane.GRPCListener.Addr().String()) log.Info(ctx).
_, httpPort, _ := net.SplitHostPort(controlPlane.HTTPListener.Addr().String()) Str("grpc-port", src.GetConfig().GRPCPort).
Str("http-port", src.GetConfig().HTTPPort).
log.Info(ctx).Str("port", grpcPort).Msg("gRPC server started") Str("outbound-port", src.GetConfig().OutboundPort).
log.Info(ctx).Str("port", httpPort).Msg("HTTP server started") Msg("server started")
// create envoy server // create envoy server
envoyServer, err := envoy.NewServer(ctx, src, grpcPort, httpPort, controlPlane.Builder) envoyServer, err := envoy.NewServer(ctx, src, controlPlane.Builder)
if err != nil { if err != nil {
return fmt.Errorf("error creating envoy server: %w", err) return fmt.Errorf("error creating envoy server: %w", err)
} }
@ -143,13 +142,6 @@ func Run(ctx context.Context, configFile string) error {
eg.Go(func() error { eg.Go(func() error {
return authorizeServer.Run(ctx) return authorizeServer.Run(ctx)
}) })
// in non-all-in-one mode we will wait for the initial sync to complete before starting
// the control plane
if dataBrokerServer == nil {
if err := authorizeServer.WaitForInitialSync(ctx); err != nil {
return err
}
}
} }
eg.Go(func() error { eg.Go(func() error {
return controlPlane.Run(ctx) return controlPlane.Run(ctx)

View file

@ -81,28 +81,17 @@ func (srv *Server) storeEnvoyConfigurationEvent(ctx context.Context, evt *events
} }
func (srv *Server) getDataBrokerClient(ctx context.Context) (databrokerpb.DataBrokerServiceClient, error) { func (srv *Server) getDataBrokerClient(ctx context.Context) (databrokerpb.DataBrokerServiceClient, error) {
options := srv.currentConfig.Load().Options cfg := srv.currentConfig.Load()
sharedKey, err := options.GetSharedKey() sharedKey, err := cfg.Options.GetSharedKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
urls, err := options.GetDataBrokerURLs() cc, err := grpc.GetOutboundGRPCClientConn(context.Background(), &grpc.OutboundOptions{
if err != nil { OutboundPort: cfg.OutboundPort,
return nil, err InstallationID: cfg.Options.InstallationID,
} ServiceName: cfg.Options.Services,
cc, err := grpc.GetGRPCClientConn(ctx, "databroker", &grpc.Options{
Addrs: urls,
OverrideCertificateName: options.OverrideCertificateName,
CA: options.CA,
CAFile: options.CAFile,
RequestTimeout: options.GRPCClientTimeout,
ClientDNSRoundRobin: options.GRPCClientDNSRoundRobin,
WithInsecure: options.GetGRPCInsecure(),
InstallationID: options.InstallationID,
ServiceName: options.Services,
SignedJWTKey: sharedKey, SignedJWTKey: sharedKey,
}) })
if err != nil { if err != nil {

View file

@ -69,6 +69,7 @@ func TestEvents(t *testing.T) {
li, err := net.Listen("tcp", "127.0.0.1:0") li, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err) require.NoError(t, err)
defer li.Close() defer li.Close()
_, outboundPort, _ := net.SplitHostPort(li.Addr().String())
var putRequest *databrokerpb.PutRequest var putRequest *databrokerpb.PutRequest
var setOptionsRequest *databrokerpb.SetOptionsRequest var setOptionsRequest *databrokerpb.SetOptionsRequest
@ -100,6 +101,7 @@ func TestEvents(t *testing.T) {
srv := &Server{} srv := &Server{}
srv.currentConfig.Store(versionedConfig{ srv.currentConfig.Store(versionedConfig{
Config: &config.Config{ Config: &config.Config{
OutboundPort: outboundPort,
Options: &config.Options{ Options: &config.Options{
SharedKey: cryptutil.NewBase64Key(), SharedKey: cryptutil.NewBase64Key(),
DataBrokerURLString: "http://" + li.Addr().String(), DataBrokerURLString: "http://" + li.Addr().String(),

View file

@ -68,20 +68,20 @@ type Server struct {
} }
// NewServer creates a new Server. Listener ports are chosen by the OS. // NewServer creates a new Server. Listener ports are chosen by the OS.
func NewServer(name string, metricsMgr *config.MetricsManager) (*Server, error) { func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager) (*Server, error) {
srv := &Server{ srv := &Server{
metricsMgr: metricsMgr, metricsMgr: metricsMgr,
reproxy: reproxy.New(), reproxy: reproxy.New(),
envoyConfigurationEvents: make(chan *events.EnvoyConfigurationEvent, 10), envoyConfigurationEvents: make(chan *events.EnvoyConfigurationEvent, 10),
} }
srv.currentConfig.Store(versionedConfig{ srv.currentConfig.Store(versionedConfig{
Config: &config.Config{Options: &config.Options{}}, Config: cfg,
}) })
var err error var err error
// setup gRPC // setup gRPC
srv.GRPCListener, err = net.Listen("tcp4", "127.0.0.1:0") srv.GRPCListener, err = net.Listen("tcp4", net.JoinHostPort("127.0.0.1", cfg.GRPCPort))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -92,7 +92,7 @@ func NewServer(name string, metricsMgr *config.MetricsManager) (*Server, error)
), ),
) )
srv.GRPCServer = grpc.NewServer( srv.GRPCServer = grpc.NewServer(
grpc.StatsHandler(telemetry.NewGRPCServerStatsHandler(name)), grpc.StatsHandler(telemetry.NewGRPCServerStatsHandler(cfg.Options.Services)),
grpc.ChainUnaryInterceptor(requestid.UnaryServerInterceptor(), ui), grpc.ChainUnaryInterceptor(requestid.UnaryServerInterceptor(), ui),
grpc.ChainStreamInterceptor(requestid.StreamServerInterceptor(), si), grpc.ChainStreamInterceptor(requestid.StreamServerInterceptor(), si),
) )
@ -102,7 +102,7 @@ func NewServer(name string, metricsMgr *config.MetricsManager) (*Server, error)
grpc_health_v1.RegisterHealthServer(srv.GRPCServer, pom_grpc.NewHealthCheckServer()) grpc_health_v1.RegisterHealthServer(srv.GRPCServer, pom_grpc.NewHealthCheckServer())
// setup HTTP // setup HTTP
srv.HTTPListener, err = net.Listen("tcp4", "127.0.0.1:0") srv.HTTPListener, err = net.Listen("tcp4", net.JoinHostPort("127.0.0.1", cfg.HTTPPort))
if err != nil { if err != nil {
_ = srv.GRPCListener.Close() _ = srv.GRPCListener.Close()
return nil, err return nil, err
@ -121,7 +121,7 @@ func NewServer(name string, metricsMgr *config.MetricsManager) (*Server, error)
) )
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context { ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
return c.Str("server_name", name) return c.Str("server_name", cfg.Options.Services)
}) })
res, err := srv.buildDiscoveryResources(ctx) res, err := srv.buildDiscoveryResources(ctx)

View file

@ -157,21 +157,10 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
} }
func (src *ConfigSource) runUpdater(cfg *config.Config) { func (src *ConfigSource) runUpdater(cfg *config.Config) {
urls, err := cfg.Options.GetDataBrokerURLs()
if err != nil {
log.Fatal().Err(err).Send()
return
}
sharedKey, _ := cfg.Options.GetSharedKey() sharedKey, _ := cfg.Options.GetSharedKey()
connectionOptions := &grpc.Options{ connectionOptions := &grpc.OutboundOptions{
Addrs: urls, OutboundPort: cfg.OutboundPort,
OverrideCertificateName: cfg.Options.OverrideCertificateName, InstallationID: cfg.Options.InstallationID,
CA: cfg.Options.CA,
CAFile: cfg.Options.CAFile,
RequestTimeout: cfg.Options.GRPCClientTimeout,
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GetGRPCInsecure(),
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey, SignedJWTKey: sharedKey,
} }
@ -193,7 +182,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
ctx := context.Background() ctx := context.Background()
ctx, src.cancel = context.WithCancel(ctx) ctx, src.cancel = context.WithCancel(ctx)
cc, err := grpc.NewGRPCClientConn(ctx, connectionOptions) cc, err := grpc.GetOutboundGRPCClientConn(ctx, connectionOptions)
if err != nil { if err != nil {
log.Error(ctx).Err(err).Msg("databroker: failed to create gRPC connection to data broker") log.Error(ctx).Err(err).Msg("databroker: failed to create gRPC connection to data broker")
return return

View file

@ -25,6 +25,7 @@ func TestConfigSource(t *testing.T) {
return return
} }
defer func() { _ = li.Close() }() defer func() { _ = li.Close() }()
_, outboundPort, _ := net.SplitHostPort(li.Addr().String())
dataBrokerServer := New() dataBrokerServer := New()
srv := grpc.NewServer() srv := grpc.NewServer()
@ -45,6 +46,7 @@ func TestConfigSource(t *testing.T) {
}) })
baseSource := config.NewStaticSource(&config.Config{ baseSource := config.NewStaticSource(&config.Config{
OutboundPort: outboundPort,
Options: base, Options: base,
}) })
src := NewConfigSource(ctx, baseSource, func(_ context.Context, cfg *config.Config) { src := NewConfigSource(ctx, baseSource, func(_ context.Context, cfg *config.Config) {
@ -86,6 +88,7 @@ func TestConfigSource(t *testing.T) {
} }
baseSource.SetConfig(ctx, &config.Config{ baseSource.SetConfig(ctx, &config.Config{
OutboundPort: outboundPort,
Options: base, Options: base,
}) })
} }

View file

@ -63,7 +63,7 @@ type Server struct {
} }
// NewServer creates a new server with traffic routed by envoy. // NewServer creates a new server with traffic routed by envoy.
func NewServer(ctx context.Context, src config.Source, grpcPort, httpPort string, builder *envoyconfig.Builder) (*Server, error) { func NewServer(ctx context.Context, src config.Source, builder *envoyconfig.Builder) (*Server, error) {
wd := filepath.Join(os.TempDir(), workingDirectoryName) wd := filepath.Join(os.TempDir(), workingDirectoryName)
err := os.MkdirAll(wd, embeddedEnvoyPermissions) err := os.MkdirAll(wd, embeddedEnvoyPermissions)
if err != nil { if err != nil {
@ -97,8 +97,8 @@ func NewServer(ctx context.Context, src config.Source, grpcPort, httpPort string
srv := &Server{ srv := &Server{
wd: wd, wd: wd,
builder: builder, builder: builder,
grpcPort: grpcPort, grpcPort: src.GetConfig().GRPCPort,
httpPort: httpPort, httpPort: src.GetConfig().HTTPPort,
envoyPath: envoyPath, envoyPath: envoyPath,
monitorProcessCancel: func() {}, monitorProcessCancel: func() {},

View file

@ -0,0 +1,22 @@
// Package netutil contains various functions that help with networking.
package netutil
import "net"
// AllocatePorts allocates random ports suitable for listening.
func AllocatePorts(count int) ([]string, error) {
var ports []string
for i := 0; i < count; i++ {
li, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, err
}
_, port, _ := net.SplitHostPort(li.Addr().String())
err = li.Close()
if err != nil {
return nil, err
}
ports = append(ports, port)
}
return ports, nil
}

View file

@ -39,20 +39,8 @@ func (r *Reporter) OnConfigChange(ctx context.Context, cfg *config.Config) {
return return
} }
urls, err := cfg.Options.GetDataBrokerURLs() registryConn, err := grpc.GetOutboundGRPCClientConn(ctx, &grpc.OutboundOptions{
if err != nil { OutboundPort: cfg.OutboundPort,
log.Error(ctx).Err(err).Msg("invalid databroker urls")
return
}
registryConn, err := grpc.GetGRPCClientConn(ctx, "databroker", &grpc.Options{
Addrs: urls,
OverrideCertificateName: cfg.Options.OverrideCertificateName,
CA: cfg.Options.CA,
CAFile: cfg.Options.CAFile,
RequestTimeout: cfg.Options.GRPCClientTimeout,
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GetGRPCInsecure(),
InstallationID: cfg.Options.InstallationID, InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey, SignedJWTKey: sharedKey,

View file

@ -2,51 +2,23 @@ package grpc
import ( import (
"context" "context"
"crypto/tls"
"errors"
"net" "net"
"net/url"
"strconv"
"strings"
"sync" "sync"
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/telemetry/requestid" "github.com/pomerium/pomerium/internal/telemetry/requestid"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/grpcutil"
) )
const (
defaultGRPCSecurePort = 443
defaultGRPCInsecurePort = 80
)
// Options contains options for connecting to a pomerium rpc service. // Options contains options for connecting to a pomerium rpc service.
type Options struct { type Options struct {
// Addrs is the location of the service. e.g. "service.corp.example:8443" // Address is the location of the service. e.g. "service.corp.example:8443"
Addrs []*url.URL Address string
// OverrideCertificateName overrides the server name used to verify the hostname on the
// returned certificates from the server. gRPC internals also use it to override the virtual
// hosting name if it is set.
OverrideCertificateName string
// CA specifies the base64 encoded TLS certificate authority to use.
CA string
// CAFile specifies the TLS certificate authority file to use.
CAFile string
// RequestTimeout specifies the timeout for individual RPC calls
RequestTimeout time.Duration
// ClientDNSRoundRobin enables or disables DNS resolver based load balancing
ClientDNSRoundRobin bool
// WithInsecure disables transport security for this ClientConn.
// Note that transport security is required unless WithInsecure is set.
WithInsecure bool
// InstallationID specifies the installation id for telemetry exposition. // InstallationID specifies the installation id for telemetry exposition.
InstallationID string InstallationID string
@ -60,31 +32,10 @@ type Options struct {
// NewGRPCClientConn returns a new gRPC pomerium service client connection. // NewGRPCClientConn returns a new gRPC pomerium service client connection.
func NewGRPCClientConn(ctx context.Context, opts *Options, other ...grpc.DialOption) (*grpc.ClientConn, error) { func NewGRPCClientConn(ctx context.Context, opts *Options, other ...grpc.DialOption) (*grpc.ClientConn, error) {
if len(opts.Addrs) == 0 {
return nil, errors.New("internal/grpc: connection address required")
}
var addrs []string
for _, u := range opts.Addrs {
hostport := u.Host
// no colon exists in the connection string, assume one must be added manually
if _, _, err := net.SplitHostPort(hostport); err != nil {
if u.Scheme == "https" {
hostport = net.JoinHostPort(hostport, strconv.Itoa(defaultGRPCSecurePort))
} else {
hostport = net.JoinHostPort(hostport, strconv.Itoa(defaultGRPCInsecurePort))
}
}
addrs = append(addrs, hostport)
}
connAddr := "pomerium:///" + strings.Join(addrs, ",")
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.ServiceName) clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.ServiceName)
unaryClientInterceptors := []grpc.UnaryClientInterceptor{ unaryClientInterceptors := []grpc.UnaryClientInterceptor{
requestid.UnaryClientInterceptor(), requestid.UnaryClientInterceptor(),
grpcTimeoutInterceptor(opts.RequestTimeout),
clientStatsHandler.UnaryInterceptor, clientStatsHandler.UnaryInterceptor,
} }
streamClientInterceptors := []grpc.StreamClientInterceptor{ streamClientInterceptors := []grpc.StreamClientInterceptor{
@ -98,38 +49,13 @@ func NewGRPCClientConn(ctx context.Context, opts *Options, other ...grpc.DialOpt
dialOptions := []grpc.DialOption{ dialOptions := []grpc.DialOption{
grpc.WithChainUnaryInterceptor(unaryClientInterceptors...), grpc.WithChainUnaryInterceptor(unaryClientInterceptors...),
grpc.WithChainStreamInterceptor(streamClientInterceptors...), grpc.WithChainStreamInterceptor(streamClientInterceptors...),
grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...),
grpc.WithStatsHandler(clientStatsHandler.Handler), grpc.WithStatsHandler(clientStatsHandler.Handler),
grpc.WithDefaultServiceConfig(roundRobinServiceConfig),
grpc.WithDisableServiceConfig(), grpc.WithDisableServiceConfig(),
grpc.WithInsecure(),
} }
dialOptions = append(dialOptions, other...) dialOptions = append(dialOptions, other...)
log.Info(ctx).Str("address", opts.Address).Msg("dialing")
if opts.WithInsecure { return grpc.DialContext(ctx, opts.Address, dialOptions...)
log.Info(ctx).Str("addr", connAddr).Msg("internal/grpc: grpc with insecure")
dialOptions = append(dialOptions, grpc.WithInsecure())
} else {
rootCAs, err := cryptutil.GetCertPool(opts.CA, opts.CAFile)
if err != nil {
return nil, err
}
cert := credentials.NewTLS(&tls.Config{RootCAs: rootCAs, MinVersion: tls.VersionTLS12})
// override allowed certificate name string, typically used when doing behind ingress connection
if opts.OverrideCertificateName != "" {
log.Debug(ctx).Str("cert-override-name", opts.OverrideCertificateName).Msg("internal/grpc: grpc")
err := cert.OverrideServerName(opts.OverrideCertificateName)
if err != nil {
return nil, err
}
}
// finally add our credential
dialOptions = append(dialOptions, grpc.WithTransportCredentials(cert))
}
return grpc.DialContext(ctx, connAddr, dialOptions...)
} }
// grpcTimeoutInterceptor enforces per-RPC request timeouts // grpcTimeoutInterceptor enforces per-RPC request timeouts
@ -186,3 +112,28 @@ func GetGRPCClientConn(ctx context.Context, name string, opts *Options) (*grpc.C
} }
return cc, nil return cc, nil
} }
// OutboundOptions are the options for the outbound gRPC client.
type OutboundOptions struct {
// OutboundPort is the port for the outbound gRPC listener.
OutboundPort string
// InstallationID specifies the installation id for telemetry exposition.
InstallationID string
// ServiceName specifies the service name for telemetry exposition
ServiceName string
// SignedJWTKey is the JWT key to use for signing a JWT attached to metadata.
SignedJWTKey []byte
}
// GetOutboundGRPCClientConn gets the outbound gRPC client.
func GetOutboundGRPCClientConn(ctx context.Context, opts *OutboundOptions) (*grpc.ClientConn, error) {
return GetGRPCClientConn(ctx, "outbound", &Options{
Address: net.JoinHostPort("127.0.0.1", opts.OutboundPort),
InstallationID: opts.InstallationID,
ServiceName: opts.ServiceName,
SignedJWTKey: opts.SignedJWTKey,
})
}

View file

@ -2,12 +2,9 @@ package grpc
import ( import (
"context" "context"
"net/url"
"strings"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@ -37,81 +34,3 @@ func Test_grpcTimeoutInterceptor(t *testing.T) {
to(context.Background(), "test", nil, nil, nil, mockInvoker(timeOut*2, true)) to(context.Background(), "test", nil, nil, nil, mockInvoker(timeOut*2, true))
to(context.Background(), "test", nil, nil, nil, mockInvoker(timeOut/2, false)) to(context.Background(), "test", nil, nil, nil, mockInvoker(timeOut/2, false))
} }
func TestNewGRPC(t *testing.T) {
t.Parallel()
tests := []struct {
name string
opts *Options
wantErr bool
wantErrStr string
wantTarget string
}{
{"empty connection", &Options{Addrs: nil}, true, "proxy/authenticator: connection address required", ""},
{"both internal and addr empty", &Options{Addrs: nil}, true, "proxy/authenticator: connection address required", ""},
{"addr with port", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:8443"}}}, false, "", "pomerium:///localhost.example:8443"},
{"secure addr without port", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example"}}}, false, "", "pomerium:///localhost.example:443"},
{"insecure addr without port", &Options{Addrs: []*url.URL{{Scheme: "http", Host: "localhost.example"}}}, false, "", "pomerium:///localhost.example:80"},
{"cert override", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:443"}}, OverrideCertificateName: "*.local"}, false, "", "pomerium:///localhost.example:443"},
{"custom ca", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:443"}}, OverrideCertificateName: "*.local", CA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURFVENDQWZrQ0ZBWHhneFg5K0hjWlBVVVBEK0laV0NGNUEvVTdNQTBHQ1NxR1NJYjNEUUVCQ3dVQU1FVXgKQ3pBSkJnTlZCQVlUQWtGVk1STXdFUVlEVlFRSURBcFRiMjFsTFZOMFlYUmxNU0V3SHdZRFZRUUtEQmhKYm5SbApjbTVsZENCWGFXUm5hWFJ6SUZCMGVTQk1kR1F3SGhjTk1Ua3dNakk0TVRnMU1EQTNXaGNOTWprd01qSTFNVGcxCk1EQTNXakJGTVFzd0NRWURWUVFHRXdKQlZURVRNQkVHQTFVRUNBd0tVMjl0WlMxVGRHRjBaVEVoTUI4R0ExVUUKQ2d3WVNXNTBaWEp1WlhRZ1YybGtaMmwwY3lCUWRIa2dUSFJrTUlJQklqQU5CZ2txaGtpRzl3MEJBUUVGQUFPQwpBUThBTUlJQkNnS0NBUUVBOVRFMEFiaTdnMHhYeURkVUtEbDViNTBCT05ZVVVSc3F2THQrSWkwdlpjMzRRTHhOClJrT0hrOFZEVUgzcUt1N2UrNGVubUdLVVNUdzRPNFlkQktiSWRJTFpnb3o0YitNL3FVOG5adVpiN2pBVTdOYWkKajMzVDVrbXB3L2d4WHNNUzNzdUpXUE1EUDB3Z1BUZUVRK2J1bUxVWmpLdUVIaWNTL0l5dmtaVlBzRlE4NWlaUwpkNXE2a0ZGUUdjWnFXeFg0dlhDV25Sd3E3cHY3TThJd1RYc1pYSVRuNXB5Z3VTczNKb29GQkg5U3ZNTjRKU25GCmJMK0t6ekduMy9ScXFrTXpMN3FUdkMrNWxVT3UxUmNES21mZXBuVGVaN1IyVnJUQm42NndWMjVHRnBkSDIzN00KOXhJVkJrWEd1U2NvWHVPN1lDcWFrZkt6aXdoRTV4UmRaa3gweXdJREFRQUJNQTBHQ1NxR1NJYjNEUUVCQ3dVQQpBNElCQVFCaHRWUEI0OCs4eFZyVmRxM1BIY3k5QkxtVEtrRFl6N2Q0ODJzTG1HczBuVUdGSTFZUDdmaFJPV3ZxCktCTlpkNEI5MUpwU1NoRGUrMHpoNno4WG5Ha01mYnRSYWx0NHEwZ3lKdk9hUWhqQ3ZCcSswTFk5d2NLbXpFdnMKcTRiNUZ5NXNpRUZSekJLTmZtTGwxTTF2cW1hNmFCVnNYUUhPREdzYS83dE5MalZ2ay9PYm52cFg3UFhLa0E3cQpLMTQvV0tBRFBJWm9mb00xMzB4Q1RTYXVpeXROajlnWkx1WU9leEZhblVwNCt2MHBYWS81OFFSNTk2U0ROVTlKClJaeDhwTzBTaUYvZXkxVUZXbmpzdHBjbTQzTFVQKzFwU1hFeVhZOFJrRTI2QzNvdjNaTFNKc2pMbC90aXVqUlgKZUJPOWorWDdzS0R4amdtajBPbWdpVkpIM0YrUAotLS0tLUVORCBDRVJUSUZJQ0FURS0tLS0tCg=="}, false, "", "pomerium:///localhost.example:443"},
{"bad ca encoding", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:443"}}, OverrideCertificateName: "*.local", CA: "^"}, true, "", "pomerium:///localhost.example:443"},
{"custom ca file", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:443"}}, OverrideCertificateName: "*.local", CAFile: "testdata/example.crt"}, false, "", "pomerium:///localhost.example:443"},
{"bad custom ca file", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:443"}}, OverrideCertificateName: "*.local", CAFile: "testdata/example.crt2"}, true, "", "pomerium:///localhost.example:443"},
{"valid with insecure", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:8443"}}, WithInsecure: true}, false, "", "pomerium:///localhost.example:8443"},
{"valid client round robin", &Options{Addrs: []*url.URL{{Scheme: "https", Host: "localhost.example:8443"}}, ClientDNSRoundRobin: true}, false, "", "pomerium:///localhost.example:8443"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewGRPCClientConn(context.Background(), tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
if !strings.EqualFold(err.Error(), tt.wantErrStr) {
t.Errorf("New() error = %v did not contain wantErr %v", err, tt.wantErrStr)
}
}
if got != nil && got.Target() != tt.wantTarget {
t.Errorf("New() target = %v expected %v", got.Target(), tt.wantTarget)
}
})
}
}
func TestGetGRPC(t *testing.T) {
cc1, err := GetGRPCClientConn(context.Background(), "example", &Options{
Addrs: mustParseURLs("https://localhost.example"),
})
if !assert.NoError(t, err) {
return
}
cc2, err := GetGRPCClientConn(context.Background(), "example", &Options{
Addrs: mustParseURLs("https://localhost.example"),
})
if !assert.NoError(t, err) {
return
}
assert.Same(t, cc1, cc2, "GetGRPCClientConn should return the same connection when there are no changes")
cc3, err := GetGRPCClientConn(context.Background(), "example", &Options{
Addrs: mustParseURLs("http://localhost.example"),
WithInsecure: true,
})
if !assert.NoError(t, err) {
return
}
assert.NotSame(t, cc1, cc3, "GetGRPCClientConn should return a new connection when there are changes")
}
func mustParseURLs(rawurls ...string) []*url.URL {
var urls []*url.URL
for _, rawurl := range rawurls {
u, err := url.Parse(rawurl)
if err != nil {
panic(err)
}
urls = append(urls, u)
}
return urls
}

View file

@ -1,11 +0,0 @@
package grpc
//go:generate ./protoc.bash
const roundRobinServiceConfig = `{
"loadBalancingConfig": [
{
"round_robin": {}
}
]
}`

View file

@ -1,104 +0,0 @@
package grpc
import (
"strings"
"sync"
"google.golang.org/grpc/resolver"
)
func init() {
resolver.Register(&pomeriumBuilder{})
}
type pomeriumBuilder struct {
}
func (*pomeriumBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
endpoints := strings.Split(target.Endpoint, ",")
pccd := &pomeriumClientConnData{
states: make([]resolver.State, len(endpoints)),
}
pr := &pomeriumResolver{}
for i, endpoint := range endpoints {
subTarget := parseTarget(endpoint)
b := resolver.Get(subTarget.Scheme)
pcc := &pomeriumClientConn{
data: pccd,
idx: i,
ClientConn: cc,
}
r, err := b.Build(subTarget, pcc, opts)
if err != nil {
return nil, err
}
pr.resolvers = append(pr.resolvers, r)
}
return pr, nil
}
func (*pomeriumBuilder) Scheme() string {
return "pomerium"
}
type pomeriumResolver struct {
resolvers []resolver.Resolver
}
func (pr *pomeriumResolver) ResolveNow(options resolver.ResolveNowOptions) {
for _, r := range pr.resolvers {
r.ResolveNow(options)
}
}
func (pr *pomeriumResolver) Close() {
for _, r := range pr.resolvers {
r.Close()
}
}
type pomeriumClientConn struct {
data *pomeriumClientConnData
idx int
resolver.ClientConn
}
func (pcc *pomeriumClientConn) UpdateState(state resolver.State) error {
return pcc.ClientConn.UpdateState(pcc.data.updateState(pcc.idx, state))
}
type pomeriumClientConnData struct {
mu sync.Mutex
states []resolver.State
}
func (pccd *pomeriumClientConnData) updateState(idx int, state resolver.State) resolver.State {
pccd.mu.Lock()
defer pccd.mu.Unlock()
pccd.states[idx] = state
merged := resolver.State{}
for _, s := range pccd.states {
merged.Addresses = append(merged.Addresses, s.Addresses...)
merged.ServiceConfig = s.ServiceConfig
merged.Attributes = s.Attributes
}
return merged
}
func parseTarget(raw string) resolver.Target {
target := resolver.Target{
Scheme: resolver.GetDefaultScheme(),
}
if idx := strings.Index(raw, "://"); idx >= 0 {
target.Scheme = raw[:idx]
raw = raw[idx+3:]
}
if idx := strings.Index(raw, "/"); idx >= 0 {
target.Authority = raw[:idx]
raw = raw[idx+1:]
}
target.Endpoint = raw
return target
}

View file

@ -1,68 +0,0 @@
package grpc
import (
"context"
"net"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/test/grpc_testing"
)
type resolverTestServer struct {
grpc_testing.UnimplementedTestServiceServer
username string
}
func (srv *resolverTestServer) UnaryCall(context.Context, *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) {
return &grpc_testing.SimpleResponse{
Username: srv.username,
}, nil
}
func TestResolver(t *testing.T) {
li1, err := net.Listen("tcp", "127.0.0.1:0")
if !assert.NoError(t, err) {
return
}
defer func() { _ = li1.Close() }()
srv1 := grpc.NewServer()
grpc_testing.RegisterTestServiceServer(srv1, &resolverTestServer{
username: "srv1",
})
go func() { _ = srv1.Serve(li1) }()
li2, err := net.Listen("tcp", "127.0.0.1:0")
if !assert.NoError(t, err) {
return
}
defer func() { _ = li2.Close() }()
srv2 := grpc.NewServer()
grpc_testing.RegisterTestServiceServer(srv2, &resolverTestServer{
username: "srv2",
})
go func() { _ = srv2.Serve(li2) }()
cc, err := grpc.Dial("pomerium:///"+strings.Join([]string{
"dns:///" + li1.Addr().String(),
li2.Addr().String(),
}, ","), grpc.WithInsecure(), grpc.WithDefaultServiceConfig(roundRobinServiceConfig))
if !assert.NoError(t, err) {
return
}
defer func() { _ = cc.Close() }()
c := grpc_testing.NewTestServiceClient(cc)
usernames := map[string]int{}
for i := 0; i < 1000; i++ {
res, err := c.UnaryCall(context.Background(), new(grpc_testing.SimpleRequest))
assert.NoError(t, err)
usernames[res.GetUsername()]++
}
assert.Greater(t, usernames["srv1"], 0)
assert.Greater(t, usernames["srv2"], 0)
}