mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
80 lines
1.9 KiB
Go
80 lines
1.9 KiB
Go
package config
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"encoding/pem"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func getLocalCertPEM(s *httptest.Server) []byte {
|
|
return pem.EncodeToMemory(&pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: s.Certificate().Raw,
|
|
})
|
|
}
|
|
|
|
func TestHTTPTransport(t *testing.T) {
|
|
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer s.Close()
|
|
|
|
src := NewStaticSource(&Config{
|
|
Options: &Options{
|
|
CA: base64.StdEncoding.EncodeToString(getLocalCertPEM(s)),
|
|
},
|
|
})
|
|
transport := NewHTTPTransport(src)
|
|
client := &http.Client{
|
|
Transport: transport,
|
|
}
|
|
_, err := client.Get(s.URL)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestPolicyHTTPTransport(t *testing.T) {
|
|
originalTransport := http.DefaultTransport
|
|
defer func() {
|
|
http.DefaultTransport = originalTransport
|
|
}()
|
|
src := NewStaticSource(&Config{Options: &Options{}})
|
|
http.DefaultTransport = NewHTTPTransport(src)
|
|
|
|
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer s.Close()
|
|
|
|
get := func(options *Options, policy *Policy) (*http.Response, error) {
|
|
transport := NewPolicyHTTPTransport(options, policy, false)
|
|
client := &http.Client{
|
|
Transport: transport,
|
|
}
|
|
return client.Get(s.URL)
|
|
}
|
|
|
|
t.Run("default", func(t *testing.T) {
|
|
_, err := get(&Options{}, &Policy{})
|
|
assert.Error(t, err)
|
|
})
|
|
t.Run("skip verify", func(t *testing.T) {
|
|
_, err := get(&Options{}, &Policy{TLSSkipVerify: true})
|
|
assert.NoError(t, err)
|
|
})
|
|
t.Run("ca", func(t *testing.T) {
|
|
_, err := get(&Options{
|
|
CA: base64.StdEncoding.EncodeToString(getLocalCertPEM(s)),
|
|
}, &Policy{})
|
|
assert.NoError(t, err)
|
|
})
|
|
t.Run("custom ca", func(t *testing.T) {
|
|
_, err := get(&Options{}, &Policy{
|
|
TLSCustomCA: base64.StdEncoding.EncodeToString(getLocalCertPEM(s)),
|
|
})
|
|
assert.NoError(t, err)
|
|
})
|
|
}
|