package main import ( "context" "crypto/tls" "encoding/json" "fmt" "net" "net/http" "testing" "time" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/pomerium/pomerium/integration/internal/flows" "github.com/pomerium/pomerium/integration/internal/netutil" ) func TestCORS(t *testing.T) { ctx := mainCtx ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) defer clearTimeout() t.Run("enabled", func(t *testing.T) { client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "OPTIONS", "https://httpdetails.localhost.pomerium.io/cors-enabled", nil) if err != nil { t.Fatal(err) } req.Header.Set("Access-Control-Request-Method", "GET") req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io") res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode, "unexpected status code") }) t.Run("disabled", func(t *testing.T) { client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "OPTIONS", "https://httpdetails.localhost.pomerium.io/cors-disabled", nil) if err != nil { t.Fatal(err) } req.Header.Set("Access-Control-Request-Method", "GET") req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io") res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() assert.NotEqual(t, http.StatusOK, res.StatusCode, "unexpected status code") }) } func TestPreserveHostHeader(t *testing.T) { ctx := mainCtx ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) defer clearTimeout() t.Run("enabled", func(t *testing.T) { client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/preserve-host-header-enabled", nil) if err != nil { t.Fatal(err) } res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() var result struct { Host string `json:"host"` } err = json.NewDecoder(res.Body).Decode(&result) if !assert.NoError(t, err) { return } assert.Equal(t, "httpdetails.localhost.pomerium.io", result.Host, "destination host should be preserved in %v", result) }) t.Run("disabled", func(t *testing.T) { client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/preserve-host-header-disabled", nil) if err != nil { t.Fatal(err) } res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() var result struct { Host string `json:"host"` } err = json.NewDecoder(res.Body).Decode(&result) if !assert.NoError(t, err) { return } assert.NotEqual(t, "httpdetails.localhost.pomerium.io", result.Host, "destination host should not be preserved in %v", result) }) } func TestSetRequestHeaders(t *testing.T) { ctx := mainCtx ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) defer clearTimeout() client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/", nil) if err != nil { t.Fatal(err) } res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() var result struct { Headers map[string]string `json:"headers"` } err = json.NewDecoder(res.Body).Decode(&result) if !assert.NoError(t, err) { return } assert.Equal(t, "custom-request-header-value", result.Headers["X-Custom-Request-Header"], "expected custom request header to be sent upstream") } func TestRemoveRequestHeaders(t *testing.T) { ctx := mainCtx ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) defer clearTimeout() client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/", nil) if err != nil { t.Fatal(err) } req.Header.Add("X-Custom-Request-Header-To-Remove", "foo") res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() var result struct { Headers map[string]string `json:"headers"` } err = json.NewDecoder(res.Body).Decode(&result) if !assert.NoError(t, err) { return } _, exist := result.Headers["X-Custom-Request-Header-To-Remove"] assert.False(t, exist, "expected X-Custom-Request-Header-To-Remove not to be present.") } func TestWebsocket(t *testing.T) { ctx := mainCtx ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) defer clearTimeout() t.Run("disabled", func(t *testing.T) { ws, _, err := (&websocket.Dialer{ NetDialContext: testcluster.Transport.DialContext, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, }).DialContext(ctx, "wss://disabled-ws-echo.localhost.pomerium.io", nil) if !assert.Error(t, err, "expected bad handshake when websocket is not enabled") { ws.Close() return } }) t.Run("enabled", func(t *testing.T) { ws, _, err := (&websocket.Dialer{ NetDialContext: testcluster.Transport.DialContext, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, }).DialContext(ctx, "wss://enabled-ws-echo.localhost.pomerium.io", nil) if !assert.NoError(t, err, "expected no error when creating websocket") { return } defer ws.Close() msg := "hello world" err = ws.WriteJSON("hello world") assert.NoError(t, err, "expected no error when writing json to websocket") err = ws.ReadJSON(&msg) assert.NoError(t, err, "expected no error when reading json from websocket") }) } func TestTLSSkipVerify(t *testing.T) { ctx := mainCtx ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) defer clearTimeout() t.Run("enabled", func(t *testing.T) { client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-skip-verify-enabled", nil) if err != nil { t.Fatal(err) } res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) }) t.Run("disabled", func(t *testing.T) { client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-skip-verify-disabled", nil) if err != nil { t.Fatal(err) } res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() assert.Contains(t, []int{http.StatusBadGateway, http.StatusServiceUnavailable}, res.StatusCode) }) } func TestTLSServerName(t *testing.T) { ctx := mainCtx ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) defer clearTimeout() t.Run("enabled", func(t *testing.T) { client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-server-name-enabled", nil) if err != nil { t.Fatal(err) } res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) }) t.Run("disabled", func(t *testing.T) { client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-server-name-disabled", nil) if err != nil { t.Fatal(err) } res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() assert.Contains(t, []int{http.StatusBadGateway, http.StatusServiceUnavailable}, res.StatusCode) }) } func TestTLSCustomCA(t *testing.T) { ctx := mainCtx ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) defer clearTimeout() t.Run("enabled", func(t *testing.T) { client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-custom-ca-enabled", nil) if err != nil { t.Fatal(err) } res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) }) t.Run("disabled", func(t *testing.T) { client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-custom-ca-disabled", nil) if err != nil { t.Fatal(err) } res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() assert.Contains(t, []int{http.StatusBadGateway, http.StatusServiceUnavailable}, res.StatusCode) }) } func TestTLSClientCert(t *testing.T) { ctx := mainCtx ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) defer clearTimeout() t.Run("enabled", func(t *testing.T) { client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-client-cert-enabled", nil) if err != nil { t.Fatal(err) } res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) }) t.Run("disabled", func(t *testing.T) { client := testcluster.NewHTTPClient() req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-client-cert-disabled", nil) if err != nil { t.Fatal(err) } res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() assert.Contains(t, []int{http.StatusBadGateway, http.StatusServiceUnavailable}, res.StatusCode) }) } func TestSNIMismatch(t *testing.T) { ctx := mainCtx ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) defer clearTimeout() // Browsers will coalesce connections for the same IP address and TLS certificate // even if the request was made to different domain names. We need to support this // so this test makes a request with an incorrect TLS server name to make sure it // gets routed properly hostport, err := testcluster.GetNodePortAddr(ctx, "default", "pomerium-proxy-nodeport") if err != nil { t.Fatal(err) } client := testcluster.NewHTTPClientWithTransport(&http.Transport{ DialContext: netutil.NewLocalDialer((&net.Dialer{}), map[string]string{ "443": hostport, }).DialContext, TLSClientConfig: &tls.Config{ ServerName: "ws-echo.localhost.pomerium.io", }, }) req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/ping", nil) if err != nil { t.Fatal(err) } res, err := client.Do(req) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) } func TestAttestationJWT(t *testing.T) { ctx := mainCtx ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) defer clearTimeout() client := testcluster.NewHTTPClient() res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-user"), nil, flows.WithEmail("bob@dogs.test"), flows.WithGroups("user")) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() var result struct { Headers map[string]string `json:"headers"` } err = json.NewDecoder(res.Body).Decode(&result) if !assert.NoError(t, err) { return } assert.NotEmpty(t, result.Headers["X-Pomerium-Jwt-Assertion"], "Expected JWT assertion") } func TestPassIdentityHeaders(t *testing.T) { ctx := mainCtx ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) defer clearTimeout() tests := []struct { name string path string wantExist bool }{ {"enabled", "/by-user", true}, {"disabled", "/by-domain", false}, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { client := testcluster.NewHTTPClient() res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io"+tc.path), nil, flows.WithEmail("bob@dogs.test"), flows.WithGroups("user")) if !assert.NoError(t, err, "unexpected http error") { return } defer res.Body.Close() var result struct { Headers map[string]string `json:"headers"` } err = json.NewDecoder(res.Body).Decode(&result) if !assert.NoError(t, err) { return } for _, header := range []string{"X-Pomerium-Jwt-Assertion", "X-Pomerium-Claim-Email"} { _, exist := result.Headers[header] assert.True(t, exist == tc.wantExist, fmt.Sprintf("Header %s, expected: %v, got: %v", header, tc.wantExist, exist)) } }) } }