mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-04 09:19:39 +02:00
config: fix url type regression (#253)
Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
790619ef01
commit
a962877ad4
12 changed files with 117 additions and 57 deletions
|
@ -25,6 +25,9 @@ func StripPort(hostport string) string {
|
|||
// ParseAndValidateURL wraps standard library's default url.Parse because
|
||||
// it's much more lenient about what type of urls it accepts than pomerium.
|
||||
func ParseAndValidateURL(rawurl string) (*url.URL, error) {
|
||||
if rawurl == "" {
|
||||
return nil, fmt.Errorf("url cannot be empty")
|
||||
}
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -37,3 +40,10 @@ func ParseAndValidateURL(rawurl string) (*url.URL, error) {
|
|||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func DeepCopy(u *url.URL) (*url.URL, error) {
|
||||
if u == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return ParseAndValidateURL(u.String())
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package urlutil
|
|||
|
||||
import (
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
@ -45,6 +46,7 @@ func TestParseAndValidateURL(t *testing.T) {
|
|||
{"bad schema", "//some.example", nil, true},
|
||||
{"bad hostname", "https://", nil, true},
|
||||
{"bad parse", "https://^", nil, true},
|
||||
{"empty string error", "", nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -59,3 +61,29 @@ func TestParseAndValidateURL(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepCopy(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
u *url.URL
|
||||
want *url.URL
|
||||
wantErr bool
|
||||
}{
|
||||
{"nil", nil, nil, false},
|
||||
{"good", &url.URL{Scheme: "https", Host: "some.example"}, &url.URL{Scheme: "https", Host: "some.example"}, false},
|
||||
{"bad no scheme", &url.URL{Host: "some.example"}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := DeepCopy(tt.u)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("DeepCopy() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("DeepCopy() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue