diff --git a/integration/internal/cluster/cluster.go b/integration/internal/cluster/cluster.go index 5d0d0d28a..93d6d3632 100644 --- a/integration/internal/cluster/cluster.go +++ b/integration/internal/cluster/cluster.go @@ -24,15 +24,20 @@ func New(workingDir string) *Cluster { } } -// NewHTTPClient creates a new *http.Client, with a cookie jar, and a LocalRoundTripper -// which routes traffic to the nginx ingress controller. +// NewHTTPClient calls NewHTTPClientWithTransport with the default cluster transport. func (cluster *Cluster) NewHTTPClient() *http.Client { + return cluster.NewHTTPClientWithTransport(cluster.Transport) +} + +// NewHTTPClientWithTransport creates a new *http.Client, with a cookie jar, and a LocalRoundTripper +// which routes traffic to the nginx ingress controller. +func (cluster *Cluster) NewHTTPClientWithTransport(transport http.RoundTripper) *http.Client { jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) if err != nil { panic(err) } return &http.Client{ - Transport: &loggingRoundTripper{cluster.Transport}, + Transport: &loggingRoundTripper{transport}, CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, diff --git a/integration/internal/cluster/setup.go b/integration/internal/cluster/setup.go index dcf1e3dc0..3ffcac7cc 100644 --- a/integration/internal/cluster/setup.go +++ b/integration/internal/cluster/setup.go @@ -51,7 +51,7 @@ func (cluster *Cluster) Setup(ctx context.Context) error { return err } - hostport, err := cluster.getNodeHTTPSAddr(ctx) + hostport, err := cluster.GetNodePortAddr(ctx, "ingress-nginx", "ingress-nginx-nodeport") if err != nil { return err } @@ -68,11 +68,12 @@ func (cluster *Cluster) Setup(ctx context.Context) error { return nil } -func (cluster *Cluster) getNodeHTTPSAddr(ctx context.Context) (hostport string, err error) { +// GetNodePortAddr returns the node:port address for a NodePort kubernetes service. +func (cluster *Cluster) GetNodePortAddr(ctx context.Context, namespace, svcName string) (hostport string, err error) { var buf bytes.Buffer - args := []string{"get", "service", "--namespace", "ingress-nginx", "--output", "json", - "ingress-nginx-nodeport"} + args := []string{"get", "service", "--namespace", namespace, "--output", "json", + svcName} err = run(ctx, "kubectl", withArgs(args...), withStdout(&buf)) if err != nil { return "", fmt.Errorf("error getting service details with kubectl: %w", err) @@ -94,7 +95,7 @@ func (cluster *Cluster) getNodeHTTPSAddr(ctx context.Context) (hostport string, buf.Reset() - args = []string{"get", "pods", "--namespace", "ingress-nginx", "--output", "json"} + args = []string{"get", "pods", "--namespace", namespace, "--output", "json"} var sel []string for k, v := range svcResult.Spec.Selector { sel = append(sel, k+"="+v) diff --git a/integration/manifests/lib/pomerium.libsonnet b/integration/manifests/lib/pomerium.libsonnet index bf62034ca..af05629ba 100644 --- a/integration/manifests/lib/pomerium.libsonnet +++ b/integration/manifests/lib/pomerium.libsonnet @@ -178,7 +178,7 @@ local PomeriumDeployment = function(svc) { ip: '10.96.1.1', hostnames: [ 'openid.localhost.pomerium.io', - 'authenticate.localhost.pomerium.io' + 'authenticate.localhost.pomerium.io', ], }], initContainers: [{ @@ -269,6 +269,28 @@ local PomeriumService = function(svc) { }, }; +local PomeriumNodePortServce = function() { + apiVersion: 'v1', + kind: 'Service', + metadata: { + namespace: 'default', + name: 'pomerium-proxy-nodeport', + labels: { + app: 'pomerium-proxy', + 'app.kubernetes.io/part-of': 'pomerium', + }, + }, + spec: { + type: 'NodePort', + ports: [ + { name: 'https', port: 443, protocol: 'TCP', targetPort: 'https', nodePort: 31443 }, + ], + selector: { + app: 'pomerium-proxy', + }, + }, +}; + local PomeriumIngress = function() { local proxyHosts = [ 'forward-authenticate.localhost.pomerium.io', @@ -392,6 +414,7 @@ local PomeriumForwardAuthIngress = function() { PomeriumDeployment('cache'), PomeriumService('proxy'), PomeriumDeployment('proxy'), + PomeriumNodePortServce(), PomeriumIngress(), PomeriumForwardAuthIngress(), ], diff --git a/integration/policy_test.go b/integration/policy_test.go index 49412a490..be6463a5b 100644 --- a/integration/policy_test.go +++ b/integration/policy_test.go @@ -4,11 +4,13 @@ import ( "context" "crypto/tls" "encoding/json" + "net" "net/http" "testing" "time" "github.com/gorilla/websocket" + "github.com/pomerium/pomerium/integration/internal/netutil" "github.com/stretchr/testify/assert" ) @@ -180,3 +182,41 @@ func TestWebsocket(t *testing.T) { assert.NoError(t, err, "expected no error when reading json from websocket") }) } + +func TestSNIMismatch(t *testing.T) { + // Browsers will coalesce connections for the same IP address and TLS certificate + // even if the request was made to different domain names. We need to support this + // so this test makes a request with an incorrect TLS server name to make sure it + // gets routed properly + + ctx := mainCtx + ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) + defer clearTimeout() + + hostport, err := testcluster.GetNodePortAddr(ctx, "default", "pomerium-proxy-nodeport") + if err != nil { + t.Fatal(err) + } + + client := testcluster.NewHTTPClientWithTransport(&http.Transport{ + DialContext: netutil.NewLocalDialer((&net.Dialer{}), map[string]string{ + "443": hostport, + }).DialContext, + TLSClientConfig: &tls.Config{ + ServerName: "ws-echo.localhost.pomerium.io", + }, + }) + + req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/ping", nil) + if err != nil { + t.Fatal(err) + } + + res, err := client.Do(req) + if !assert.NoError(t, err, "unexpected http error") { + return + } + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode) +} diff --git a/internal/controlplane/xds_listeners.go b/internal/controlplane/xds_listeners.go index e7ab2b409..04eae8316 100644 --- a/internal/controlplane/xds_listeners.go +++ b/internal/controlplane/xds_listeners.go @@ -109,7 +109,7 @@ func (srv *Server) buildFilterChains( var chains []*envoy_config_listener_v3.FilterChain for _, domain := range allDomains { // first we match on SNI - chains = append(chains, callback(domain, []string{domain})) + chains = append(chains, callback(domain, allDomains)) } // if there are no SNI matches we match on HTTP host chains = append(chains, callback("*", allDomains))