diff --git a/integration/cors_test.go b/integration/cors_test.go deleted file mode 100644 index eaa942b0d..000000000 --- a/integration/cors_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package main - -import ( - "context" - "net/http" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestCORS(t *testing.T) { - ctx := mainCtx - ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) - defer clearTimeout() - - t.Run("enabled", func(t *testing.T) { - client := testcluster.NewHTTPClient() - - req, err := http.NewRequestWithContext(ctx, "OPTIONS", "https://httpdetails.localhost.pomerium.io/cors-enabled", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Access-Control-Request-Method", "GET") - req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io") - - res, err := client.Do(req) - if !assert.NoError(t, err, "unexpected http error") { - return - } - defer res.Body.Close() - - assert.Equal(t, http.StatusOK, res.StatusCode, "unexpected status code") - }) - t.Run("disabled", func(t *testing.T) { - client := testcluster.NewHTTPClient() - - req, err := http.NewRequestWithContext(ctx, "OPTIONS", "https://httpdetails.localhost.pomerium.io/cors-disabled", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Access-Control-Request-Method", "GET") - req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io") - - res, err := client.Do(req) - if !assert.NoError(t, err, "unexpected http error") { - return - } - defer res.Body.Close() - - assert.NotEqual(t, http.StatusOK, res.StatusCode, "unexpected status code") - }) -} diff --git a/integration/manifests/lib/pomerium.libsonnet b/integration/manifests/lib/pomerium.libsonnet index d3190d980..d1d6a045e 100644 --- a/integration/manifests/lib/pomerium.libsonnet +++ b/integration/manifests/lib/pomerium.libsonnet @@ -20,6 +20,7 @@ local PomeriumPolicy = function() std.flattenArrays([ to: 'http://' + domain + '.default.svc.cluster.local', allowed_groups: ['admin'], }, + // cors_allow_preflight option { from: 'http://' + domain + '.localhost.pomerium.io', to: 'http://' + domain + '.default.svc.cluster.local', @@ -32,10 +33,28 @@ local PomeriumPolicy = function() std.flattenArrays([ prefix: '/cors-disabled', cors_allow_preflight: false, }, + // preserve_host_header option + { + from: 'http://' + domain + '.localhost.pomerium.io', + to: 'http://' + domain + '.default.svc.cluster.local', + path: '/preserve-host-header-enabled', + allow_public_unauthenticated_access: true, + preserve_host_header: true, + }, + { + from: 'http://' + domain + '.localhost.pomerium.io', + to: 'http://' + domain + '.default.svc.cluster.local', + path: '/preserve-host-header-disabled', + allow_public_unauthenticated_access: true, + preserve_host_header: false, + }, { from: 'http://' + domain + '.localhost.pomerium.io', to: 'http://' + domain + '.default.svc.cluster.local', allow_public_unauthenticated_access: true, + set_request_headers: { + 'X-Custom-Request-Header': 'custom-request-header-value', + }, }, ] for domain in ['httpdetails', 'fa-httpdetails', 'ws-echo'] diff --git a/integration/policy_test.go b/integration/policy_test.go new file mode 100644 index 000000000..49412a490 --- /dev/null +++ b/integration/policy_test.go @@ -0,0 +1,182 @@ +package main + +import ( + "context" + "crypto/tls" + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" +) + +func TestCORS(t *testing.T) { + ctx := mainCtx + ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) + defer clearTimeout() + + t.Run("enabled", func(t *testing.T) { + client := testcluster.NewHTTPClient() + + req, err := http.NewRequestWithContext(ctx, "OPTIONS", "https://httpdetails.localhost.pomerium.io/cors-enabled", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Access-Control-Request-Method", "GET") + req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io") + + res, err := client.Do(req) + if !assert.NoError(t, err, "unexpected http error") { + return + } + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode, "unexpected status code") + }) + t.Run("disabled", func(t *testing.T) { + client := testcluster.NewHTTPClient() + + req, err := http.NewRequestWithContext(ctx, "OPTIONS", "https://httpdetails.localhost.pomerium.io/cors-disabled", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Access-Control-Request-Method", "GET") + req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io") + + res, err := client.Do(req) + if !assert.NoError(t, err, "unexpected http error") { + return + } + defer res.Body.Close() + + assert.NotEqual(t, http.StatusOK, res.StatusCode, "unexpected status code") + }) +} + +func TestPreserveHostHeader(t *testing.T) { + ctx := mainCtx + ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) + defer clearTimeout() + + t.Run("enabled", func(t *testing.T) { + client := testcluster.NewHTTPClient() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/preserve-host-header-enabled", nil) + if err != nil { + t.Fatal(err) + } + + res, err := client.Do(req) + if !assert.NoError(t, err, "unexpected http error") { + return + } + defer res.Body.Close() + + var result struct { + Headers map[string]string `json:"headers"` + } + err = json.NewDecoder(res.Body).Decode(&result) + if !assert.NoError(t, err) { + return + } + + assert.Equal(t, "httpdetails.localhost.pomerium.io", result.Headers["host"], + "destination host should be preserved") + }) + t.Run("disabled", func(t *testing.T) { + client := testcluster.NewHTTPClient() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/preserve-host-header-disabled", nil) + if err != nil { + t.Fatal(err) + } + + res, err := client.Do(req) + if !assert.NoError(t, err, "unexpected http error") { + return + } + defer res.Body.Close() + + var result struct { + Headers map[string]string `json:"headers"` + } + err = json.NewDecoder(res.Body).Decode(&result) + if !assert.NoError(t, err) { + return + } + + assert.NotEqual(t, "httpdetails.localhost.pomerium.io", result.Headers["host"], + "destination host should not be preserved") + }) + +} + +func TestSetRequestHeaders(t *testing.T) { + ctx := mainCtx + ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) + defer clearTimeout() + + client := testcluster.NewHTTPClient() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/", nil) + if err != nil { + t.Fatal(err) + } + + res, err := client.Do(req) + if !assert.NoError(t, err, "unexpected http error") { + return + } + defer res.Body.Close() + + var result struct { + Headers map[string]string `json:"headers"` + } + err = json.NewDecoder(res.Body).Decode(&result) + if !assert.NoError(t, err) { + return + } + + assert.Equal(t, "custom-request-header-value", result.Headers["x-custom-request-header"], + "expected custom request header to be sent upstream") + +} + +func TestWebsocket(t *testing.T) { + ctx := mainCtx + ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) + defer clearTimeout() + + t.Run("disabled", func(t *testing.T) { + ws, _, err := (&websocket.Dialer{ + NetDialContext: testcluster.Transport.DialContext, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }).DialContext(ctx, "wss://disabled-ws-echo.localhost.pomerium.io", nil) + if !assert.Error(t, err, "expected bad handshake when websocket is not enabled") { + ws.Close() + return + } + }) + t.Run("enabled", func(t *testing.T) { + ws, _, err := (&websocket.Dialer{ + NetDialContext: testcluster.Transport.DialContext, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }).DialContext(ctx, "wss://enabled-ws-echo.localhost.pomerium.io", nil) + if !assert.NoError(t, err, "expected no error when creating websocket") { + return + } + defer ws.Close() + + msg := "hello world" + err = ws.WriteJSON("hello world") + assert.NoError(t, err, "expected no error when writing json to websocket") + err = ws.ReadJSON(&msg) + assert.NoError(t, err, "expected no error when reading json from websocket") + }) +} diff --git a/integration/websocket_test.go b/integration/websocket_test.go deleted file mode 100644 index 67728b070..000000000 --- a/integration/websocket_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package main - -import ( - "context" - "crypto/tls" - "testing" - "time" - - "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" -) - -func TestWebsocket(t *testing.T) { - ctx := mainCtx - ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) - defer clearTimeout() - - t.Run("disabled", func(t *testing.T) { - ws, _, err := (&websocket.Dialer{ - NetDialContext: testcluster.Transport.DialContext, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }).DialContext(ctx, "wss://disabled-ws-echo.localhost.pomerium.io", nil) - if !assert.Error(t, err, "expected bad handshake when websocket is not enabled") { - ws.Close() - return - } - }) - t.Run("enabled", func(t *testing.T) { - ws, _, err := (&websocket.Dialer{ - NetDialContext: testcluster.Transport.DialContext, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }).DialContext(ctx, "wss://enabled-ws-echo.localhost.pomerium.io", nil) - if !assert.NoError(t, err, "expected no error when creating websocket") { - return - } - defer ws.Close() - - msg := "hello world" - err = ws.WriteJSON("hello world") - assert.NoError(t, err, "expected no error when writing json to websocket") - err = ws.ReadJSON(&msg) - assert.NoError(t, err, "expected no error when reading json from websocket") - }) -}