diff --git a/cmd/pomerium/main.go b/cmd/pomerium/main.go index 9fbfcf1ae..3a265ff1a 100644 --- a/cmd/pomerium/main.go +++ b/cmd/pomerium/main.go @@ -14,6 +14,7 @@ import ( "github.com/pomerium/pomerium/internal/https" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/middleware" + "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/version" pbAuthenticate "github.com/pomerium/pomerium/proto/authenticate" pbAuthorize "github.com/pomerium/pomerium/proto/authorize" @@ -58,7 +59,7 @@ func main() { if err != nil { log.Fatal().Err(err).Msg("cmd/pomerium: new authenticate") } - authHost = opts.AuthenticateURL.Host + authHost = urlutil.StripPort(opts.AuthenticateURL.Host) pbAuthenticate.RegisterAuthenticatorServer(grpcServer, authenticateService) } diff --git a/internal/urlutil/url.go b/internal/urlutil/url.go new file mode 100644 index 000000000..6182ddc18 --- /dev/null +++ b/internal/urlutil/url.go @@ -0,0 +1,14 @@ +package urlutil // import "github.com/pomerium/pomerium/internal/urlutil" + +import "strings" + +func StripPort(hostport string) string { + colon := strings.IndexByte(hostport, ':') + if colon == -1 { + return hostport + } + if i := strings.IndexByte(hostport, ']'); i != -1 { + return strings.TrimPrefix(hostport[:i], "[") + } + return hostport[:colon] +} diff --git a/internal/urlutil/url_test.go b/internal/urlutil/url_test.go new file mode 100644 index 000000000..40dd1c4eb --- /dev/null +++ b/internal/urlutil/url_test.go @@ -0,0 +1,29 @@ +package urlutil // import "github.com/pomerium/pomerium/internal/urlutil" + +import "testing" + +func Test_StripPort(t *testing.T) { + t.Parallel() + tests := []struct { + name string + hostport string + want string + }{ + {"localhost", "localhost", "localhost"}, + {"localhost with port", "localhost:443", "localhost"}, + {"IPv6 localhost", "[::1]:80", "::1"}, + {"IPv6 localhost without port", "[::1]", "::1"}, + {"domain with port", "example.org:8080", "example.org"}, + {"domain without port", "example.org", "example.org"}, + {"long domain with port", "some.super.long.domain.example.org:8080", "some.super.long.domain.example.org"}, + {"IPv6 with port", "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:17000", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"}, + {"IPv6 without port", "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := StripPort(tt.hostport); got != tt.want { + t.Errorf("StripPort() = %v, want %v", got, tt.want) + } + }) + } +}