pomerium/internal/urlutil/url_test.go
Bobby DeSimone dc12947241
all: refactor handler logic
- all: prefer `FormValues` to `ParseForm` with subsequent `Form.Get`s
- all: refactor authentication stack to be checked by middleware, and accessible via request context.
- all: replace http.ServeMux with gorilla/mux’s router
- all: replace custom CSRF checks with gorilla/csrf middleware
- authenticate: extract callback path as constant.
- internal/config: implement stringer interface for policy
- internal/cryptutil: add helper func `NewBase64Key`
- internal/cryptutil: rename `GenerateKey` to `NewKey`
- internal/cryptutil: rename `GenerateRandomString` to `NewRandomStringN`
- internal/middleware: removed alice in favor of gorilla/mux
- internal/sessions: remove unused `ValidateRedirectURI` and `ValidateClientSecret`
- internal/sessions: replace custom CSRF with gorilla/csrf fork that supports custom handler protection
- internal/urlutil: add `SignedRedirectURL` to create hmac'd URLs
- internal/urlutil: add `ValidateURL` helper to parse URL options
- internal/urlutil: add `GetAbsoluteURL` which takes a request and returns its absolute URL.
- proxy: remove holdover state verification checks; we no longer are setting sessions in any proxy routes so we don’t need them.
- proxy: replace un-named http.ServeMux with named domain routes.

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
2019-09-16 18:01:14 -07:00

177 lines
4.7 KiB
Go

package urlutil
import (
"net/http"
"net/url"
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
)
func Test_StripPort(t *testing.T) {
t.Parallel()
tests := []struct {
name string
hostport string
want string
}{
{"localhost", "localhost", "localhost"},
{"localhost with port", "localhost:443", "localhost"},
{"IPv6 localhost", "[::1]:80", "::1"},
{"IPv6 localhost without port", "[::1]", "::1"},
{"domain with port", "example.org:8080", "example.org"},
{"domain without port", "example.org", "example.org"},
{"long domain with port", "some.super.long.domain.example.org:8080", "some.super.long.domain.example.org"},
{"IPv6 with port", "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:17000", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"},
{"IPv6 without port", "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := StripPort(tt.hostport); got != tt.want {
t.Errorf("StripPort() = %v, want %v", got, tt.want)
}
})
}
}
func TestParseAndValidateURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
rawurl string
want *url.URL
wantErr bool
}{
{"good", "https://some.example", &url.URL{Scheme: "https", Host: "some.example"}, false},
{"bad schema", "//some.example", nil, true},
{"bad hostname", "https://", nil, true},
{"bad parse", "https://^", nil, true},
{"empty string error", "", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseAndValidateURL(tt.rawurl)
if (err != nil) != tt.wantErr {
t.Errorf("ParseAndValidateURL() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("TestParseAndValidateURL() = %s", diff)
}
})
}
}
func TestDeepCopy(t *testing.T) {
t.Parallel()
tests := []struct {
name string
u *url.URL
want *url.URL
wantErr bool
}{
{"nil", nil, nil, false},
{"good", &url.URL{Scheme: "https", Host: "some.example"}, &url.URL{Scheme: "https", Host: "some.example"}, false},
{"bad no scheme", &url.URL{Host: "some.example"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := DeepCopy(tt.u)
if (err != nil) != tt.wantErr {
t.Errorf("DeepCopy() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("DeepCopy() = %v, want %v", got, tt.want)
}
})
}
}
func TestValidateURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
u *url.URL
wantErr bool
}{
{"good", &url.URL{Scheme: "https", Host: "some.example"}, false},
{"nil", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ValidateURL(tt.u); (err != nil) != tt.wantErr {
t.Errorf("ValidateURL() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestSignedRedirectURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mockedTime int64
key string
destination *url.URL
urlToSign *url.URL
want *url.URL
}{
{"good", 2, "hunter42", &url.URL{Host: "pomerium.io", Scheme: "https://"}, &url.URL{Host: "pomerium.io", Scheme: "https://", Path: "/ok"}, &url.URL{Host: "pomerium.io", Scheme: "https://", RawQuery: "redirect_uri=https%3A%2F%2F%3A%2F%2Fpomerium.io%2Fok&sig=7jdo1XFcmuhjBHnpfVhll5cXflYByeMnbp5kRz87CVQ%3D&ts=2"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testTimeNow = tt.mockedTime
got := SignedRedirectURL(tt.key, tt.destination, tt.urlToSign)
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("SignedRedirectURL() = diff %v", diff)
}
})
}
}
func Test_timestamp(t *testing.T) {
t.Parallel()
tests := []struct {
name string
dontWant int64
}{
{"if unset should never return", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testTimeNow = tt.dontWant
if got := timestamp(); got == tt.dontWant {
t.Errorf("timestamp() = %v, dontWant %v", got, tt.dontWant)
}
})
}
}
func parseURLHelper(s string) *url.URL {
u, _ := url.Parse(s)
return u
}
func TestGetAbsoluteURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
u *url.URL
want *url.URL
}{
{"add https", parseURLHelper("http://pomerium.io"), parseURLHelper("https://pomerium.io")},
{"missing scheme", parseURLHelper("https://pomerium.io"), parseURLHelper("https://pomerium.io")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := http.Request{URL: tt.u, Host: tt.u.Host}
got := GetAbsoluteURL(&r)
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("GetAbsoluteURL() = %v", diff)
}
})
}
}