mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
Currently, user's identity headers are always inserted to downstream request. For privacy reason, it would be better to not insert these headers by default, and let user chose whether to include these headers per=policy basis. Fixes #702
477 lines
13 KiB
Go
477 lines
13 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/pomerium/pomerium/integration/internal/flows"
|
|
"github.com/pomerium/pomerium/integration/internal/netutil"
|
|
)
|
|
|
|
func TestCORS(t *testing.T) {
|
|
ctx := mainCtx
|
|
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
defer clearTimeout()
|
|
|
|
t.Run("enabled", func(t *testing.T) {
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "OPTIONS", "https://httpdetails.localhost.pomerium.io/cors-enabled", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Header.Set("Access-Control-Request-Method", "GET")
|
|
req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io")
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusOK, res.StatusCode, "unexpected status code")
|
|
})
|
|
t.Run("disabled", func(t *testing.T) {
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "OPTIONS", "https://httpdetails.localhost.pomerium.io/cors-disabled", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Header.Set("Access-Control-Request-Method", "GET")
|
|
req.Header.Set("Origin", "https://httpdetails.localhost.pomerium.io")
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
assert.NotEqual(t, http.StatusOK, res.StatusCode, "unexpected status code")
|
|
})
|
|
}
|
|
|
|
func TestPreserveHostHeader(t *testing.T) {
|
|
ctx := mainCtx
|
|
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
defer clearTimeout()
|
|
|
|
t.Run("enabled", func(t *testing.T) {
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/preserve-host-header-enabled", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
var result struct {
|
|
Host string `json:"host"`
|
|
}
|
|
err = json.NewDecoder(res.Body).Decode(&result)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
|
|
assert.Equal(t, "httpdetails.localhost.pomerium.io", result.Host,
|
|
"destination host should be preserved in %v", result)
|
|
})
|
|
t.Run("disabled", func(t *testing.T) {
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/preserve-host-header-disabled", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
var result struct {
|
|
Host string `json:"host"`
|
|
}
|
|
err = json.NewDecoder(res.Body).Decode(&result)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
|
|
assert.NotEqual(t, "httpdetails.localhost.pomerium.io", result.Host,
|
|
"destination host should not be preserved in %v", result)
|
|
})
|
|
|
|
}
|
|
|
|
func TestSetRequestHeaders(t *testing.T) {
|
|
ctx := mainCtx
|
|
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
defer clearTimeout()
|
|
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
var result struct {
|
|
Headers map[string]string `json:"headers"`
|
|
}
|
|
err = json.NewDecoder(res.Body).Decode(&result)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
|
|
assert.Equal(t, "custom-request-header-value", result.Headers["X-Custom-Request-Header"],
|
|
"expected custom request header to be sent upstream")
|
|
|
|
}
|
|
|
|
func TestRemoveRequestHeaders(t *testing.T) {
|
|
ctx := mainCtx
|
|
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
defer clearTimeout()
|
|
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Header.Add("X-Custom-Request-Header-To-Remove", "foo")
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
var result struct {
|
|
Headers map[string]string `json:"headers"`
|
|
}
|
|
err = json.NewDecoder(res.Body).Decode(&result)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
|
|
_, exist := result.Headers["X-Custom-Request-Header-To-Remove"]
|
|
assert.False(t, exist, "expected X-Custom-Request-Header-To-Remove not to be present.")
|
|
|
|
}
|
|
|
|
func TestWebsocket(t *testing.T) {
|
|
ctx := mainCtx
|
|
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
defer clearTimeout()
|
|
|
|
t.Run("disabled", func(t *testing.T) {
|
|
ws, _, err := (&websocket.Dialer{
|
|
NetDialContext: testcluster.Transport.DialContext,
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}).DialContext(ctx, "wss://disabled-ws-echo.localhost.pomerium.io", nil)
|
|
if !assert.Error(t, err, "expected bad handshake when websocket is not enabled") {
|
|
ws.Close()
|
|
return
|
|
}
|
|
})
|
|
t.Run("enabled", func(t *testing.T) {
|
|
ws, _, err := (&websocket.Dialer{
|
|
NetDialContext: testcluster.Transport.DialContext,
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}).DialContext(ctx, "wss://enabled-ws-echo.localhost.pomerium.io", nil)
|
|
if !assert.NoError(t, err, "expected no error when creating websocket") {
|
|
return
|
|
}
|
|
defer ws.Close()
|
|
|
|
msg := "hello world"
|
|
err = ws.WriteJSON("hello world")
|
|
assert.NoError(t, err, "expected no error when writing json to websocket")
|
|
err = ws.ReadJSON(&msg)
|
|
assert.NoError(t, err, "expected no error when reading json from websocket")
|
|
})
|
|
}
|
|
|
|
func TestTLSSkipVerify(t *testing.T) {
|
|
ctx := mainCtx
|
|
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
defer clearTimeout()
|
|
|
|
t.Run("enabled", func(t *testing.T) {
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-skip-verify-enabled", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
|
})
|
|
t.Run("disabled", func(t *testing.T) {
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-skip-verify-disabled", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
assert.Contains(t, []int{http.StatusBadGateway, http.StatusServiceUnavailable}, res.StatusCode)
|
|
})
|
|
}
|
|
|
|
func TestTLSServerName(t *testing.T) {
|
|
ctx := mainCtx
|
|
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
defer clearTimeout()
|
|
|
|
t.Run("enabled", func(t *testing.T) {
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-server-name-enabled", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
|
})
|
|
t.Run("disabled", func(t *testing.T) {
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-server-name-disabled", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
assert.Contains(t, []int{http.StatusBadGateway, http.StatusServiceUnavailable}, res.StatusCode)
|
|
})
|
|
}
|
|
|
|
func TestTLSCustomCA(t *testing.T) {
|
|
ctx := mainCtx
|
|
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
defer clearTimeout()
|
|
|
|
t.Run("enabled", func(t *testing.T) {
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-custom-ca-enabled", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
|
})
|
|
t.Run("disabled", func(t *testing.T) {
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-custom-ca-disabled", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
assert.Contains(t, []int{http.StatusBadGateway, http.StatusServiceUnavailable}, res.StatusCode)
|
|
})
|
|
}
|
|
|
|
func TestTLSClientCert(t *testing.T) {
|
|
ctx := mainCtx
|
|
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
defer clearTimeout()
|
|
t.Run("enabled", func(t *testing.T) {
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-client-cert-enabled", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
|
})
|
|
t.Run("disabled", func(t *testing.T) {
|
|
client := testcluster.NewHTTPClient()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/tls-client-cert-disabled", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
assert.Contains(t, []int{http.StatusBadGateway, http.StatusServiceUnavailable}, res.StatusCode)
|
|
})
|
|
}
|
|
|
|
func TestSNIMismatch(t *testing.T) {
|
|
ctx := mainCtx
|
|
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
defer clearTimeout()
|
|
|
|
// Browsers will coalesce connections for the same IP address and TLS certificate
|
|
// even if the request was made to different domain names. We need to support this
|
|
// so this test makes a request with an incorrect TLS server name to make sure it
|
|
// gets routed properly
|
|
|
|
hostport, err := testcluster.GetNodePortAddr(ctx, "default", "pomerium-proxy-nodeport")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
client := testcluster.NewHTTPClientWithTransport(&http.Transport{
|
|
DialContext: netutil.NewLocalDialer((&net.Dialer{}), map[string]string{
|
|
"443": hostport,
|
|
}).DialContext,
|
|
TLSClientConfig: &tls.Config{
|
|
ServerName: "ws-echo.localhost.pomerium.io",
|
|
},
|
|
})
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/ping", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
|
}
|
|
|
|
func TestAttestationJWT(t *testing.T) {
|
|
ctx := mainCtx
|
|
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
defer clearTimeout()
|
|
|
|
client := testcluster.NewHTTPClient()
|
|
res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io/by-user"),
|
|
nil, flows.WithEmail("bob@dogs.test"), flows.WithGroups("user"))
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
var result struct {
|
|
Headers map[string]string `json:"headers"`
|
|
}
|
|
err = json.NewDecoder(res.Body).Decode(&result)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
|
|
assert.NotEmpty(t, result.Headers["X-Pomerium-Jwt-Assertion"], "Expected JWT assertion")
|
|
}
|
|
|
|
func TestPassIdentityHeaders(t *testing.T) {
|
|
ctx := mainCtx
|
|
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
defer clearTimeout()
|
|
|
|
tests := []struct {
|
|
name string
|
|
path string
|
|
wantExist bool
|
|
}{
|
|
{"enabled", "/by-user", true},
|
|
{"disabled", "/by-domain", false},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
client := testcluster.NewHTTPClient()
|
|
res, err := flows.Authenticate(ctx, client, mustParseURL("https://httpdetails.localhost.pomerium.io"+tc.path),
|
|
nil, flows.WithEmail("bob@dogs.test"), flows.WithGroups("user"))
|
|
if !assert.NoError(t, err, "unexpected http error") {
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
var result struct {
|
|
Headers map[string]string `json:"headers"`
|
|
}
|
|
err = json.NewDecoder(res.Body).Decode(&result)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
|
|
for _, header := range []string{"X-Pomerium-Jwt-Assertion", "X-Pomerium-Claim-Email"} {
|
|
_, exist := result.Headers[header]
|
|
assert.True(t, exist == tc.wantExist, fmt.Sprintf("Header %s, expected: %v, got: %v", header, tc.wantExist, exist))
|
|
}
|
|
})
|
|
}
|
|
}
|