From 8de453dae3eea9326a1f46ac02c76252b9aa2ee2 Mon Sep 17 00:00:00 2001 From: Bobby DeSimone Date: Mon, 3 Jun 2019 08:45:38 -0700 Subject: [PATCH] internal/middleware: validate only top domain (#158) --- authenticate/handlers.go | 6 ++--- go.mod | 2 +- go.sum | 2 ++ internal/middleware/middleware.go | 22 +++++++--------- internal/middleware/middleware_test.go | 36 ++++++++++++++------------ 5 files changed, 35 insertions(+), 33 deletions(-) diff --git a/authenticate/handlers.go b/authenticate/handlers.go index b1d279f46..182dc347e 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -162,14 +162,14 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { a.csrfStore.SetCSRF(w, r, nonce) // verify redirect uri is from the root domain - if !middleware.SameSubdomain(authRedirectURL, a.RedirectURL) { + if !middleware.SameDomain(authRedirectURL, a.RedirectURL) { httputil.ErrorResponse(w, r, "Invalid redirect parameter: redirect uri not from the root domain", http.StatusBadRequest) return } // verify proxy url is from the root domain proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri")) - if err != nil || !middleware.SameSubdomain(proxyRedirectURL, a.RedirectURL) { + if err != nil || !middleware.SameDomain(proxyRedirectURL, a.RedirectURL) { httputil.ErrorResponse(w, r, "Invalid redirect parameter: proxy url not from the root domain", http.StatusBadRequest) return } @@ -261,7 +261,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Malformed redirect url"} } // sanity check, we are redirecting back to the same subdomain right? - if !middleware.SameSubdomain(redirectURL, a.RedirectURL) { + if !middleware.SameDomain(redirectURL, a.RedirectURL) { return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Invalid Redirect URI domain"} } diff --git a/go.mod b/go.mod index 11f001d5c..075f11811 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/spf13/viper v1.3.2 github.com/stretchr/testify v1.3.0 // indirect golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f - golang.org/x/net v0.0.0-20190522155817-f3200d17e092 + golang.org/x/net v0.0.0-20190603091049-60506f45cf65 golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421 golang.org/x/sys v0.0.0-20190524152521-dbbf3f1254d4 // indirect golang.org/x/text v0.3.2 // indirect diff --git a/go.sum b/go.sum index 4dcdf95c5..4de46ba0c 100644 --- a/go.sum +++ b/go.sum @@ -93,6 +93,8 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190522155817-f3200d17e092 h1:4QSRKanuywn15aTZvI/mIDEgPQpswuFndXpOj3rKEco= golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65 h1:+rhAzEzT3f4JtomfC371qB+0Ola2caSKcY69NUBZrRQ= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421 h1:Wo7BWFiOk0QRFMLYMqJGFMd9CgUAcGx7V+qEg/h5IBI= diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 2c1f2e2ef..84475f380 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -12,6 +12,7 @@ import ( "github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/httputil" + "golang.org/x/net/publicsuffix" ) // SetHeaders ensures that every response includes some basic security headers @@ -66,7 +67,7 @@ func ValidateRedirectURI(rootDomain *url.URL) func(next http.Handler) http.Handl httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) return } - if !SameSubdomain(redirectURI, rootDomain) { + if !SameDomain(redirectURI, rootDomain) { httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest) return } @@ -75,22 +76,17 @@ func ValidateRedirectURI(rootDomain *url.URL) func(next http.Handler) http.Handl } } -// SameSubdomain checks to see if two URLs share the same root domain. -func SameSubdomain(u, j *url.URL) bool { - if u.Hostname() == "" || j.Hostname() == "" { +// SameDomain checks to see if two URLs share the top level domain (TLD Plus One). +func SameDomain(u, j *url.URL) bool { + a, err := publicsuffix.EffectiveTLDPlusOne(u.Hostname()) + if err != nil { return false } - uParts := strings.Split(u.Hostname(), ".") - jParts := strings.Split(j.Hostname(), ".") - if len(uParts) != len(jParts) { + b, err := publicsuffix.EffectiveTLDPlusOne(j.Hostname()) + if err != nil { return false } - for i := 1; i < len(uParts); i++ { - if uParts[i] != jParts[i] { - return false - } - } - return true + return a == b } // ValidateSignature ensures the request is valid and has been signed with diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index 3f859722f..de7b15c6b 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -10,7 +10,7 @@ import ( "time" ) -func Test_SameSubdomain(t *testing.T) { +func Test_SameDomain(t *testing.T) { tests := []struct { name string @@ -19,8 +19,10 @@ func Test_SameSubdomain(t *testing.T) { want bool }{ {"good url redirect", "https://example.com/redirect", "https://example.com", true}, + {"good multilevel", "https://httpbin.a.corp.example.com", "https://auth.b.corp.example.com", true}, + {"good complex tld", "https://httpbin.a.corp.example.co.uk", "https://auth.b.corp.example.co.uk", true}, + {"bad complex tld", "https://httpbin.a.corp.notexample.co.uk", "https://auth.b.corp.example.co.uk", false}, {"simple sub", "https://auth.example.com", "https://test.example.com", true}, - {"mismatched lengths", "https://auth.auth.example.com", "https://test.example.com", false}, {"bad domain", "https://auth.example.com/redirect", "https://test.notexample.com", false}, {"malformed url", "^example.com/redirect", "https://notexample.com", false}, {"empty domain list", "https://example.com/redirect", ".com", false}, @@ -31,8 +33,8 @@ func Test_SameSubdomain(t *testing.T) { t.Run(tt.name, func(t *testing.T) { u, _ := url.Parse(tt.uri) j, _ := url.Parse(tt.rootDomains) - if got := SameSubdomain(u, j); got != tt.want { - t.Errorf("SameSubdomain() = %v, want %v", got, tt.want) + if got := SameDomain(u, j); got != tt.want { + t.Errorf("SameDomain() = %v, want %v", got, tt.want) } }) } @@ -127,24 +129,26 @@ func TestValidateRedirectURI(t *testing.T) { redirectURI string status int }{ - {"simple", "https://auth.google.com", "https://b.google.com", http.StatusOK}, - {"deep ok", "https://a.some.really.deep.sub.domain.google.com", "https://b.some.really.deep.sub.domain.google.com", http.StatusOK}, - {"bad match", "https://auth.aol.com", "https://test.google.com", http.StatusBadRequest}, - {"bad simple", "https://auth.corp.aol.com", "https://test.corp.google.com", http.StatusBadRequest}, - {"deep bad", "https://a.some.really.deep.sub.domain.scroogle.com", "https://b.some.really.deep.sub.domain.google.com", http.StatusBadRequest}, - {"with cname", "https://auth.google.com", "https://www.google.com", http.StatusOK}, - {"with path", "https://auth.google.com", "https://www.google.com/path", http.StatusOK}, - {"http mistmatch", "https://auth.google.com", "http://www.google.com/path", http.StatusOK}, - {"http", "http://auth.google.com", "http://www.google.com/path", http.StatusOK}, - {"ip", "http://1.1.1.1", "http://8.8.8.8", http.StatusBadRequest}, - {"malformed, invalid hex digits", "https://auth.google.com", "%zzzzz", http.StatusBadRequest}, + {"simple", "https://auth.google.com", "redirect_uri=https://b.google.com", http.StatusOK}, + {"deep ok", "https://a.some.really.deep.sub.domain.google.com", "redirect_uri=https://b.some.really.deep.sub.domain.google.com", http.StatusOK}, + {"bad match", "https://auth.aol.com", "redirect_uri=https://test.google.com", http.StatusBadRequest}, + {"bad simple", "https://auth.corp.aol.com", "redirect_uri=https://test.corp.google.com", http.StatusBadRequest}, + {"deep bad", "https://a.some.really.deep.sub.domain.scroogle.com", "redirect_uri=https://b.some.really.deep.sub.domain.google.com", http.StatusBadRequest}, + {"with cname", "https://auth.google.com", "redirect_uri=https://www.google.com", http.StatusOK}, + {"with path", "https://auth.google.com", "redirect_uri=https://www.google.com/path", http.StatusOK}, + {"http mistmatch", "https://auth.google.com", "redirect_uri=http://www.google.com/path", http.StatusOK}, + {"http", "http://auth.google.com", "redirect_uri=http://www.google.com/path", http.StatusOK}, + {"ip", "http://1.1.1.1", "redirect_uri=http://8.8.8.8", http.StatusBadRequest}, + {"redirect get param not set", "https://auth.google.com", "not_redirect_uri!=https://b.google.com", http.StatusBadRequest}, + {"malformed, invalid get params", "https://auth.google.com", "redirect_uri=https://%zzzzz", http.StatusBadRequest}, + {"malformed, invalid url", "https://auth.google.com", "redirect_uri=https://accounts.google.^", http.StatusBadRequest}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := &http.Request{ Method: http.MethodGet, - URL: &url.URL{RawQuery: fmt.Sprintf("redirect_uri=%s", tt.redirectURI)}, + URL: &url.URL{RawQuery: tt.redirectURI}, } testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hi"))