package urlutil

import (
	"net/http"
	"net/url"
	"reflect"
	"testing"

	"github.com/google/go-cmp/cmp"
	"github.com/stretchr/testify/assert"
)

func Test_StripPort(t *testing.T) {
	t.Parallel()
	tests := []struct {
		name     string
		hostport string
		want     string
	}{
		{"localhost", "localhost", "localhost"},
		{"localhost with port", "localhost:443", "localhost"},
		{"IPv6 localhost", "[::1]:80", "::1"},
		{"IPv6 localhost without port", "[::1]", "::1"},
		{"domain with port", "example.org:8080", "example.org"},
		{"domain without port", "example.org", "example.org"},
		{"long domain with port", "some.super.long.domain.example.org:8080", "some.super.long.domain.example.org"},
		{"IPv6 with port", "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:17000", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"},
		{"IPv6 without port", "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if got := StripPort(tt.hostport); got != tt.want {
				t.Errorf("StripPort() = %v, want %v", got, tt.want)
			}
		})
	}
}

func TestParseAndValidateURL(t *testing.T) {
	t.Parallel()
	tests := []struct {
		name    string
		rawurl  string
		want    *url.URL
		wantErr bool
	}{
		{"good", "https://some.example", &url.URL{Scheme: "https", Host: "some.example"}, false},
		{"bad schema", "//some.example", nil, true},
		{"bad hostname", "https://", nil, true},
		{"bad parse", "https://^", nil, true},
		{"empty string error", "", nil, true},
		{"path segment", "192.168.0.1:1234/path", nil, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := ParseAndValidateURL(tt.rawurl)
			if (err != nil) != tt.wantErr {
				t.Errorf("ParseAndValidateURL() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if diff := cmp.Diff(got, tt.want); diff != "" {
				t.Errorf("TestParseAndValidateURL() = %s", diff)
			}
		})
	}
}

func TestDeepCopy(t *testing.T) {
	t.Parallel()
	tests := []struct {
		name    string
		u       *url.URL
		want    *url.URL
		wantErr bool
	}{
		{"nil", nil, nil, false},
		{"good", &url.URL{Scheme: "https", Host: "some.example"}, &url.URL{Scheme: "https", Host: "some.example"}, false},
		{"bad no scheme", &url.URL{Host: "some.example"}, nil, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := DeepCopy(tt.u)
			if (err != nil) != tt.wantErr {
				t.Errorf("DeepCopy() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("DeepCopy() = %v, want %v", got, tt.want)
			}
		})
	}
}

func TestValidateURL(t *testing.T) {
	t.Parallel()
	tests := []struct {
		name    string
		u       *url.URL
		wantErr bool
	}{
		{"good", &url.URL{Scheme: "https", Host: "some.example"}, false},
		{"nil", nil, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if err := ValidateURL(tt.u); (err != nil) != tt.wantErr {
				t.Errorf("ValidateURL() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}

func parseURLHelper(s string) *url.URL {
	u, _ := url.Parse(s)
	return u
}

func TestGetAbsoluteURL(t *testing.T) {
	t.Parallel()
	tests := []struct {
		name string
		u    *url.URL
		want *url.URL
	}{
		{"add https", parseURLHelper("http://pomerium.io"), parseURLHelper("https://pomerium.io")},
		{"missing scheme", parseURLHelper("https://pomerium.io"), parseURLHelper("https://pomerium.io")},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			r := http.Request{URL: tt.u, Host: tt.u.Host}
			got := GetAbsoluteURL(&r)
			if diff := cmp.Diff(got, tt.want); diff != "" {
				t.Errorf("GetAbsoluteURL() = %v", diff)
			}
		})
	}
}

func TestGetDomainsForURL(t *testing.T) {
	t.Parallel()
	tests := []struct {
		name string
		u    *url.URL
		want []string
	}{
		{"http", &url.URL{Scheme: "http", Host: "example.com"}, []string{"example.com", "example.com:80"}},
		{"http scheme with host contain 443", &url.URL{Scheme: "http", Host: "example.com:443"}, []string{"example.com:443"}},
		{"https", &url.URL{Scheme: "https", Host: "example.com"}, []string{"example.com", "example.com:443"}},
		{"Host contains other port", &url.URL{Scheme: "https", Host: "example.com:1234"}, []string{"example.com:1234"}},
	}
	for _, tc := range tests {
		tc := tc
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()
			got := GetDomainsForURL(*tc.u)
			if diff := cmp.Diff(got, tc.want); diff != "" {
				t.Errorf("GetDomainsForURL() = %v", diff)
			}
		})
	}
}

func TestJoin(t *testing.T) {
	assert.Equal(t, "/x/y/z/", Join("/x", "y/z/"))
	assert.Equal(t, "/x/y/z/", Join("/x/", "y/z/"))
	assert.Equal(t, "/x/y/z/", Join("/x", "/y/z/"))
	assert.Equal(t, "/x/y/z/", Join("/x/", "/y/z/"))
}