From a7637cdf49d71162c6e172ee8ba1f51353ee0404 Mon Sep 17 00:00:00 2001 From: Bobby DeSimone Date: Thu, 13 Jun 2019 21:09:19 -0700 Subject: [PATCH] proxy: allow custom redirect url to be set following signout --- CHANGELOG.md | 2 ++ authenticate/handlers_test.go | 2 +- proxy/handlers.go | 24 ++++++++++--- proxy/handlers_test.go | 66 +++++++++++++++++++++++++++++------ 4 files changed, 79 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 995589d4c..d49aa97a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ ### CHANGED +- Proxy's sign out handler `{}/.pomerium/sign_out` now accepts an optional `redirect_uri` parameter which can be used to specify a custom redirect page, so long as it is under the same top-level domain. [GH-183] + ### FIXED ## v0.0.5 diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 5fea3ca21..79514a84b 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -136,7 +136,7 @@ func TestAuthenticate_SignIn(t *testing.T) { "state=example", "redirect_uri=some.example", &sessions.MockSessionStore{ - LoadError: errors.New("unexpeted"), + LoadError: errors.New("error"), Session: &sessions.SessionState{ AccessToken: "AccessToken", RefreshToken: "RefreshToken", diff --git a/proxy/handlers.go b/proxy/handlers.go index 03d703410..52de7427b 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -37,7 +37,7 @@ func (p *Proxy) Handler() http.Handler { mux.HandleFunc("/robots.txt", p.RobotsTxt) mux.HandleFunc("/.pomerium", p.UserDashboard) mux.HandleFunc("/.pomerium/impersonate", p.Impersonate) // POST - mux.HandleFunc("/.pomerium/sign_out", p.SignOutCallback) + mux.HandleFunc("/.pomerium/sign_out", p.SignOut) // handlers handlers with validation mux.Handle("/.pomerium/callback", validate.ThenFunc(p.OAuthCallback)) mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.Refresh)) @@ -51,12 +51,28 @@ func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) { fmt.Fprintf(w, "User-agent: *\nDisallow: /") } -// SignOutCallback redirects the request to the sign out url. It's the responsibility +// SignOut redirects the request to the sign out url. It's the responsibility // of the authenticate service to revoke the remote session and clear // the local session state. -func (p *Proxy) SignOutCallback(w http.ResponseWriter, r *http.Request) { +func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) { redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"} - http.Redirect(w, r, redirectURL.String(), http.StatusFound) + switch r.Method { + case http.MethodPost: + if err := r.ParseForm(); err != nil { + httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusBadRequest}) + return + } + uri, err := url.Parse(r.Form.Get("redirect_uri")) + if err == nil && uri.String() != "" { + redirectURL = uri + } + default: + uri, err := url.Parse(r.URL.Query().Get("redirect_uri")) + if err == nil && uri.String() != "" { + redirectURL = uri + } + } + http.Redirect(w, r, p.GetSignOutURL(p.AuthenticateURL, redirectURL).String(), http.StatusFound) } // OAuthStart begins the authenticate flow, encrypting the redirect url diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 3e27eb9f2..21d9f912a 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -49,7 +49,7 @@ func (a mockCipher) Unmarshal(s string, i interface{}) error { func TestProxy_RobotsTxt(t *testing.T) { proxy := Proxy{} - req := httptest.NewRequest("GET", "/robots.txt", nil) + req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil) rr := httptest.NewRecorder() proxy.RobotsTxt(rr, req) if status := rr.Code; status != http.StatusOK { @@ -62,7 +62,6 @@ func TestProxy_RobotsTxt(t *testing.T) { } func TestProxy_GetRedirectURL(t *testing.T) { - tests := []struct { name string host string @@ -103,7 +102,6 @@ func TestProxy_signRedirectURL(t *testing.T) { } func TestProxy_GetSignOutURL(t *testing.T) { - tests := []struct { name string authenticate string @@ -159,13 +157,17 @@ func TestProxy_Signout(t *testing.T) { if err != nil { t.Fatal(err) } - req := httptest.NewRequest("GET", "/.pomerium/sign_out", nil) + req := httptest.NewRequest(http.MethodGet, "/.pomerium/sign_out", nil) rr := httptest.NewRecorder() - proxy.SignOutCallback(rr, req) + proxy.SignOut(rr, req) if status := rr.Code; status != http.StatusFound { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound) } - // todo(bdd) : good way of mocking auth then serving a simple favicon? + body := rr.Body.String() + want := (proxy.AuthenticateURL.String()) + if !strings.Contains(body, want) { + t.Errorf("handler returned unexpected body: got %v want %s ", body, want) + } } func TestProxy_OAuthStart(t *testing.T) { @@ -173,7 +175,7 @@ func TestProxy_OAuthStart(t *testing.T) { if err != nil { t.Fatal(err) } - req := httptest.NewRequest("GET", "/oauth-start", nil) + req := httptest.NewRequest(http.MethodGet, "/oauth-start", nil) rr := httptest.NewRecorder() proxy.OAuthStart(rr, req) @@ -199,7 +201,7 @@ func TestProxy_Handler(t *testing.T) { } mux := http.NewServeMux() mux.Handle("/", h) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) rr := httptest.NewRecorder() mux.ServeHTTP(rr, req) if rr.Code != http.StatusNotFound { @@ -254,7 +256,7 @@ func TestProxy_router(t *testing.T) { p.AuthenticateClient = clients.MockAuthenticate{} p.cipher = mockCipher{} - req := httptest.NewRequest("GET", tt.host, nil) + req := httptest.NewRequest(http.MethodGet, tt.host, nil) _, ok := p.router(req) if ok != tt.wantOk { t.Errorf("Proxy.router() ok = %v, want %v", ok, tt.wantOk) @@ -555,5 +557,49 @@ func TestProxy_OAuthCallback(t *testing.T) { } }) } - +} +func TestProxy_SignOut(t *testing.T) { + + tests := []struct { + name string + verb string + redirectURL string + wantStatus int + }{ + {"good post", http.MethodPost, "https://test.example", http.StatusFound}, + {"good get", http.MethodGet, "https://test.example", http.StatusFound}, + {"good empty default", http.MethodGet, "", http.StatusFound}, + {"malformed", http.MethodPost, "", http.StatusBadRequest}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := testOptions() + p, err := New(opts) + if err != nil { + t.Fatal(err) + } + postForm := url.Values{} + postForm.Add("redirect_uri", tt.redirectURL) + uri := &url.URL{Path: "/"} + if tt.name == "malformed" { + uri.RawQuery = "redirect_uri=%zzzzz" + } + + query, _ := url.ParseQuery(uri.RawQuery) + if tt.verb == http.MethodGet { + query.Add("redirect_uri", tt.redirectURL) + uri.RawQuery = query.Encode() + } + r := httptest.NewRequest(tt.verb, uri.String(), bytes.NewBufferString(postForm.Encode())) + w := httptest.NewRecorder() + if tt.verb == http.MethodPost { + r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + } + p.SignOut(w, r) + if status := w.Code; status != tt.wantStatus { + t.Errorf("status code: got %v want %v", status, tt.wantStatus) + } + + }) + } }