diff --git a/internal/controlplane/xds.go b/internal/controlplane/xds.go index d88b4c934..6ed173bc1 100644 --- a/internal/controlplane/xds.go +++ b/internal/controlplane/xds.go @@ -30,7 +30,7 @@ import ( func (srv *Server) buildDiscoveryResponse(version string, typeURL string, options *config.Options) (*envoy_service_discovery_v3.DiscoveryResponse, error) { switch typeURL { case "type.googleapis.com/envoy.config.listener.v3.Listener": - listeners := srv.buildListeners(options) + listeners := buildListeners(options) anys := make([]*any.Any, len(listeners)) for i, listener := range listeners { a, err := ptypes.MarshalAny(listener) @@ -64,7 +64,7 @@ func (srv *Server) buildDiscoveryResponse(version string, typeURL string, option } } -func (srv *Server) buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog { +func buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog { lvl := options.ProxyLogLevel if lvl == "" { lvl = options.LogLevel @@ -130,7 +130,7 @@ func inlineBytes(bs []byte) *envoy_config_core_v3.DataSource { func inlineBytesAsFilename(name string, bs []byte) *envoy_config_core_v3.DataSource { ext := filepath.Ext(name) - name = fmt.Sprintf("%s-%x%s", name[:len(ext)], xxhash.Sum64(bs), ext) + name = fmt.Sprintf("%s-%x%s", name[:len(name)-len(ext)], xxhash.Sum64(bs), ext) cacheDir, err := os.UserCacheDir() if err != nil { diff --git a/internal/controlplane/xds_cluster_test.go b/internal/controlplane/xds_cluster_test.go new file mode 100644 index 000000000..6361b0e4a --- /dev/null +++ b/internal/controlplane/xds_cluster_test.go @@ -0,0 +1,274 @@ +package controlplane + +import ( + "encoding/base64" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/pomerium/internal/testutil" +) + +func Test_buildPolicyTransportSocket(t *testing.T) { + rootCA, _ := getRootCertificateAuthority() + cacheDir, _ := os.UserCacheDir() + t.Run("insecure", func(t *testing.T) { + assert.Nil(t, buildPolicyTransportSocket(&config.Policy{ + Destination: mustParseURL("http://example.com"), + })) + }) + t.Run("host as sni", func(t *testing.T) { + testutil.AssertProtoJSONEqual(t, ` + { + "name": "tls", + "typedConfig": { + "@type": "type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext", + "commonTlsContext": { + "alpnProtocols": ["http/1.1"], + "validationContext": { + "matchSubjectAltNames": [{ + "exact": "example.com" + }], + "trustedCa": { + "filename": "`+rootCA+`" + } + } + }, + "sni": "example.com" + } + } + `, buildPolicyTransportSocket(&config.Policy{ + Destination: mustParseURL("https://example.com"), + })) + }) + t.Run("tls_server_name as sni", func(t *testing.T) { + testutil.AssertProtoJSONEqual(t, ` + { + "name": "tls", + "typedConfig": { + "@type": "type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext", + "commonTlsContext": { + "alpnProtocols": ["http/1.1"], + "validationContext": { + "matchSubjectAltNames": [{ + "exact": "use-this-name.example.com" + }], + "trustedCa": { + "filename": "`+rootCA+`" + } + } + }, + "sni": "use-this-name.example.com" + } + } + `, buildPolicyTransportSocket(&config.Policy{ + Destination: mustParseURL("https://example.com"), + TLSServerName: "use-this-name.example.com", + })) + }) + t.Run("tls_skip_verify", func(t *testing.T) { + testutil.AssertProtoJSONEqual(t, ` + { + "name": "tls", + "typedConfig": { + "@type": "type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext", + "commonTlsContext": { + "alpnProtocols": ["http/1.1"], + "validationContext": { + "matchSubjectAltNames": [{ + "exact": "example.com" + }], + "trustedCa": { + "filename": "`+rootCA+`" + }, + "trustChainVerification": "ACCEPT_UNTRUSTED" + } + }, + "sni": "example.com" + } + } + `, buildPolicyTransportSocket(&config.Policy{ + Destination: mustParseURL("https://example.com"), + TLSSkipVerify: true, + })) + }) + t.Run("custom ca", func(t *testing.T) { + testutil.AssertProtoJSONEqual(t, ` + { + "name": "tls", + "typedConfig": { + "@type": "type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext", + "commonTlsContext": { + "alpnProtocols": ["http/1.1"], + "validationContext": { + "matchSubjectAltNames": [{ + "exact": "example.com" + }], + "trustedCa": { + "filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-3aefa6fd5cf2deb4.pem")+`" + } + } + }, + "sni": "example.com" + } + } + `, buildPolicyTransportSocket(&config.Policy{ + Destination: mustParseURL("https://example.com"), + TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}), + })) + }) + t.Run("client certificate", func(t *testing.T) { + clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey) + testutil.AssertProtoJSONEqual(t, ` + { + "name": "tls", + "typedConfig": { + "@type": "type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext", + "commonTlsContext": { + "alpnProtocols": ["http/1.1"], + "tlsCertificates": [{ + "certificateChain":{ + "filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-921a8294d2e2ec54.pem")+`" + }, + "privateKey": { + "filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-d5cf35b1e8533e4a.pem")+`" + } + }], + "validationContext": { + "matchSubjectAltNames": [{ + "exact": "example.com" + }], + "trustedCa": { + "filename": "`+rootCA+`" + } + } + }, + "sni": "example.com" + } + } + `, buildPolicyTransportSocket(&config.Policy{ + Destination: mustParseURL("https://example.com"), + ClientCertificate: clientCert, + })) + }) +} + +func Test_buildCluster(t *testing.T) { + rootCA, _ := getRootCertificateAuthority() + t.Run("insecure", func(t *testing.T) { + cluster := buildCluster("example", mustParseURL("http://example.com"), nil, true) + testutil.AssertProtoJSONEqual(t, ` + { + "name": "example", + "type": "STRICT_DNS", + "connectTimeout": "10s", + "respectDnsTtl": true, + "http2ProtocolOptions": { + "allowConnect": true + }, + "loadAssignment": { + "clusterName": "example", + "endpoints": [{ + "lbEndpoints": [{ + "endpoint": { + "address": { + "socketAddress": { + "address": "example.com", + "ipv4Compat": true, + "portValue": 80 + } + } + } + }] + }] + } + } + `, cluster) + }) + t.Run("secure", func(t *testing.T) { + u := mustParseURL("https://example.com") + transportSocket := buildPolicyTransportSocket(&config.Policy{ + Destination: u, + }) + cluster := buildCluster("example", u, transportSocket, true) + testutil.AssertProtoJSONEqual(t, ` + { + "name": "example", + "type": "STRICT_DNS", + "connectTimeout": "10s", + "respectDnsTtl": true, + "transportSocket": { + "name": "tls", + "typedConfig": { + "@type": "type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext", + "commonTlsContext": { + "alpnProtocols": ["http/1.1"], + "validationContext": { + "matchSubjectAltNames": [{ + "exact": "example.com" + }], + "trustedCa": { + "filename": "`+rootCA+`" + } + } + }, + "sni": "example.com" + } + }, + "http2ProtocolOptions": { + "allowConnect": true + }, + "loadAssignment": { + "clusterName": "example", + "endpoints": [{ + "lbEndpoints": [{ + "endpoint": { + "address": { + "socketAddress": { + "address": "example.com", + "ipv4Compat": true, + "portValue": 443 + } + } + } + }] + }] + } + } + `, cluster) + }) + t.Run("ip address", func(t *testing.T) { + cluster := buildCluster("example", mustParseURL("http://127.0.0.1"), nil, true) + testutil.AssertProtoJSONEqual(t, ` + { + "name": "example", + "type": "STATIC", + "connectTimeout": "10s", + "respectDnsTtl": true, + "http2ProtocolOptions": { + "allowConnect": true + }, + "loadAssignment": { + "clusterName": "example", + "endpoints": [{ + "lbEndpoints": [{ + "endpoint": { + "address": { + "socketAddress": { + "address": "127.0.0.1", + "ipv4Compat": true, + "portValue": 80 + } + } + } + }] + }] + } + } + `, cluster) + }) +} diff --git a/internal/controlplane/xds_clusters.go b/internal/controlplane/xds_clusters.go index 32b0f8227..599606d22 100644 --- a/internal/controlplane/xds_clusters.go +++ b/internal/controlplane/xds_clusters.go @@ -33,77 +33,80 @@ func (srv *Server) buildClusters(options *config.Options) []*envoy_config_cluste } clusters := []*envoy_config_cluster_v3.Cluster{ - srv.buildInternalCluster(options, "pomerium-control-plane-grpc", grpcURL, true), - srv.buildInternalCluster(options, "pomerium-control-plane-http", httpURL, false), + buildInternalCluster(options, "pomerium-control-plane-grpc", grpcURL, true), + buildInternalCluster(options, "pomerium-control-plane-http", httpURL, false), } - clusters = append(clusters, srv.buildInternalCluster(options, "pomerium-authz", authzURL, true)) + clusters = append(clusters, buildInternalCluster(options, "pomerium-authz", authzURL, true)) if config.IsProxy(options.Services) { for _, policy := range options.Policies { - clusters = append(clusters, srv.buildPolicyCluster(&policy)) + clusters = append(clusters, buildPolicyCluster(&policy)) } } return clusters } -func (srv *Server) buildInternalCluster(options *config.Options, name string, endpoint *url.URL, forceHTTP2 bool) *envoy_config_cluster_v3.Cluster { - var transportSocket *envoy_config_core_v3.TransportSocket - if endpoint.Scheme == "https" { - sni := endpoint.Hostname() - if options.OverrideCertificateName != "" { - sni = options.OverrideCertificateName +func buildInternalCluster(options *config.Options, name string, endpoint *url.URL, forceHTTP2 bool) *envoy_config_cluster_v3.Cluster { + return buildCluster(name, endpoint, buildInternalTransportSocket(options, endpoint), forceHTTP2) +} + +func buildPolicyCluster(policy *config.Policy) *envoy_config_cluster_v3.Cluster { + name := getPolicyName(policy) + return buildCluster(name, policy.Destination, buildPolicyTransportSocket(policy), false) +} + +func buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *envoy_config_core_v3.TransportSocket { + if endpoint.Scheme != "https" { + return nil + } + sni := endpoint.Hostname() + if options.OverrideCertificateName != "" { + sni = options.OverrideCertificateName + } + validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{ + MatchSubjectAltNames: []*envoy_type_matcher_v3.StringMatcher{{ + MatchPattern: &envoy_type_matcher_v3.StringMatcher_Exact{ + Exact: sni, + }, + }}, + } + if options.CAFile != "" { + validationContext.TrustedCa = inlineFilename(options.CAFile) + } else if options.CA != "" { + bs, err := base64.StdEncoding.DecodeString(options.CA) + if err != nil { + log.Error().Err(err).Msg("invalid custom CA certificate") } - validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{ - MatchSubjectAltNames: []*envoy_type_matcher_v3.StringMatcher{{ - MatchPattern: &envoy_type_matcher_v3.StringMatcher_Exact{ - Exact: sni, - }, - }}, - } - if options.CAFile != "" { - validationContext.TrustedCa = inlineFilename(options.CAFile) - } else if options.CA != "" { - bs, err := base64.StdEncoding.DecodeString(options.CA) - if err != nil { - log.Error().Err(err).Msg("invalid custom CA certificate") - } - validationContext.TrustedCa = inlineBytesAsFilename("custom-ca.pem", bs) + validationContext.TrustedCa = inlineBytesAsFilename("custom-ca.pem", bs) + } else { + rootCA, err := getRootCertificateAuthority() + if err != nil { + log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found") } else { - rootCA, err := getRootCertificateAuthority() - if err != nil { - log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found") - } else { - validationContext.TrustedCa = inlineFilename(rootCA) - } - } - tlsContext := &envoy_extensions_transport_sockets_tls_v3.UpstreamTlsContext{ - CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{ - AlpnProtocols: []string{"h2", "http/1.1"}, - ValidationContextType: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext{ - ValidationContext: validationContext, - }, - }, - Sni: sni, - } - tlsConfig, _ := ptypes.MarshalAny(tlsContext) - transportSocket = &envoy_config_core_v3.TransportSocket{ - Name: "tls", - ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ - TypedConfig: tlsConfig, - }, + validationContext.TrustedCa = inlineFilename(rootCA) } } - return srv.buildCluster(name, endpoint, transportSocket, forceHTTP2) + tlsContext := &envoy_extensions_transport_sockets_tls_v3.UpstreamTlsContext{ + CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{ + AlpnProtocols: []string{"h2", "http/1.1"}, + ValidationContextType: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext{ + ValidationContext: validationContext, + }, + }, + Sni: sni, + } + tlsConfig, _ := ptypes.MarshalAny(tlsContext) + return &envoy_config_core_v3.TransportSocket{ + Name: "tls", + ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ + TypedConfig: tlsConfig, + }, + } } -func (srv *Server) buildPolicyCluster(policy *config.Policy) *envoy_config_cluster_v3.Cluster { - name := getPolicyName(policy) - return srv.buildCluster(name, policy.Destination, srv.buildPolicyTransportSocket(policy), false) -} - -func (srv *Server) buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.TransportSocket { +func buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.TransportSocket { if policy.Destination.Scheme != "https" { return nil } @@ -116,7 +119,7 @@ func (srv *Server) buildPolicyTransportSocket(policy *config.Policy) *envoy_conf CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{ AlpnProtocols: []string{"http/1.1"}, ValidationContextType: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext{ - ValidationContext: srv.buildPolicyValidationContext(policy), + ValidationContext: buildPolicyValidationContext(policy), }, }, Sni: sni, @@ -135,7 +138,7 @@ func (srv *Server) buildPolicyTransportSocket(policy *config.Policy) *envoy_conf } } -func (srv *Server) buildPolicyValidationContext(policy *config.Policy) *envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext { +func buildPolicyValidationContext(policy *config.Policy) *envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext { sni := policy.Destination.Hostname() if policy.TLSServerName != "" { sni = policy.TLSServerName @@ -171,7 +174,7 @@ func (srv *Server) buildPolicyValidationContext(policy *config.Policy) *envoy_ex return validationContext } -func (srv *Server) buildCluster( +func buildCluster( name string, endpoint *url.URL, transportSocket *envoy_config_core_v3.TransportSocket, diff --git a/internal/controlplane/xds_listeners.go b/internal/controlplane/xds_listeners.go index 97a6808c6..1bf06ecd1 100644 --- a/internal/controlplane/xds_listeners.go +++ b/internal/controlplane/xds_listeners.go @@ -33,24 +33,24 @@ func init() { }) } -func (srv *Server) buildListeners(options *config.Options) []*envoy_config_listener_v3.Listener { +func buildListeners(options *config.Options) []*envoy_config_listener_v3.Listener { var listeners []*envoy_config_listener_v3.Listener if config.IsAuthenticate(options.Services) || config.IsProxy(options.Services) { - listeners = append(listeners, srv.buildMainListener(options)) + listeners = append(listeners, buildMainListener(options)) } if config.IsAuthorize(options.Services) || config.IsCache(options.Services) { - listeners = append(listeners, srv.buildGRPCListener(options)) + listeners = append(listeners, buildGRPCListener(options)) } return listeners } -func (srv *Server) buildMainListener(options *config.Options) *envoy_config_listener_v3.Listener { +func buildMainListener(options *config.Options) *envoy_config_listener_v3.Listener { if options.InsecureServer { - filter := srv.buildMainHTTPConnectionManagerFilter(options, - srv.getAllRouteableDomains(options, options.Addr)) + filter := buildMainHTTPConnectionManagerFilter(options, + getAllRouteableDomains(options, options.Addr)) return &envoy_config_listener_v3.Listener{ Name: "http-ingress", @@ -73,9 +73,9 @@ func (srv *Server) buildMainListener(options *config.Options) *envoy_config_list TypedConfig: tlsInspectorCfg, }, }}, - FilterChains: srv.buildFilterChains(options, options.Addr, + FilterChains: buildFilterChains(options, options.Addr, func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain { - filter := srv.buildMainHTTPConnectionManagerFilter(options, httpDomains) + filter := buildMainHTTPConnectionManagerFilter(options, httpDomains) filterChain := &envoy_config_listener_v3.FilterChain{ Filters: []*envoy_config_listener_v3.Filter{filter}, } @@ -84,7 +84,7 @@ func (srv *Server) buildMainListener(options *config.Options) *envoy_config_list ServerNames: []string{tlsDomain}, } } - tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain) + tlsContext := buildDownstreamTLSContext(options, tlsDomain) if tlsContext != nil { tlsConfig, _ := ptypes.MarshalAny(tlsContext) filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ @@ -100,11 +100,11 @@ func (srv *Server) buildMainListener(options *config.Options) *envoy_config_list return li } -func (srv *Server) buildFilterChains( +func buildFilterChains( options *config.Options, addr string, callback func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain, ) []*envoy_config_listener_v3.FilterChain { - allDomains := srv.getAllRouteableDomains(options, addr) + allDomains := getAllRouteableDomains(options, addr) var chains []*envoy_config_listener_v3.FilterChain for _, domain := range allDomains { // first we match on SNI @@ -115,7 +115,7 @@ func (srv *Server) buildFilterChains( return chains } -func (srv *Server) buildMainHTTPConnectionManagerFilter(options *config.Options, domains []string) *envoy_config_listener_v3.Filter { +func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []string) *envoy_config_listener_v3.Filter { var virtualHosts []*envoy_config_route_v3.VirtualHost for _, domain := range domains { vh := &envoy_config_route_v3.VirtualHost{ @@ -127,16 +127,16 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter(options *config.Options, // if this is a gRPC service domain and we're supposed to handle that, add those routes if (config.IsAuthorize(options.Services) && domain == options.AuthorizeURL.Host) || (config.IsCache(options.Services) && domain == options.CacheURL.Host) { - vh.Routes = append(vh.Routes, srv.buildGRPCRoutes()...) + vh.Routes = append(vh.Routes, buildGRPCRoutes()...) } } // these routes match /.pomerium/... and similar paths - vh.Routes = append(vh.Routes, srv.buildPomeriumHTTPRoutes(options, domain)...) + vh.Routes = append(vh.Routes, buildPomeriumHTTPRoutes(options, domain)...) // if we're the proxy, add all the policy routes if config.IsProxy(options.Services) { - vh.Routes = append(vh.Routes, srv.buildPolicyRoutes(options, domain)...) + vh.Routes = append(vh.Routes, buildPolicyRoutes(options, domain)...) } if len(vh.Routes) > 0 { @@ -212,7 +212,7 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter(options *config.Options, Name: "envoy.filters.http.router", }, }, - AccessLog: srv.buildAccessLogs(options), + AccessLog: buildAccessLogs(options), CommonHttpProtocolOptions: &envoy_config_core_v3.HttpProtocolOptions{ IdleTimeout: ptypes.DurationProto(options.IdleTimeout), MaxStreamDuration: maxStreamDuration, @@ -231,8 +231,8 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter(options *config.Options, } } -func (srv *Server) buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listener { - filter := srv.buildGRPCHTTPConnectionManagerFilter() +func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listener { + filter := buildGRPCHTTPConnectionManagerFilter() if options.GRPCInsecure { return &envoy_config_listener_v3.Listener{ @@ -256,7 +256,7 @@ func (srv *Server) buildGRPCListener(options *config.Options) *envoy_config_list TypedConfig: tlsInspectorCfg, }, }}, - FilterChains: srv.buildFilterChains(options, options.Addr, + FilterChains: buildFilterChains(options, options.Addr, func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain { filterChain := &envoy_config_listener_v3.FilterChain{ Filters: []*envoy_config_listener_v3.Filter{filter}, @@ -266,7 +266,7 @@ func (srv *Server) buildGRPCListener(options *config.Options) *envoy_config_list ServerNames: []string{tlsDomain}, } } - tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain) + tlsContext := buildDownstreamTLSContext(options, tlsDomain) if tlsContext != nil { tlsConfig, _ := ptypes.MarshalAny(tlsContext) filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ @@ -282,7 +282,7 @@ func (srv *Server) buildGRPCListener(options *config.Options) *envoy_config_list return li } -func (srv *Server) buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_v3.Filter { +func buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_v3.Filter { tc, _ := ptypes.MarshalAny(&envoy_http_connection_manager.HttpConnectionManager{ CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO, StatPrefix: "grpc_ingress", @@ -321,7 +321,7 @@ func (srv *Server) buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener } } -func (srv *Server) buildDownstreamTLSContext(options *config.Options, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { +func buildDownstreamTLSContext(options *config.Options, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { cert, err := cryptutil.GetCertificateForDomain(options.Certificates, domain) if err != nil { log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain") @@ -354,7 +354,7 @@ func (srv *Server) buildDownstreamTLSContext(options *config.Options, domain str } } -func (srv *Server) getAllRouteableDomains(options *config.Options, addr string) []string { +func getAllRouteableDomains(options *config.Options, addr string) []string { lookup := map[string]struct{}{} if config.IsAuthenticate(options.Services) && addr == options.Addr { lookup[options.AuthenticateURL.Host] = struct{}{} diff --git a/internal/controlplane/xds_listeners_test.go b/internal/controlplane/xds_listeners_test.go new file mode 100644 index 000000000..c3d642a38 --- /dev/null +++ b/internal/controlplane/xds_listeners_test.go @@ -0,0 +1,87 @@ +package controlplane + +import ( + "crypto/tls" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/pomerium/internal/testutil" +) + +const ( + aExampleComCert = `LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVQVENDQXFXZ0F3SUJBZ0lSQUlWMDhHSVFYTWRVT0NXV3FocXlGR3N3RFFZSktvWklodmNOQVFFTEJRQXcKY3pFZU1Cd0dBMVVFQ2hNVmJXdGpaWEowSUdSbGRtVnNiM0J0Wlc1MElFTkJNU1F3SWdZRFZRUUxEQnRqWVd4bApZa0J3YjNBdGIzTWdLRU5oYkdWaUlFUnZlSE5sZVNreEt6QXBCZ05WQkFNTUltMXJZMlZ5ZENCallXeGxZa0J3CmIzQXRiM01nS0VOaGJHVmlJRVJ2ZUhObGVTa3dIaGNOTVRrd05qQXhNREF3TURBd1doY05NekF3TlRJeU1qRXoKT0RRMFdqQlBNU2N3SlFZRFZRUUtFeDV0YTJObGNuUWdaR1YyWld4dmNHMWxiblFnWTJWeWRHbG1hV05oZEdVeApKREFpQmdOVkJBc01HMk5oYkdWaVFIQnZjQzF2Y3lBb1EyRnNaV0lnUkc5NGMyVjVLVENDQVNJd0RRWUpLb1pJCmh2Y05BUUVCQlFBRGdnRVBBRENDQVFvQ2dnRUJBTm1HMWFKaXc0L29SMHFqUDMxUjRXeTZkOUVqZHc5K1kyelQKcjBDbGNYTDYxRk11R0YrKzJRclV6Y0VUZlZ2dGM1OXNQa0xkRHNtZ0Y2VlZCOTkyQ3ArWDlicWczWmQwSXZtbApVbjJvdTM5eUNEYnV2Q0E2d1gwbGNHL2JkRDE3TkRrS0poL3g5SDMzU3h4SG5UamlKdFBhbmt1MUI3ajdtRmM5Ck5jNXRyamFvUHBGaFJqMTJ1L0dWajRhWWs3SStpWHRpZHBjZXp2eWNDT0NtQlIwNHkzeWx5Q2sxSWNMTUhWOEEKNXphUFpVck15ZUtnTE1PTGlDSDBPeHhhUzh0Nk5vTjZudDdmOUp1TUxTN2V5SkxkQW05bGg0c092YXBPVklXZgpJQitaYnk5bkQ1dWl4N3V0a3llWTFOeE05SFZhUmZTQzcrejM4TDBWN3lJZlpCNkFLcWNDQXdFQUFhTndNRzR3CkRnWURWUjBQQVFIL0JBUURBZ1dnTUJNR0ExVWRKUVFNTUFvR0NDc0dBUVVGQndNQk1Bd0dBMVVkRXdFQi93UUMKTUFBd0h3WURWUjBqQkJnd0ZvQVVTaG9mWE5rY1hoMnE0d25uV1oyYmNvMjRYRVF3R0FZRFZSMFJCQkV3RDRJTgpZUzVsZUdGdGNHeGxMbU52YlRBTkJna3Foa2lHOXcwQkFRc0ZBQU9DQVlFQVA3aHVraThGeG54azRoVnJYUk93Ck51Uy9OUFhmQ3VaVDZWemJYUVUxbWNrZmhweVNDajVRZkFDQzdodVp6Qkp0NEtsUHViWHdRQ25YMFRMSmg1L0cKUzZBWEFXQ3VTSW5jTTZxNGs4MFAzVllWK3hXOS9rdERnTk1FTlNxSjdKR3lqdzBWWHlhOUZwdWd6Q3ZnN290RQo5STcrZTN0cmJnUDBHY3plSml6WTJBMVBWU082MVdKQ1lNQjNDLzcwVE9KMkZTNy82bURPTG9DSVJCY215cW5KClY2Vk5sRDl3Y2xmUWIrZUp0YlY0Vlg2RUY5UEYybUtncUNKT0FKLzBoMHAydTBhZGgzMkJDS2dIMDRSYUtuSS8KUzY1N0MrN1YzVEgzQ1VIVHgrdDRRRll4UEhRL0loQ3pYdUpVeFQzYWtYNEQ1czJkTHp2RnBJMFIzTVBwUE9VQQpUelpSdDI2T3FVNHlUdUFnb0kvZnZMdk55VTNZekF3ZUQ2Mndxc1hiVHAranNFcWpoODUvakpXWnA4RExKK0w3CmhXQW0rSVNKTzhrNWgwR0lIMFllb01heXBJbjRubWVsbHNSM1dvYzZRVTZ4cFFTd3V1NXE0ckJzOUxDWS9kZkwKNkEzMEhlYXVVK2sydGFUVlBMY2FCZm11NDJPaHMyYzQ0bzNPYnlvVkNDNi8KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=` + aExampleComKey = `LS0tLS1CRUdJTiBQUklWQVRFIEtFWS0tLS0tCk1JSUV2Z0lCQURBTkJna3Foa2lHOXcwQkFRRUZBQVNDQktnd2dnU2tBZ0VBQW9JQkFRRFpodFdpWXNPUDZFZEsKb3o5OVVlRnN1bmZSSTNjUGZtTnMwNjlBcFhGeSt0UlRMaGhmdnRrSzFNM0JFMzFiN1hPZmJENUMzUTdKb0JlbApWUWZmZGdxZmwvVzZvTjJYZENMNXBWSjlxTHQvY2dnMjdyd2dPc0Y5SlhCdjIzUTllelE1Q2lZZjhmUjk5MHNjClI1MDQ0aWJUMnA1THRRZTQrNWhYUFRYT2JhNDJxRDZSWVVZOWRydnhsWStHbUpPeVBvbDdZbmFYSHM3OG5BamcKcGdVZE9NdDhwY2dwTlNIQ3pCMWZBT2MyajJWS3pNbmlvQ3pEaTRnaDlEc2NXa3ZMZWphRGVwN2UzL1NiakMwdQozc2lTM1FKdlpZZUxEcjJxVGxTRm55QWZtVzh2WncrYm9zZTdyWk1ubU5UY1RQUjFXa1gwZ3UvczkvQzlGZThpCkgyUWVnQ3FuQWdNQkFBRUNnZ0VCQUsrclFrLzNyck5EQkgvMFFrdTBtbll5U0p6dkpUR3dBaDlhL01jYVZQcGsKTXFCU000RHZJVnlyNnRZb0pTN2VIbWY3QkhUL0RQZ3JmNjBYZEZvMGUvUFN4ckhIUSswUjcwVHBEQ3RLM3REWAppR2JFZWMwVlpqam95VnFzUWIxOUIvbWdocFY1MHRiL3BQcmJvczdUWkVQbTQ3dUVJUTUwc055VEpDYm5VSy8xCnhla2ZmZ3hMbmZlRUxoaXhDNE1XYjMzWG9GNU5VdWduQ2pUakthUFNNUmpISm9YSFlGWjdZdEdlSEd1aDR2UGwKOU5TM0YxT2l0MWNnQzNCSm1BM28yZmhYbTRGR1FhQzNjYUdXTzE5eHAwRWE1eXQ0RHZOTWp5WlgvSkx1Qko0NQpsZU5jUSs3c3U0dW0vY0hqcFFVenlvZmoydFBIU085QXczWGY0L2lmN0hFQ2dZRUE1SWMzMzVKUUhJVlQwc003CnhkY3haYmppbUE5alBWMDFXSXh0di8zbzFJWm5TUGFocEFuYXVwZGZqRkhKZmJTYlZXaUJTaUZpb2RTR3pIdDgKTlZNTGFyVzVreDl5N1luYXdnZjJuQjc2VG03aFl6L3h5T3AxNXFRbmswVW9DdnQ2MHp6dDl5UE5KQ1pWalFwNgp4cUw4T1c4emNlUGpxZzJBTHRtcVhpNitZRXNDZ1lFQTg2ME5zSHMzNktFZE91Q1o1TXF6NVRLSmVYSzQ5ZkdBCjdxcjM5Sm9RcWYzbEhSSWozUlFlNERkWmQ5NUFXcFRKUEJXdnp6NVROOWdwNHVnb3VGc0tCaG82YWtsUEZTUFIKRkZwWCtGZE56eHJGTlAwZHhydmN0bXU2OW91MFR0QU1jd1hYWFJuR1BuK0xDTnVUUHZndHZTTnRwSEZMb0dzUQorVDFpTjhpWS9aVUNnWUJpMVJQVjdkb1ZxNWVuNCtWYTE0azJlL0lMWDBSRkNxV0NpU0VCMGxhNmF2SUtQUmVFCjhQb1dqbGExUWIzSlRxMkxEMm95M0NOaTU1M3dtMHNKYU1QY1A0RmxYa2wrNzRxYk5ZUnkybmJZS3QzdzVYdTAKcjZtVHVOU2d2VnptK3dHUWo1NCtyczRPWDBIS2dJaStsVWhOc29qbUxXK05ZTTlaODZyWmxvK2c1d0tCZ0VMQQplRXlOSko2c2JCWng2cFo3Vk5hSGhwTm5jdldreDc0WnhiMFM2MWUxL3FwOUNxZ0lXQUR5Q0tkR2tmaCtZN1g2Cjl1TmQzbXdnNGpDUGlvQWVLRnZObVl6K01oVEhjQUlVVVo3dFE1cGxhZnAvRUVZZHRuT2VoV1ArbDFFenV3VlQKWjFEUXU3YnBONHdnb25DUWllOFRJbmoydEZIb29vaTBZUkNJK2lnVkFvR0JBSUxaOXd4WDlnMmVNYU9xUFk1dgo5RGxxNFVEZlpaYkprNFZPbmhjR0pWQUNXbmlpNTU0Y1RCSEkxUTdBT0ZQOHRqK3d3YWJBOWRMaUpDdzJzd0E2ClQrdnhiK1NySGxEUnFON3NNRUQ1Z091REo0eHJxRVdLZ3ZkSEsvME9EMC9ZMUFvSCt2aDlJMHVaV0RRNnNLcXcKeFcrbDk0UTZXSW1xYnpDODZsa3JXa0lCCi0tLS0tRU5EIFBSSVZBVEUgS0VZLS0tLS0K` +) + +func Test_buildDownstreamTLSContext(t *testing.T) { + certA, err := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey) + if !assert.NoError(t, err) { + return + } + + downstreamTLSContext := buildDownstreamTLSContext(&config.Options{ + Certificates: []tls.Certificate{*certA}, + }, "a.example.com") + + cacheDir, _ := os.UserCacheDir() + certFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-921a8294d2e2ec54.pem") + keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-d5cf35b1e8533e4a.pem") + + testutil.AssertProtoJSONEqual(t, `{ + "commonTlsContext": { + "alpnProtocols": ["h2", "http/1.1"], + "tlsCertificates": [ + { + "certificateChain": { + "filename": "`+certFileName+`" + }, + "privateKey": { + "filename": "`+keyFileName+`" + } + } + ], + "validationContext": { + "trustChainVerification": "ACCEPT_UNTRUSTED" + } + } + }`, downstreamTLSContext) +} + +func Test_getAllRouteableDomains(t *testing.T) { + options := &config.Options{ + Addr: "127.0.0.1:9000", + GRPCAddr: "127.0.0.1:9001", + Services: "all", + AuthenticateURL: mustParseURL("https://authenticate.example.com"), + AuthorizeURL: mustParseURL("https://authorize.example.com:9001"), + CacheURL: mustParseURL("https://cache.example.com:9001"), + Policies: []config.Policy{ + {Source: &config.StringURL{URL: mustParseURL("https://a.example.com")}}, + {Source: &config.StringURL{URL: mustParseURL("https://b.example.com")}}, + {Source: &config.StringURL{URL: mustParseURL("https://c.example.com")}}, + }, + } + t.Run("http", func(t *testing.T) { + actual := getAllRouteableDomains(options, "127.0.0.1:9000") + expect := []string{ + "a.example.com", + "authenticate.example.com", + "b.example.com", + "c.example.com", + } + assert.Equal(t, expect, actual) + }) + t.Run("grpc", func(t *testing.T) { + actual := getAllRouteableDomains(options, "127.0.0.1:9001") + expect := []string{ + "authorize.example.com:9001", + "cache.example.com:9001", + } + assert.Equal(t, expect, actual) + }) +} diff --git a/internal/controlplane/xds_routes.go b/internal/controlplane/xds_routes.go index cf29a3166..7920ad995 100644 --- a/internal/controlplane/xds_routes.go +++ b/internal/controlplane/xds_routes.go @@ -15,7 +15,7 @@ import ( "github.com/pomerium/pomerium/config" ) -func (srv *Server) buildGRPCRoutes() []*envoy_config_route_v3.Route { +func buildGRPCRoutes() []*envoy_config_route_v3.Route { action := &envoy_config_route_v3.Route_Route{ Route: &envoy_config_route_v3.RouteAction{ ClusterSpecifier: &envoy_config_route_v3.RouteAction_Cluster{ @@ -38,29 +38,27 @@ func (srv *Server) buildGRPCRoutes() []*envoy_config_route_v3.Route { }} } -func (srv *Server) buildPomeriumHTTPRoutes(options *config.Options, domain string) []*envoy_config_route_v3.Route { +func buildPomeriumHTTPRoutes(options *config.Options, domain string) []*envoy_config_route_v3.Route { routes := []*envoy_config_route_v3.Route{ - srv.buildControlPlanePathRoute("/ping"), - srv.buildControlPlanePathRoute("/healthz"), - srv.buildControlPlanePathRoute("/.pomerium"), - srv.buildControlPlanePrefixRoute("/.pomerium/"), - srv.buildControlPlanePathRoute("/.well-known/pomerium"), - srv.buildControlPlanePrefixRoute("/.well-known/pomerium/"), + buildControlPlanePathRoute("/ping"), + buildControlPlanePathRoute("/healthz"), + buildControlPlanePathRoute("/.pomerium"), + buildControlPlanePrefixRoute("/.pomerium/"), + buildControlPlanePathRoute("/.well-known/pomerium"), + buildControlPlanePrefixRoute("/.well-known/pomerium/"), } // if we're handling authentication, add the oauth2 callback url if config.IsAuthenticate(options.Services) && domain == options.AuthenticateURL.Host { - routes = append(routes, - srv.buildControlPlanePathRoute(options.AuthenticateCallbackPath)) + routes = append(routes, buildControlPlanePathRoute(options.AuthenticateCallbackPath)) } // if we're the proxy and this is the forward-auth url if config.IsProxy(options.Services) && options.ForwardAuthURL != nil && domain == options.ForwardAuthURL.Host { - routes = append(routes, - srv.buildControlPlanePrefixRoute("/")) + routes = append(routes, buildControlPlanePrefixRoute("/")) } return routes } -func (srv *Server) buildControlPlanePathRoute(path string) *envoy_config_route_v3.Route { +func buildControlPlanePathRoute(path string) *envoy_config_route_v3.Route { return &envoy_config_route_v3.Route{ Name: "pomerium-path-" + path, Match: &envoy_config_route_v3.RouteMatch{ @@ -79,7 +77,7 @@ func (srv *Server) buildControlPlanePathRoute(path string) *envoy_config_route_v } } -func (srv *Server) buildControlPlanePrefixRoute(prefix string) *envoy_config_route_v3.Route { +func buildControlPlanePrefixRoute(prefix string) *envoy_config_route_v3.Route { return &envoy_config_route_v3.Route{ Name: "pomerium-prefix-" + prefix, Match: &envoy_config_route_v3.RouteMatch{ @@ -98,7 +96,7 @@ func (srv *Server) buildControlPlanePrefixRoute(prefix string) *envoy_config_rou } } -func (srv *Server) buildPolicyRoutes(options *config.Options, domain string) []*envoy_config_route_v3.Route { +func buildPolicyRoutes(options *config.Options, domain string) []*envoy_config_route_v3.Route { var routes []*envoy_config_route_v3.Route for i, policy := range options.Policies { if policy.Source.Host != domain { diff --git a/internal/controlplane/xds_routes_test.go b/internal/controlplane/xds_routes_test.go new file mode 100644 index 000000000..c872205e6 --- /dev/null +++ b/internal/controlplane/xds_routes_test.go @@ -0,0 +1,340 @@ +package controlplane + +import ( + "net/url" + "testing" + "time" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/testutil" +) + +func Test_buildGRPCRoutes(t *testing.T) { + routes := buildGRPCRoutes() + testutil.AssertProtoJSONEqual(t, ` + [ + { + "name": "pomerium-grpc", + "match": { + "grpc": {}, + "prefix": "/" + }, + "route": { + "cluster": "pomerium-control-plane-grpc" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + } + ] + `, routes) +} + +func Test_buildPomeriumHTTPRoutes(t *testing.T) { + routes := buildPomeriumHTTPRoutes(&config.Options{ + Services: "all", + AuthenticateURL: mustParseURL("https://authenticate.example.com"), + AuthenticateCallbackPath: "/oauth2/callback", + ForwardAuthURL: mustParseURL("https://forward-auth.example.com"), + }, "authenticate.example.com") + + testutil.AssertProtoJSONEqual(t, ` + [ + { + "name": "pomerium-path-/ping", + "match": { + "path": "/ping" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-path-/healthz", + "match": { + "path": "/healthz" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-path-/.pomerium", + "match": { + "path": "/.pomerium" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-prefix-/.pomerium/", + "match": { + "prefix": "/.pomerium/" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-path-/.well-known/pomerium", + "match": { + "path": "/.well-known/pomerium" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-prefix-/.well-known/pomerium/", + "match": { + "prefix": "/.well-known/pomerium/" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-path-/oauth2/callback", + "match": { + "path": "/oauth2/callback" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + } + ] + `, routes) +} + +func Test_buildControlPlanePathRoute(t *testing.T) { + route := buildControlPlanePathRoute("/hello/world") + testutil.AssertProtoJSONEqual(t, ` + { + "name": "pomerium-path-/hello/world", + "match": { + "path": "/hello/world" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + } + `, route) +} + +func Test_buildControlPlanePrefixRoute(t *testing.T) { + route := buildControlPlanePrefixRoute("/hello/world/") + testutil.AssertProtoJSONEqual(t, ` + { + "name": "pomerium-prefix-/hello/world/", + "match": { + "prefix": "/hello/world/" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + } + `, route) +} + +func Test_buildPolicyRoutes(t *testing.T) { + routes := buildPolicyRoutes(&config.Options{ + CookieName: "pomerium", + DefaultUpstreamTimeout: time.Second * 3, + Policies: []config.Policy{ + { + Source: &config.StringURL{URL: mustParseURL("https://ignore.example.com")}, + }, + { + Source: &config.StringURL{URL: mustParseURL("https://example.com")}, + }, + { + Source: &config.StringURL{URL: mustParseURL("https://example.com")}, + Path: "/some/path", + AllowWebsockets: true, + PreserveHostHeader: true, + }, + { + Source: &config.StringURL{URL: mustParseURL("https://example.com")}, + Prefix: "/some/prefix/", + SetRequestHeaders: map[string]string{"HEADER-KEY": "HEADER-VALUE"}, + UpstreamTimeout: time.Minute, + }, + { + Source: &config.StringURL{URL: mustParseURL("https://example.com")}, + Regex: `^/[a]+$`, + }, + }, + }, "example.com") + testutil.AssertProtoJSONEqual(t, ` + [ + { + "name": "policy-1", + "match": { + "prefix": "/" + }, + "metadata": { + "filterMetadata": { + "envoy.filters.http.lua": { + "remove_pomerium_authorization": true, + "remove_pomerium_cookie": "pomerium" + } + } + }, + "route": { + "autoHostRewrite": true, + "cluster": "policy-d00072a199d7b614", + "timeout": "3s", + "upgradeConfigs": [{ + "enabled": false, + "upgradeType": "websocket" + }] + } + }, + { + "name": "policy-2", + "match": { + "path": "/some/path" + }, + "metadata": { + "filterMetadata": { + "envoy.filters.http.lua": { + "remove_pomerium_authorization": true, + "remove_pomerium_cookie": "pomerium" + } + } + }, + "route": { + "autoHostRewrite": false, + "cluster": "policy-907a31075a413547", + "timeout": "0s", + "upgradeConfigs": [{ + "enabled": true, + "upgradeType": "websocket" + }] + } + }, + { + "name": "policy-3", + "match": { + "prefix": "/some/prefix/" + }, + "metadata": { + "filterMetadata": { + "envoy.filters.http.lua": { + "remove_pomerium_authorization": true, + "remove_pomerium_cookie": "pomerium" + } + } + }, + "route": { + "autoHostRewrite": true, + "cluster": "policy-f05528f790686bc3", + "timeout": "60s", + "upgradeConfigs": [{ + "enabled": false, + "upgradeType": "websocket" + }] + }, + "requestHeadersToAdd": [{ + "append": false, + "header": { + "key": "HEADER-KEY", + "value": "HEADER-VALUE" + } + }] + }, + { + "name": "policy-4", + "match": { + "safeRegex": { + "googleRe2": {}, + "regex": "^/[a]+$" + } + }, + "metadata": { + "filterMetadata": { + "envoy.filters.http.lua": { + "remove_pomerium_authorization": true, + "remove_pomerium_cookie": "pomerium" + } + } + }, + "route": { + "autoHostRewrite": true, + "cluster": "policy-e5d3a05ff1f97659", + "timeout": "3s", + "upgradeConfigs": [{ + "enabled": false, + "upgradeType": "websocket" + }] + } + } + ] + `, routes) +} + +func mustParseURL(str string) *url.URL { + u, err := url.Parse(str) + if err != nil { + panic(err) + } + return u +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 000000000..c62536109 --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,34 @@ +// Package testutil contains helper functions for unit tests. +package testutil + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/encoding/protojson" +) + +// AssertProtoJSONEqual asserts that a protobuf message matches the given JSON. The protoMsg can also be a slice +// of protobuf messages. +func AssertProtoJSONEqual(t *testing.T, expected string, protoMsg interface{}, msgAndArgs ...interface{}) bool { + protoMsgVal := reflect.ValueOf(protoMsg) + if protoMsgVal.Kind() == reflect.Slice { + var protoMsgs []json.RawMessage + for i := 0; i < protoMsgVal.Len(); i++ { + protoMsgs = append(protoMsgs, toProtoJSON(protoMsgVal.Index(i).Interface())) + } + bs, _ := json.Marshal(protoMsgs) + return assert.JSONEq(t, expected, string(bs), msgAndArgs...) + } + + return assert.JSONEq(t, expected, string(toProtoJSON(protoMsg)), msgAndArgs...) +} + +func toProtoJSON(protoMsg interface{}) json.RawMessage { + v2 := proto.MessageV2(protoMsg) + bs, _ := protojson.Marshal(v2) + return bs +}