From a8a703218fa6e1a91a667521dc4499c4c12d38ac Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Tue, 26 Jan 2021 14:40:39 -0700 Subject: [PATCH] return errors in xds build methods (#1827) --- internal/controlplane/xds.go | 7 +- internal/controlplane/xds_cluster_test.go | 79 +++--- internal/controlplane/xds_clusters.go | 81 ++++-- internal/controlplane/xds_listeners.go | 270 ++++++++++++-------- internal/controlplane/xds_listeners_test.go | 16 +- internal/controlplane/xds_routes.go | 154 +++++++---- internal/controlplane/xds_routes_test.go | 68 +++-- 7 files changed, 447 insertions(+), 228 deletions(-) diff --git a/internal/controlplane/xds.go b/internal/controlplane/xds.go index 86361cc89..984c4ffc7 100644 --- a/internal/controlplane/xds.go +++ b/internal/controlplane/xds.go @@ -47,7 +47,12 @@ func (srv *Server) buildDiscoveryResources() (map[string][]*envoy_service_discov Resource: any, }) } - for _, listener := range srv.buildListeners(cfg.Config) { + + listeners, err := srv.buildListeners(cfg.Config) + if err != nil { + return nil, err + } + for _, listener := range listeners { any, _ := anypb.New(listener) resources[listenerTypeURL] = append(resources[listenerTypeURL], &envoy_service_discovery_v3.Resource{ Name: listener.Name, diff --git a/internal/controlplane/xds_cluster_test.go b/internal/controlplane/xds_cluster_test.go index 982e6c115..11e2c0eaa 100644 --- a/internal/controlplane/xds_cluster_test.go +++ b/internal/controlplane/xds_cluster_test.go @@ -25,11 +25,17 @@ func Test_buildPolicyTransportSocket(t *testing.T) { rootCA := srv.filemgr.FileDataSource(rootCAPath).GetFilename() t.Run("insecure", func(t *testing.T) { - assert.Nil(t, srv.buildPolicyTransportSocket(&config.Policy{ + ts, err := srv.buildPolicyTransportSocket(&config.Policy{ Destinations: mustParseURLs("http://example.com"), - }, mustParseURL("http://example.com"))) + }, mustParseURL("http://example.com")) + require.NoError(t, err) + assert.Nil(t, ts) }) t.Run("host as sni", func(t *testing.T) { + ts, err := srv.buildPolicyTransportSocket(&config.Policy{ + Destinations: mustParseURLs("https://example.com"), + }, mustParseURL("https://example.com")) + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { "name": "tls", @@ -57,11 +63,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) { "sni": "example.com" } } - `, srv.buildPolicyTransportSocket(&config.Policy{ - Destinations: mustParseURLs("https://example.com"), - }, mustParseURL("https://example.com"))) + `, ts) }) t.Run("tls_server_name as sni", func(t *testing.T) { + ts, err := srv.buildPolicyTransportSocket(&config.Policy{ + Destinations: mustParseURLs("https://example.com"), + TLSServerName: "use-this-name.example.com", + }, mustParseURL("https://example.com")) + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { "name": "tls", @@ -89,12 +98,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) { "sni": "use-this-name.example.com" } } - `, srv.buildPolicyTransportSocket(&config.Policy{ - Destinations: mustParseURLs("https://example.com"), - TLSServerName: "use-this-name.example.com", - }, mustParseURL("https://example.com"))) + `, ts) }) t.Run("tls_skip_verify", func(t *testing.T) { + ts, err := srv.buildPolicyTransportSocket(&config.Policy{ + Destinations: mustParseURLs("https://example.com"), + TLSSkipVerify: true, + }, mustParseURL("https://example.com")) + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { "name": "tls", @@ -123,12 +134,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) { "sni": "example.com" } } - `, srv.buildPolicyTransportSocket(&config.Policy{ - Destinations: mustParseURLs("https://example.com"), - TLSSkipVerify: true, - }, mustParseURL("https://example.com"))) + `, ts) }) t.Run("custom ca", func(t *testing.T) { + ts, err := srv.buildPolicyTransportSocket(&config.Policy{ + Destinations: mustParseURLs("https://example.com"), + TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}), + }, mustParseURL("https://example.com")) + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { "name": "tls", @@ -156,13 +169,15 @@ func Test_buildPolicyTransportSocket(t *testing.T) { "sni": "example.com" } } - `, srv.buildPolicyTransportSocket(&config.Policy{ - Destinations: mustParseURLs("https://example.com"), - TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}), - }, mustParseURL("https://example.com"))) + `, ts) }) t.Run("client certificate", func(t *testing.T) { clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey) + ts, err := srv.buildPolicyTransportSocket(&config.Policy{ + Destinations: mustParseURLs("https://example.com"), + ClientCertificate: clientCert, + }, mustParseURL("https://example.com")) + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { "name": "tls", @@ -198,10 +213,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { "sni": "example.com" } } - `, srv.buildPolicyTransportSocket(&config.Policy{ - Destinations: mustParseURLs("https://example.com"), - ClientCertificate: clientCert, - }, mustParseURL("https://example.com"))) + `, ts) }) } @@ -210,12 +222,13 @@ func Test_buildCluster(t *testing.T) { rootCAPath, _ := getRootCertificateAuthority() rootCA := srv.filemgr.FileDataSource(rootCAPath).GetFilename() t.Run("insecure", func(t *testing.T) { - endpoints := srv.buildPolicyEndpoints(&config.Policy{ + endpoints, err := srv.buildPolicyEndpoints(&config.Policy{ Destinations: mustParseURLs("http://example.com", "http://1.2.3.4"), }) + require.NoError(t, err) cluster := newDefaultEnvoyClusterConfig() cluster.DnsLookupFamily = envoy_config_cluster_v3.Cluster_V4_ONLY - err := buildCluster(cluster, "example", endpoints, true) + err = srv.buildCluster(cluster, "example", endpoints, true) require.NoErrorf(t, err, "cluster %+v", cluster) testutil.AssertProtoJSONEqual(t, ` { @@ -257,14 +270,15 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("secure", func(t *testing.T) { - endpoints := srv.buildPolicyEndpoints(&config.Policy{ + endpoints, err := srv.buildPolicyEndpoints(&config.Policy{ Destinations: mustParseURLs( "https://example.com", "https://example.com", ), }) + require.NoError(t, err) cluster := newDefaultEnvoyClusterConfig() - err := buildCluster(cluster, "example", endpoints, true) + err = srv.buildCluster(cluster, "example", endpoints, true) require.NoErrorf(t, err, "cluster %+v", cluster) testutil.AssertProtoJSONEqual(t, ` { @@ -351,11 +365,12 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("ip addresses", func(t *testing.T) { - endpoints := srv.buildPolicyEndpoints(&config.Policy{ + endpoints, err := srv.buildPolicyEndpoints(&config.Policy{ Destinations: mustParseURLs("http://127.0.0.1", "http://127.0.0.2"), }) + require.NoError(t, err) cluster := newDefaultEnvoyClusterConfig() - err := buildCluster(cluster, "example", endpoints, true) + err = srv.buildCluster(cluster, "example", endpoints, true) require.NoErrorf(t, err, "cluster %+v", cluster) testutil.AssertProtoJSONEqual(t, ` { @@ -396,11 +411,12 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("localhost", func(t *testing.T) { - endpoints := srv.buildPolicyEndpoints(&config.Policy{ + endpoints, err := srv.buildPolicyEndpoints(&config.Policy{ Destinations: mustParseURLs("http://localhost"), }) + require.NoError(t, err) cluster := newDefaultEnvoyClusterConfig() - err := buildCluster(cluster, "example", endpoints, true) + err = srv.buildCluster(cluster, "example", endpoints, true) require.NoErrorf(t, err, "cluster %+v", cluster) testutil.AssertProtoJSONEqual(t, ` { @@ -431,16 +447,17 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("outlier", func(t *testing.T) { - endpoints := srv.buildPolicyEndpoints(&config.Policy{ + endpoints, err := srv.buildPolicyEndpoints(&config.Policy{ Destinations: mustParseURLs("http://example.com"), }) + require.NoError(t, err) cluster := newDefaultEnvoyClusterConfig() cluster.DnsLookupFamily = envoy_config_cluster_v3.Cluster_V4_ONLY cluster.OutlierDetection = &envoy_config_cluster_v3.OutlierDetection{ EnforcingConsecutive_5Xx: wrapperspb.UInt32(17), SplitExternalLocalOriginErrors: true, } - err := buildCluster(cluster, "example", endpoints, true) + err = srv.buildCluster(cluster, "example", endpoints, true) require.NoErrorf(t, err, "cluster %+v", cluster) testutil.AssertProtoJSONEqual(t, ` { diff --git a/internal/controlplane/xds_clusters.go b/internal/controlplane/xds_clusters.go index 28c562a71..d13477789 100644 --- a/internal/controlplane/xds_clusters.go +++ b/internal/controlplane/xds_clusters.go @@ -96,8 +96,11 @@ func (srv *Server) buildClusters(options *config.Options) ([]*envoy_config_clust func (srv *Server) buildInternalCluster(options *config.Options, name string, dst *url.URL, forceHTTP2 bool) (*envoy_config_cluster_v3.Cluster, error) { cluster := newDefaultEnvoyClusterConfig() cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily) - endpoints := []Endpoint{NewEndpoint(dst, srv.buildInternalTransportSocket(options, dst))} - if err := buildCluster(cluster, name, endpoints, forceHTTP2); err != nil { + endpoints, err := srv.buildInternalEndpoints(options, dst) + if err != nil { + return nil, err + } + if err := srv.buildCluster(cluster, name, endpoints, forceHTTP2); err != nil { return nil, err } return cluster, nil @@ -107,7 +110,10 @@ func (srv *Server) buildPolicyCluster(options *config.Options, policy *config.Po cluster := policy.EnvoyOpts name := getPolicyName(policy) - endpoints := srv.buildPolicyEndpoints(policy) + endpoints, err := srv.buildPolicyEndpoints(policy) + if err != nil { + return nil, err + } if cluster.DnsLookupFamily == envoy_config_cluster_v3.Cluster_AUTO { cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily) @@ -117,24 +123,38 @@ func (srv *Server) buildPolicyCluster(options *config.Options, policy *config.Po cluster.DnsLookupFamily = envoy_config_cluster_v3.Cluster_V4_ONLY } - if err := buildCluster(cluster, name, endpoints, false); err != nil { + if err := srv.buildCluster(cluster, name, endpoints, false); err != nil { return nil, err } return cluster, nil } -func (srv *Server) buildPolicyEndpoints(policy *config.Policy) []Endpoint { +func (srv *Server) buildInternalEndpoints(options *config.Options, dst *url.URL) ([]Endpoint, error) { var endpoints []Endpoint - for _, dst := range policy.Destinations { - endpoints = append(endpoints, NewEndpoint(dst, srv.buildPolicyTransportSocket(policy, dst))) + if ts, err := srv.buildInternalTransportSocket(options, dst); err != nil { + return nil, err + } else { + endpoints = append(endpoints, NewEndpoint(dst, ts)) } - return endpoints + return endpoints, nil } -func (srv *Server) buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *envoy_config_core_v3.TransportSocket { +func (srv *Server) buildPolicyEndpoints(policy *config.Policy) ([]Endpoint, error) { + var endpoints []Endpoint + for _, dst := range policy.Destinations { + if ts, err := srv.buildPolicyTransportSocket(policy, dst); err != nil { + return nil, err + } else { + endpoints = append(endpoints, NewEndpoint(dst, ts)) + } + } + return endpoints, nil +} + +func (srv *Server) buildInternalTransportSocket(options *config.Options, endpoint *url.URL) (*envoy_config_core_v3.TransportSocket, error) { if endpoint.Scheme != "https" { - return nil + return nil, nil } sni := endpoint.Hostname() if options.OverrideCertificateName != "" { @@ -178,12 +198,17 @@ func (srv *Server) buildInternalTransportSocket(options *config.Options, endpoin ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ TypedConfig: tlsConfig, }, - } + }, nil } -func (srv *Server) buildPolicyTransportSocket(policy *config.Policy, dst *url.URL) *envoy_config_core_v3.TransportSocket { +func (srv *Server) buildPolicyTransportSocket(policy *config.Policy, dst *url.URL) (*envoy_config_core_v3.TransportSocket, error) { if dst == nil || dst.Scheme != "https" { - return nil + return nil, nil + } + + vc, err := srv.buildPolicyValidationContext(policy, dst) + if err != nil { + return nil, err } sni := dst.Hostname() @@ -202,7 +227,7 @@ func (srv *Server) buildPolicyTransportSocket(policy *config.Policy, dst *url.UR }, AlpnProtocols: []string{"http/1.1"}, ValidationContextType: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext{ - ValidationContext: srv.buildPolicyValidationContext(policy, dst), + ValidationContext: vc, }, }, Sni: sni, @@ -218,12 +243,12 @@ func (srv *Server) buildPolicyTransportSocket(policy *config.Policy, dst *url.UR ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ TypedConfig: tlsConfig, }, - } + }, nil } -func (srv *Server) buildPolicyValidationContext(policy *config.Policy, dst *url.URL) *envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext { +func (srv *Server) buildPolicyValidationContext(policy *config.Policy, dst *url.URL) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) { if dst == nil { - return nil + return nil, nil } sni := dst.Hostname() @@ -258,10 +283,10 @@ func (srv *Server) buildPolicyValidationContext(policy *config.Policy, dst *url. validationContext.TrustChainVerification = envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext_ACCEPT_UNTRUSTED } - return validationContext + return validationContext, nil } -func buildCluster( +func (srv *Server) buildCluster( cluster *envoy_config_cluster_v3.Cluster, name string, endpoints []Endpoint, @@ -275,7 +300,10 @@ func buildCluster( cluster.ConnectTimeout = defaultConnectionTimeout } cluster.RespectDnsTtl = true - lbEndpoints := buildLbEndpoints(endpoints) + lbEndpoints, err := srv.buildLbEndpoints(endpoints) + if err != nil { + return err + } cluster.Name = name cluster.LoadAssignment = &envoy_config_endpoint_v3.ClusterLoadAssignment{ ClusterName: name, @@ -283,7 +311,10 @@ func buildCluster( LbEndpoints: lbEndpoints, }}, } - cluster.TransportSocketMatches = buildTransportSocketMatches(endpoints) + cluster.TransportSocketMatches, err = srv.buildTransportSocketMatches(endpoints) + if err != nil { + return err + } if forceHTTP2 { cluster.Http2ProtocolOptions = &envoy_config_core_v3.Http2ProtocolOptions{ @@ -307,7 +338,7 @@ func buildCluster( return cluster.Validate() } -func buildLbEndpoints(endpoints []Endpoint) []*envoy_config_endpoint_v3.LbEndpoint { +func (srv *Server) buildLbEndpoints(endpoints []Endpoint) ([]*envoy_config_endpoint_v3.LbEndpoint, error) { var lbes []*envoy_config_endpoint_v3.LbEndpoint for _, e := range endpoints { defaultPort := 80 @@ -343,10 +374,10 @@ func buildLbEndpoints(endpoints []Endpoint) []*envoy_config_endpoint_v3.LbEndpoi } lbes = append(lbes, lbe) } - return lbes + return lbes, nil } -func buildTransportSocketMatches(endpoints []Endpoint) []*envoy_config_cluster_v3.Cluster_TransportSocketMatch { +func (srv *Server) buildTransportSocketMatches(endpoints []Endpoint) ([]*envoy_config_cluster_v3.Cluster_TransportSocketMatch, error) { var tsms []*envoy_config_cluster_v3.Cluster_TransportSocketMatch seen := map[string]struct{}{} for _, e := range endpoints { @@ -371,5 +402,5 @@ func buildTransportSocketMatches(endpoints []Endpoint) []*envoy_config_cluster_v TransportSocket: e.transportSocket, }) } - return tsms + return tsms, nil } diff --git a/internal/controlplane/xds_listeners.go b/internal/controlplane/xds_listeners.go index 498b65f26..02bbd3378 100644 --- a/internal/controlplane/xds_listeners.go +++ b/internal/controlplane/xds_listeners.go @@ -39,21 +39,29 @@ func init() { }) } -func (srv *Server) buildListeners(cfg *config.Config) []*envoy_config_listener_v3.Listener { +func (srv *Server) buildListeners(cfg *config.Config) ([]*envoy_config_listener_v3.Listener, error) { var listeners []*envoy_config_listener_v3.Listener if config.IsAuthenticate(cfg.Options.Services) || config.IsProxy(cfg.Options.Services) { - listeners = append(listeners, srv.buildMainListener(cfg)) + if li, err := srv.buildMainListener(cfg); err != nil { + return nil, err + } else { + listeners = append(listeners, li) + } } if config.IsAuthorize(cfg.Options.Services) || config.IsDataBroker(cfg.Options.Services) { - listeners = append(listeners, srv.buildGRPCListener(cfg)) + if li, err := srv.buildGRPCListener(cfg); err != nil { + return nil, err + } else { + listeners = append(listeners, li) + } } - return listeners + return listeners, nil } -func (srv *Server) buildMainListener(cfg *config.Config) *envoy_config_listener_v3.Listener { +func (srv *Server) buildMainListener(cfg *config.Config) (*envoy_config_listener_v3.Listener, error) { listenerFilters := []*envoy_config_listener_v3.ListenerFilter{} if cfg.Options.UseProxyProtocol { proxyCfg := marshalAny(&envoy_extensions_filters_listener_proxy_protocol_v3.ProxyProtocol{}) @@ -66,8 +74,11 @@ func (srv *Server) buildMainListener(cfg *config.Config) *envoy_config_listener_ } if cfg.Options.InsecureServer { - filter := buildMainHTTPConnectionManagerFilter(cfg.Options, + filter, err := srv.buildMainHTTPConnectionManagerFilter(cfg.Options, getAllRouteableDomains(cfg.Options, cfg.Options.Addr), "") + if err != nil { + return nil, err + } return &envoy_config_listener_v3.Listener{ Name: "http-ingress", @@ -78,7 +89,7 @@ func (srv *Server) buildMainListener(cfg *config.Config) *envoy_config_listener_ filter, }, }}, - } + }, nil } tlsInspectorCfg := marshalAny(new(emptypb.Empty)) @@ -89,58 +100,75 @@ func (srv *Server) buildMainListener(cfg *config.Config) *envoy_config_listener_ }, }) + chains, err := srv.buildFilterChains(cfg.Options, cfg.Options.Addr, + func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { + filter, err := srv.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain) + if err != nil { + return nil, err + } + filterChain := &envoy_config_listener_v3.FilterChain{ + Filters: []*envoy_config_listener_v3.Filter{filter}, + } + if tlsDomain != "*" { + filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{ + ServerNames: []string{tlsDomain}, + } + } + tlsContext := srv.buildDownstreamTLSContext(cfg, tlsDomain) + if tlsContext != nil { + tlsConfig := marshalAny(tlsContext) + filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ + Name: "tls", + ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ + TypedConfig: tlsConfig, + }, + } + } + return filterChain, nil + }) + if err != nil { + return nil, err + } + li := &envoy_config_listener_v3.Listener{ Name: "https-ingress", Address: buildAddress(cfg.Options.Addr, 443), ListenerFilters: listenerFilters, - FilterChains: buildFilterChains(cfg.Options, cfg.Options.Addr, - func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain { - filter := buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain) - filterChain := &envoy_config_listener_v3.FilterChain{ - Filters: []*envoy_config_listener_v3.Filter{filter}, - } - if tlsDomain != "*" { - filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{ - ServerNames: []string{tlsDomain}, - } - } - tlsContext := srv.buildDownstreamTLSContext(cfg, tlsDomain) - if tlsContext != nil { - tlsConfig := marshalAny(tlsContext) - filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ - Name: "tls", - ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ - TypedConfig: tlsConfig, - }, - } - } - return filterChain - }), + FilterChains: chains, } - return li + return li, nil } -func buildFilterChains( +func (srv *Server) buildFilterChains( options *config.Options, addr string, - callback func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain, -) []*envoy_config_listener_v3.FilterChain { + callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error), +) ([]*envoy_config_listener_v3.FilterChain, error) { allDomains := getAllRouteableDomains(options, addr) tlsDomains := getAllTLSDomains(options, addr) var chains []*envoy_config_listener_v3.FilterChain for _, domain := range tlsDomains { // first we match on SNI - chains = append(chains, callback(domain, getRouteableDomainsForTLSDomain(options, addr, domain))) + if chain, err := callback(domain, getRouteableDomainsForTLSDomain(options, addr, domain)); err != nil { + return nil, err + } else { + chains = append(chains, chain) + } } + // if there are no SNI matches we match on HTTP host - chains = append(chains, callback("*", allDomains)) - return chains + if chain, err := callback("*", allDomains); err != nil { + return nil, err + } else { + chains = append(chains, chain) + } + return chains, nil } -func buildMainHTTPConnectionManagerFilter( +func (srv *Server) buildMainHTTPConnectionManagerFilter( options *config.Options, domains []string, tlsDomain string, -) *envoy_config_listener_v3.Filter { +) (*envoy_config_listener_v3.Filter, error) { var virtualHosts []*envoy_config_route_v3.VirtualHost for _, domain := range domains { vh := &envoy_config_route_v3.VirtualHost{ @@ -152,27 +180,44 @@ func buildMainHTTPConnectionManagerFilter( // if this is a gRPC service domain and we're supposed to handle that, add those routes if (config.IsAuthorize(options.Services) && hostMatchesDomain(options.GetAuthorizeURL(), domain)) || (config.IsDataBroker(options.Services) && hostMatchesDomain(options.GetDataBrokerURL(), domain)) { - vh.Routes = append(vh.Routes, buildGRPCRoutes()...) + if rs, err := srv.buildGRPCRoutes(); err != nil { + return nil, err + } else { + vh.Routes = append(vh.Routes, rs...) + } } } // these routes match /.pomerium/... and similar paths - vh.Routes = append(vh.Routes, buildPomeriumHTTPRoutes(options, domain)...) + if rs, err := srv.buildPomeriumHTTPRoutes(options, domain); err != nil { + return nil, err + } else { + vh.Routes = append(vh.Routes, rs...) + } // if we're the proxy, add all the policy routes if config.IsProxy(options.Services) { - vh.Routes = append(vh.Routes, buildPolicyRoutes(options, domain)...) + if rs, err := srv.buildPolicyRoutes(options, domain); err != nil { + return nil, err + } else { + vh.Routes = append(vh.Routes, rs...) + } } if len(vh.Routes) > 0 { virtualHosts = append(virtualHosts, vh) } } - virtualHosts = append(virtualHosts, &envoy_config_route_v3.VirtualHost{ - Name: "catch-all", - Domains: []string{"*"}, - Routes: buildPomeriumHTTPRoutes(options, "*"), - }) + + if rs, err := srv.buildPomeriumHTTPRoutes(options, "*"); err != nil { + return nil, err + } else { + virtualHosts = append(virtualHosts, &envoy_config_route_v3.VirtualHost{ + Name: "catch-all", + Domains: []string{"*"}, + Routes: rs, + }) + } var grpcClientTimeout *durationpb.Duration if options.GRPCClientTimeout != 0 { @@ -254,11 +299,15 @@ func buildMainHTTPConnectionManagerFilter( maxStreamDuration = ptypes.DurationProto(options.WriteTimeout) } + rc, err := srv.buildRouteConfiguration("main", virtualHosts) + if err != nil { + return nil, err + } tc := marshalAny(&envoy_http_connection_manager.HttpConnectionManager{ CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO, StatPrefix: "ingress", RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{ - RouteConfig: buildRouteConfiguration("main", virtualHosts), + RouteConfig: rc, }, HttpFilters: filters, AccessLog: buildAccessLogs(options), @@ -280,11 +329,14 @@ func buildMainHTTPConnectionManagerFilter( ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{ TypedConfig: tc, }, - } + }, nil } -func (srv *Server) buildGRPCListener(cfg *config.Config) *envoy_config_listener_v3.Listener { - filter := buildGRPCHTTPConnectionManagerFilter() +func (srv *Server) buildGRPCListener(cfg *config.Config) (*envoy_config_listener_v3.Listener, error) { + filter, err := srv.buildGRPCHTTPConnectionManagerFilter() + if err != nil { + return nil, err + } if cfg.Options.GRPCInsecure { return &envoy_config_listener_v3.Listener{ @@ -295,7 +347,33 @@ func (srv *Server) buildGRPCListener(cfg *config.Config) *envoy_config_listener_ filter, }, }}, - } + }, nil + } + + chains, err := srv.buildFilterChains(cfg.Options, cfg.Options.Addr, + func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { + filterChain := &envoy_config_listener_v3.FilterChain{ + Filters: []*envoy_config_listener_v3.Filter{filter}, + } + if tlsDomain != "*" { + filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{ + ServerNames: []string{tlsDomain}, + } + } + tlsContext := srv.buildDownstreamTLSContext(cfg, tlsDomain) + if tlsContext != nil { + tlsConfig := marshalAny(tlsContext) + filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ + Name: "tls", + ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ + TypedConfig: tlsConfig, + }, + } + } + return filterChain, nil + }) + if err != nil { + return nil, err } tlsInspectorCfg := marshalAny(new(emptypb.Empty)) @@ -308,33 +386,41 @@ func (srv *Server) buildGRPCListener(cfg *config.Config) *envoy_config_listener_ TypedConfig: tlsInspectorCfg, }, }}, - FilterChains: buildFilterChains(cfg.Options, cfg.Options.Addr, - func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain { - filterChain := &envoy_config_listener_v3.FilterChain{ - Filters: []*envoy_config_listener_v3.Filter{filter}, - } - if tlsDomain != "*" { - filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{ - ServerNames: []string{tlsDomain}, - } - } - tlsContext := srv.buildDownstreamTLSContext(cfg, tlsDomain) - if tlsContext != nil { - tlsConfig := marshalAny(tlsContext) - filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ - Name: "tls", - ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ - TypedConfig: tlsConfig, - }, - } - } - return filterChain - }), + FilterChains: chains, } - return li + return li, nil } -func buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_v3.Filter { +func (srv *Server) buildGRPCHTTPConnectionManagerFilter() (*envoy_config_listener_v3.Filter, error) { + rc, err := srv.buildRouteConfiguration("grpc", []*envoy_config_route_v3.VirtualHost{{ + Name: "grpc", + Domains: []string{"*"}, + Routes: []*envoy_config_route_v3.Route{{ + Name: "grpc", + Match: &envoy_config_route_v3.RouteMatch{ + PathSpecifier: &envoy_config_route_v3.RouteMatch_Prefix{Prefix: "/"}, + Grpc: &envoy_config_route_v3.RouteMatch_GrpcRouteMatchOptions{}, + }, + Action: &envoy_config_route_v3.Route_Route{ + Route: &envoy_config_route_v3.RouteAction{ + ClusterSpecifier: &envoy_config_route_v3.RouteAction_Cluster{ + Cluster: "pomerium-control-plane-grpc", + }, + // disable the timeout to support grpc streaming + Timeout: &durationpb.Duration{ + Seconds: 0, + }, + IdleTimeout: &durationpb.Duration{ + Seconds: 0, + }, + }, + }, + }}, + }}) + if err != nil { + return nil, err + } + tc := marshalAny(&envoy_http_connection_manager.HttpConnectionManager{ CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO, StatPrefix: "grpc_ingress", @@ -343,31 +429,7 @@ func buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_v3.Filter { Seconds: 15, }, RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{ - RouteConfig: buildRouteConfiguration("grpc", []*envoy_config_route_v3.VirtualHost{{ - Name: "grpc", - Domains: []string{"*"}, - Routes: []*envoy_config_route_v3.Route{{ - Name: "grpc", - Match: &envoy_config_route_v3.RouteMatch{ - PathSpecifier: &envoy_config_route_v3.RouteMatch_Prefix{Prefix: "/"}, - Grpc: &envoy_config_route_v3.RouteMatch_GrpcRouteMatchOptions{}, - }, - Action: &envoy_config_route_v3.Route_Route{ - Route: &envoy_config_route_v3.RouteAction{ - ClusterSpecifier: &envoy_config_route_v3.RouteAction_Cluster{ - Cluster: "pomerium-control-plane-grpc", - }, - // disable the timeout to support grpc streaming - Timeout: &durationpb.Duration{ - Seconds: 0, - }, - IdleTimeout: &durationpb.Duration{ - Seconds: 0, - }, - }, - }, - }}, - }}), + RouteConfig: rc, }, HttpFilters: []*envoy_http_connection_manager.HttpFilter{{ Name: "envoy.filters.http.router", @@ -378,16 +440,16 @@ func buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_v3.Filter { ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{ TypedConfig: tc, }, - } + }, nil } -func buildRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3.VirtualHost) *envoy_config_route_v3.RouteConfiguration { +func (srv *Server) buildRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3.VirtualHost) (*envoy_config_route_v3.RouteConfiguration, error) { return &envoy_config_route_v3.RouteConfiguration{ Name: name, VirtualHosts: virtualHosts, // disable cluster validation since the order of LDS/CDS updates isn't guaranteed ValidateClusters: &wrappers.BoolValue{Value: false}, - } + }, nil } func (srv *Server) buildDownstreamTLSContext(cfg *config.Config, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { diff --git a/internal/controlplane/xds_listeners_test.go b/internal/controlplane/xds_listeners_test.go index f628b5dc1..5131c79a5 100644 --- a/internal/controlplane/xds_listeners_test.go +++ b/internal/controlplane/xds_listeners_test.go @@ -8,6 +8,7 @@ import ( envoy_config_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/controlplane/filemgr" @@ -21,9 +22,12 @@ const ( ) func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) { + srv, _ := NewServer("TEST") + options := config.NewDefaultOptions() options.SkipXffAppend = true - filter := buildMainHTTPConnectionManagerFilter(options, []string{"example.com"}, "*") + filter, err := srv.buildMainHTTPConnectionManagerFilter(options, []string{"example.com"}, "*") + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `{ "name": "envoy.filters.network.http_connection_manager", "typedConfig": { @@ -497,8 +501,10 @@ func Test_hostMatchesDomain(t *testing.T) { } func Test_buildRouteConfiguration(t *testing.T) { + srv := &Server{filemgr: filemgr.NewManager()} virtualHosts := make([]*envoy_config_route_v3.VirtualHost, 10) - routeConfig := buildRouteConfiguration("test-route-configuration", virtualHosts) + routeConfig, err := srv.buildRouteConfiguration("test-route-configuration", virtualHosts) + require.NoError(t, err) assert.Equal(t, "test-route-configuration", routeConfig.GetName()) assert.Equal(t, virtualHosts, routeConfig.GetVirtualHosts()) assert.False(t, routeConfig.GetValidateClusters().GetValue()) @@ -509,10 +515,11 @@ func Test_requireProxyProtocol(t *testing.T) { filemgr: filemgr.NewManager(), } t.Run("required", func(t *testing.T) { - li := srv.buildMainListener(&config.Config{Options: &config.Options{ + li, err := srv.buildMainListener(&config.Config{Options: &config.Options{ UseProxyProtocol: true, InsecureServer: true, }}) + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ { "name": "envoy.filters.listener.proxy_protocol", @@ -523,10 +530,11 @@ func Test_requireProxyProtocol(t *testing.T) { ]`, li.GetListenerFilters()) }) t.Run("not required", func(t *testing.T) { - li := srv.buildMainListener(&config.Config{Options: &config.Options{ + li, err := srv.buildMainListener(&config.Config{Options: &config.Options{ UseProxyProtocol: false, InsecureServer: true, }}) + require.NoError(t, err) assert.Len(t, li.GetListenerFilters(), 0) }) } diff --git a/internal/controlplane/xds_routes.go b/internal/controlplane/xds_routes.go index e5b1a3364..e24842ece 100644 --- a/internal/controlplane/xds_routes.go +++ b/internal/controlplane/xds_routes.go @@ -24,7 +24,7 @@ const ( httpCluster = "pomerium-control-plane-http" ) -func buildGRPCRoutes() []*envoy_config_route_v3.Route { +func (srv *Server) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) { action := &envoy_config_route_v3.Route_Route{ Route: &envoy_config_route_v3.RouteAction{ ClusterSpecifier: &envoy_config_route_v3.RouteAction_Cluster{ @@ -44,47 +44,107 @@ func buildGRPCRoutes() []*envoy_config_route_v3.Route { TypedPerFilterConfig: map[string]*any.Any{ "envoy.filters.http.ext_authz": disableExtAuthz, }, - }} + }}, nil } -func buildPomeriumHTTPRoutes(options *config.Options, domain string) []*envoy_config_route_v3.Route { - routes := []*envoy_config_route_v3.Route{ - // enable ext_authz - buildControlPlanePathRoute("/.pomerium/jwt", true), - // disable ext_authz and passthrough to proxy handlers - buildControlPlanePathRoute("/ping", false), - buildControlPlanePathRoute("/healthz", false), - buildControlPlanePathRoute("/.pomerium/admin", true), - buildControlPlanePrefixRoute("/.pomerium/admin/", true), - buildControlPlanePathRoute("/.pomerium", false), - buildControlPlanePrefixRoute("/.pomerium/", false), - buildControlPlanePathRoute("/.well-known/pomerium", false), - buildControlPlanePrefixRoute("/.well-known/pomerium/", false), +func (srv *Server) buildPomeriumHTTPRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) { + var routes []*envoy_config_route_v3.Route + // enable ext_authz + if r, err := srv.buildControlPlanePathRoute("/.pomerium/jwt", true); err != nil { + return nil, err + } else { + routes = append(routes, r) + } + + // disable ext_authz and passthrough to proxy handlers + if r, err := srv.buildControlPlanePathRoute("/ping", false); err != nil { + return nil, err + } else { + routes = append(routes, r) + } + if r, err := srv.buildControlPlanePathRoute("/healthz", false); err != nil { + return nil, err + } else { + routes = append(routes, r) + } + if r, err := srv.buildControlPlanePathRoute("/.pomerium/admin", true); err != nil { + return nil, err + } else { + routes = append(routes, r) + } + if r, err := srv.buildControlPlanePrefixRoute("/.pomerium/admin/", true); err != nil { + return nil, err + } else { + routes = append(routes, r) + } + if r, err := srv.buildControlPlanePathRoute("/.pomerium", false); err != nil { + return nil, err + } else { + routes = append(routes, r) + } + if r, err := srv.buildControlPlanePrefixRoute("/.pomerium/", false); err != nil { + return nil, err + } else { + routes = append(routes, r) + } + if r, err := srv.buildControlPlanePathRoute("/.well-known/pomerium", false); err != nil { + return nil, err + } else { + routes = append(routes, r) + } + if r, err := srv.buildControlPlanePrefixRoute("/.well-known/pomerium/", false); err != nil { + return nil, err + } else { + routes = append(routes, r) } // per #837, only add robots.txt if there are no unauthenticated routes if !hasPublicPolicyMatchingURL(options, mustParseURL("https://"+domain+"/robots.txt")) { - routes = append(routes, buildControlPlanePathRoute("/robots.txt", false)) + if r, err := srv.buildControlPlanePathRoute("/robots.txt", false); err != nil { + return nil, err + } else { + routes = append(routes, r) + } } // if we're handling authentication, add the oauth2 callback url if config.IsAuthenticate(options.Services) && hostMatchesDomain(options.GetAuthenticateURL(), domain) { - routes = append(routes, buildControlPlanePathRoute(options.AuthenticateCallbackPath, false)) + if r, err := srv.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false); err != nil { + return nil, err + } else { + routes = append(routes, r) + } } // if we're the proxy and this is the forward-auth url if config.IsProxy(options.Services) && options.ForwardAuthURL != nil && hostMatchesDomain(options.GetForwardAuthURL(), domain) { - routes = append(routes, - // disable ext_authz and pass request to proxy handlers that enable authN flow - buildControlPlanePathAndQueryRoute("/verify", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI}), - buildControlPlanePathAndQueryRoute("/", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI}), - buildControlPlanePathAndQueryRoute("/", []string{urlutil.QueryForwardAuthURI}), - // otherwise, enforce ext_authz; pass all other requests through to an upstream - // handler that will simply respond with http status 200 / OK indicating that - // the fronting forward-auth proxy can continue. - buildControlPlaneProtectedPrefixRoute("/")) + // disable ext_authz and pass request to proxy handlers that enable authN flow + if r, err := srv.buildControlPlanePathAndQueryRoute("/verify", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI}); err != nil { + return nil, err + } else { + routes = append(routes, r) + } + if r, err := srv.buildControlPlanePathAndQueryRoute("/", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI}); err != nil { + return nil, err + } else { + routes = append(routes, r) + } + if r, err := srv.buildControlPlanePathAndQueryRoute("/", []string{urlutil.QueryForwardAuthURI}); err != nil { + return nil, err + } else { + routes = append(routes, r) + } + + // otherwise, enforce ext_authz; pass all other requests through to an upstream + // handler that will simply respond with http status 200 / OK indicating that + // the fronting forward-auth proxy can continue. + if r, err := srv.buildControlPlaneProtectedPrefixRoute("/"); err != nil { + return nil, err + } else { + routes = append(routes, r) + } } - return routes + return routes, nil } -func buildControlPlaneProtectedPrefixRoute(prefix string) *envoy_config_route_v3.Route { +func (srv *Server) buildControlPlaneProtectedPrefixRoute(prefix string) (*envoy_config_route_v3.Route, error) { return &envoy_config_route_v3.Route{ Name: "pomerium-protected-prefix-" + prefix, Match: &envoy_config_route_v3.RouteMatch{ @@ -97,10 +157,10 @@ func buildControlPlaneProtectedPrefixRoute(prefix string) *envoy_config_route_v3 }, }, }, - } + }, nil } -func buildControlPlanePathAndQueryRoute(path string, queryparams []string) *envoy_config_route_v3.Route { +func (srv *Server) buildControlPlanePathAndQueryRoute(path string, queryparams []string) (*envoy_config_route_v3.Route, error) { var queryParameterMatchers []*envoy_config_route_v3.QueryParameterMatcher for _, q := range queryparams { queryParameterMatchers = append(queryParameterMatchers, @@ -126,10 +186,10 @@ func buildControlPlanePathAndQueryRoute(path string, queryparams []string) *envo TypedPerFilterConfig: map[string]*any.Any{ "envoy.filters.http.ext_authz": disableExtAuthz, }, - } + }, nil } -func buildControlPlanePathRoute(path string, protected bool) *envoy_config_route_v3.Route { +func (srv *Server) buildControlPlanePathRoute(path string, protected bool) (*envoy_config_route_v3.Route, error) { r := &envoy_config_route_v3.Route{ Name: "pomerium-path-" + path, Match: &envoy_config_route_v3.RouteMatch{ @@ -148,10 +208,10 @@ func buildControlPlanePathRoute(path string, protected bool) *envoy_config_route "envoy.filters.http.ext_authz": disableExtAuthz, } } - return r + return r, nil } -func buildControlPlanePrefixRoute(prefix string, protected bool) *envoy_config_route_v3.Route { +func (srv *Server) buildControlPlanePrefixRoute(prefix string, protected bool) (*envoy_config_route_v3.Route, error) { r := &envoy_config_route_v3.Route{ Name: "pomerium-prefix-" + prefix, Match: &envoy_config_route_v3.RouteMatch{ @@ -170,14 +230,14 @@ func buildControlPlanePrefixRoute(prefix string, protected bool) *envoy_config_r "envoy.filters.http.ext_authz": disableExtAuthz, } } - return r + return r, nil } var getPolicyName = func(policy *config.Policy) string { return fmt.Sprintf("policy-%x", policy.RouteID()) } -func buildPolicyRoutes(options *config.Options, domain string) []*envoy_config_route_v3.Route { +func (srv *Server) buildPolicyRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) { var routes []*envoy_config_route_v3.Route responseHeadersToAdd := toEnvoyHeaders(options.Headers) @@ -222,19 +282,25 @@ func buildPolicyRoutes(options *config.Options, domain string) []*envoy_config_r ResponseHeadersToAdd: responseHeadersToAdd, } if policy.Redirect != nil { - envoyRoute.Action = &envoy_config_route_v3.Route_Redirect{ - Redirect: buildPolicyRouteRedirectAction(policy.Redirect), + action, err := srv.buildPolicyRouteRedirectAction(policy.Redirect) + if err != nil { + return nil, err } + envoyRoute.Action = &envoy_config_route_v3.Route_Redirect{Redirect: action} } else { - envoyRoute.Action = &envoy_config_route_v3.Route_Route{Route: buildPolicyRouteRouteAction(options, &policy)} + action, err := srv.buildPolicyRouteRouteAction(options, &policy) + if err != nil { + return nil, err + } + envoyRoute.Action = &envoy_config_route_v3.Route_Route{Route: action} } routes = append(routes, envoyRoute) } - return routes + return routes, nil } -func buildPolicyRouteRedirectAction(r *config.PolicyRedirect) *envoy_config_route_v3.RedirectAction { +func (srv *Server) buildPolicyRouteRedirectAction(r *config.PolicyRedirect) (*envoy_config_route_v3.RedirectAction, error) { action := &envoy_config_route_v3.RedirectAction{} switch { case r.HTTPSRedirect != nil: @@ -268,10 +334,10 @@ func buildPolicyRouteRedirectAction(r *config.PolicyRedirect) *envoy_config_rout if r.StripQuery != nil { action.StripQuery = *r.StripQuery } - return action + return action, nil } -func buildPolicyRouteRouteAction(options *config.Options, policy *config.Policy) *envoy_config_route_v3.RouteAction { +func (srv *Server) buildPolicyRouteRouteAction(options *config.Options, policy *config.Policy) (*envoy_config_route_v3.RouteAction, error) { clusterName := getPolicyName(policy) routeTimeout := getRouteTimeout(options, policy) idleTimeout := getRouteIdleTimeout(policy) @@ -307,7 +373,7 @@ func buildPolicyRouteRouteAction(options *config.Options, policy *config.Policy) RegexRewrite: regexRewrite, } setHostRewriteOptions(policy, action) - return action + return action, nil } func mkEnvoyHeader(k, v string) *envoy_config_core_v3.HeaderValueOption { diff --git a/internal/controlplane/xds_routes_test.go b/internal/controlplane/xds_routes_test.go index 977a63855..8aa3258c8 100644 --- a/internal/controlplane/xds_routes_test.go +++ b/internal/controlplane/xds_routes_test.go @@ -7,9 +7,11 @@ import ( envoy_config_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/controlplane/filemgr" "github.com/pomerium/pomerium/internal/testutil" ) @@ -22,7 +24,9 @@ func policyNameFunc() func(*config.Policy) string { } func Test_buildGRPCRoutes(t *testing.T) { - routes := buildGRPCRoutes() + srv := &Server{filemgr: filemgr.NewManager()} + routes, err := srv.buildGRPCRoutes() + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` [ { @@ -46,6 +50,7 @@ func Test_buildGRPCRoutes(t *testing.T) { } func Test_buildPomeriumHTTPRoutes(t *testing.T) { + srv := &Server{filemgr: filemgr.NewManager()} routeString := func(typ, name string, protected bool) string { str := `{ "name": "pomerium-` + typ + `-` + name + `", @@ -76,7 +81,8 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) { AuthenticateCallbackPath: "/oauth2/callback", ForwardAuthURL: mustParseURL("https://forward-auth.example.com"), } - routes := buildPomeriumHTTPRoutes(options, "authenticate.example.com") + routes, err := srv.buildPomeriumHTTPRoutes(options, "authenticate.example.com") + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ `+routeString("path", "/.pomerium/jwt", true)+`, @@ -105,7 +111,8 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) { }}, } _ = options.Policies[0].Validate() - routes := buildPomeriumHTTPRoutes(options, "from.example.com") + routes, err := srv.buildPomeriumHTTPRoutes(options, "from.example.com") + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ `+routeString("path", "/.pomerium/jwt", true)+`, @@ -134,7 +141,8 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) { }}, } _ = options.Policies[0].Validate() - routes := buildPomeriumHTTPRoutes(options, "from.example.com") + routes, err := srv.buildPomeriumHTTPRoutes(options, "from.example.com") + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ `+routeString("path", "/.pomerium/jwt", true)+`, @@ -151,7 +159,9 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) { } func Test_buildControlPlanePathRoute(t *testing.T) { - route := buildControlPlanePathRoute("/hello/world", false) + srv := &Server{filemgr: filemgr.NewManager()} + route, err := srv.buildControlPlanePathRoute("/hello/world", false) + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { "name": "pomerium-path-/hello/world", @@ -172,7 +182,9 @@ func Test_buildControlPlanePathRoute(t *testing.T) { } func Test_buildControlPlanePrefixRoute(t *testing.T) { - route := buildControlPlanePrefixRoute("/hello/world/", false) + srv := &Server{filemgr: filemgr.NewManager()} + route, err := srv.buildControlPlanePrefixRoute("/hello/world/", false) + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { "name": "pomerium-prefix-/hello/world/", @@ -197,7 +209,9 @@ func Test_buildPolicyRoutes(t *testing.T) { getPolicyName = f }(getPolicyName) getPolicyName = policyNameFunc() - routes := buildPolicyRoutes(&config.Options{ + + srv := &Server{filemgr: filemgr.NewManager()} + routes, err := srv.buildPolicyRoutes(&config.Options{ CookieName: "pomerium", DefaultUpstreamTimeout: time.Second * 3, Policies: []config.Policy{ @@ -260,6 +274,7 @@ func Test_buildPolicyRoutes(t *testing.T) { }, }, }, "example.com") + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` [ @@ -473,7 +488,7 @@ func Test_buildPolicyRoutes(t *testing.T) { `, routes) t.Run("tcp", func(t *testing.T) { - routes = buildPolicyRoutes(&config.Options{ + routes, err := srv.buildPolicyRoutes(&config.Options{ CookieName: "pomerium", DefaultUpstreamTimeout: time.Second * 3, Policies: []config.Policy{ @@ -488,6 +503,7 @@ func Test_buildPolicyRoutes(t *testing.T) { }, }, }, "example.com:22") + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` [ @@ -555,7 +571,8 @@ func TestAddOptionsHeadersToResponse(t *testing.T) { getPolicyName = f }(getPolicyName) getPolicyName = policyNameFunc() - routes := buildPolicyRoutes(&config.Options{ + srv := &Server{filemgr: filemgr.NewManager()} + routes, err := srv.buildPolicyRoutes(&config.Options{ CookieName: "pomerium", DefaultUpstreamTimeout: time.Second * 3, Policies: []config.Policy{ @@ -566,6 +583,7 @@ func TestAddOptionsHeadersToResponse(t *testing.T) { }, Headers: map[string]string{"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload"}, }, "example.com") + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` [ @@ -609,7 +627,8 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) { getPolicyName = f }(getPolicyName) getPolicyName = policyNameFunc() - routes := buildPolicyRoutes(&config.Options{ + srv := &Server{filemgr: filemgr.NewManager()} + routes, err := srv.buildPolicyRoutes(&config.Options{ CookieName: "pomerium", DefaultUpstreamTimeout: time.Second * 3, Policies: []config.Policy{ @@ -652,6 +671,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) { }, }, }, "example.com") + require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` [ @@ -822,19 +842,22 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) { } func Test_buildPolicyRouteRedirectAction(t *testing.T) { + srv := &Server{filemgr: filemgr.NewManager()} t.Run("HTTPSRedirect", func(t *testing.T) { - action := buildPolicyRouteRedirectAction(&config.PolicyRedirect{ + action, err := srv.buildPolicyRouteRedirectAction(&config.PolicyRedirect{ HTTPSRedirect: proto.Bool(true), }) + require.NoError(t, err) assert.Equal(t, &envoy_config_route_v3.RedirectAction{ SchemeRewriteSpecifier: &envoy_config_route_v3.RedirectAction_HttpsRedirect{ HttpsRedirect: true, }, }, action) - action = buildPolicyRouteRedirectAction(&config.PolicyRedirect{ + action, err = srv.buildPolicyRouteRedirectAction(&config.PolicyRedirect{ HTTPSRedirect: proto.Bool(false), }) + require.NoError(t, err) assert.Equal(t, &envoy_config_route_v3.RedirectAction{ SchemeRewriteSpecifier: &envoy_config_route_v3.RedirectAction_HttpsRedirect{ HttpsRedirect: false, @@ -842,9 +865,10 @@ func Test_buildPolicyRouteRedirectAction(t *testing.T) { }, action) }) t.Run("SchemeRedirect", func(t *testing.T) { - action := buildPolicyRouteRedirectAction(&config.PolicyRedirect{ + action, err := srv.buildPolicyRouteRedirectAction(&config.PolicyRedirect{ SchemeRedirect: proto.String("https"), }) + require.NoError(t, err) assert.Equal(t, &envoy_config_route_v3.RedirectAction{ SchemeRewriteSpecifier: &envoy_config_route_v3.RedirectAction_SchemeRedirect{ SchemeRedirect: "https", @@ -852,25 +876,28 @@ func Test_buildPolicyRouteRedirectAction(t *testing.T) { }, action) }) t.Run("HostRedirect", func(t *testing.T) { - action := buildPolicyRouteRedirectAction(&config.PolicyRedirect{ + action, err := srv.buildPolicyRouteRedirectAction(&config.PolicyRedirect{ HostRedirect: proto.String("HOST"), }) + require.NoError(t, err) assert.Equal(t, &envoy_config_route_v3.RedirectAction{ HostRedirect: "HOST", }, action) }) t.Run("PortRedirect", func(t *testing.T) { - action := buildPolicyRouteRedirectAction(&config.PolicyRedirect{ + action, err := srv.buildPolicyRouteRedirectAction(&config.PolicyRedirect{ PortRedirect: proto.Uint32(1234), }) + require.NoError(t, err) assert.Equal(t, &envoy_config_route_v3.RedirectAction{ PortRedirect: 1234, }, action) }) t.Run("PathRedirect", func(t *testing.T) { - action := buildPolicyRouteRedirectAction(&config.PolicyRedirect{ + action, err := srv.buildPolicyRouteRedirectAction(&config.PolicyRedirect{ PathRedirect: proto.String("PATH"), }) + require.NoError(t, err) assert.Equal(t, &envoy_config_route_v3.RedirectAction{ PathRewriteSpecifier: &envoy_config_route_v3.RedirectAction_PathRedirect{ PathRedirect: "PATH", @@ -878,9 +905,10 @@ func Test_buildPolicyRouteRedirectAction(t *testing.T) { }, action) }) t.Run("PrefixRewrite", func(t *testing.T) { - action := buildPolicyRouteRedirectAction(&config.PolicyRedirect{ + action, err := srv.buildPolicyRouteRedirectAction(&config.PolicyRedirect{ PrefixRewrite: proto.String("PREFIX_REWRITE"), }) + require.NoError(t, err) assert.Equal(t, &envoy_config_route_v3.RedirectAction{ PathRewriteSpecifier: &envoy_config_route_v3.RedirectAction_PrefixRewrite{ PrefixRewrite: "PREFIX_REWRITE", @@ -888,17 +916,19 @@ func Test_buildPolicyRouteRedirectAction(t *testing.T) { }, action) }) t.Run("ResponseCode", func(t *testing.T) { - action := buildPolicyRouteRedirectAction(&config.PolicyRedirect{ + action, err := srv.buildPolicyRouteRedirectAction(&config.PolicyRedirect{ ResponseCode: proto.Int32(301), }) + require.NoError(t, err) assert.Equal(t, &envoy_config_route_v3.RedirectAction{ ResponseCode: 301, }, action) }) t.Run("StripQuery", func(t *testing.T) { - action := buildPolicyRouteRedirectAction(&config.PolicyRedirect{ + action, err := srv.buildPolicyRouteRedirectAction(&config.PolicyRedirect{ StripQuery: proto.Bool(true), }) + require.NoError(t, err) assert.Equal(t, &envoy_config_route_v3.RedirectAction{ StripQuery: true, }, action)