proxy: allow custom redirect url to be set following signout

This commit is contained in:
Bobby DeSimone 2019-06-13 21:09:19 -07:00
parent fb3ed64fa1
commit a7637cdf49
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
4 changed files with 79 additions and 15 deletions

View file

@ -8,6 +8,8 @@
### CHANGED
- Proxy's sign out handler `{}/.pomerium/sign_out` now accepts an optional `redirect_uri` parameter which can be used to specify a custom redirect page, so long as it is under the same top-level domain. [GH-183]
### FIXED
## v0.0.5

View file

@ -136,7 +136,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
"state=example",
"redirect_uri=some.example",
&sessions.MockSessionStore{
LoadError: errors.New("unexpeted"),
LoadError: errors.New("error"),
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",

View file

@ -37,7 +37,7 @@ func (p *Proxy) Handler() http.Handler {
mux.HandleFunc("/robots.txt", p.RobotsTxt)
mux.HandleFunc("/.pomerium", p.UserDashboard)
mux.HandleFunc("/.pomerium/impersonate", p.Impersonate) // POST
mux.HandleFunc("/.pomerium/sign_out", p.SignOutCallback)
mux.HandleFunc("/.pomerium/sign_out", p.SignOut)
// handlers handlers with validation
mux.Handle("/.pomerium/callback", validate.ThenFunc(p.OAuthCallback))
mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.Refresh))
@ -51,12 +51,28 @@ func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
}
// SignOutCallback redirects the request to the sign out url. It's the responsibility
// SignOut redirects the request to the sign out url. It's the responsibility
// of the authenticate service to revoke the remote session and clear
// the local session state.
func (p *Proxy) SignOutCallback(w http.ResponseWriter, r *http.Request) {
func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
switch r.Method {
case http.MethodPost:
if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusBadRequest})
return
}
uri, err := url.Parse(r.Form.Get("redirect_uri"))
if err == nil && uri.String() != "" {
redirectURL = uri
}
default:
uri, err := url.Parse(r.URL.Query().Get("redirect_uri"))
if err == nil && uri.String() != "" {
redirectURL = uri
}
}
http.Redirect(w, r, p.GetSignOutURL(p.AuthenticateURL, redirectURL).String(), http.StatusFound)
}
// OAuthStart begins the authenticate flow, encrypting the redirect url

View file

@ -49,7 +49,7 @@ func (a mockCipher) Unmarshal(s string, i interface{}) error {
func TestProxy_RobotsTxt(t *testing.T) {
proxy := Proxy{}
req := httptest.NewRequest("GET", "/robots.txt", nil)
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
rr := httptest.NewRecorder()
proxy.RobotsTxt(rr, req)
if status := rr.Code; status != http.StatusOK {
@ -62,7 +62,6 @@ func TestProxy_RobotsTxt(t *testing.T) {
}
func TestProxy_GetRedirectURL(t *testing.T) {
tests := []struct {
name string
host string
@ -103,7 +102,6 @@ func TestProxy_signRedirectURL(t *testing.T) {
}
func TestProxy_GetSignOutURL(t *testing.T) {
tests := []struct {
name string
authenticate string
@ -159,13 +157,17 @@ func TestProxy_Signout(t *testing.T) {
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("GET", "/.pomerium/sign_out", nil)
req := httptest.NewRequest(http.MethodGet, "/.pomerium/sign_out", nil)
rr := httptest.NewRecorder()
proxy.SignOutCallback(rr, req)
proxy.SignOut(rr, req)
if status := rr.Code; status != http.StatusFound {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
}
// todo(bdd) : good way of mocking auth then serving a simple favicon?
body := rr.Body.String()
want := (proxy.AuthenticateURL.String())
if !strings.Contains(body, want) {
t.Errorf("handler returned unexpected body: got %v want %s ", body, want)
}
}
func TestProxy_OAuthStart(t *testing.T) {
@ -173,7 +175,7 @@ func TestProxy_OAuthStart(t *testing.T) {
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("GET", "/oauth-start", nil)
req := httptest.NewRequest(http.MethodGet, "/oauth-start", nil)
rr := httptest.NewRecorder()
proxy.OAuthStart(rr, req)
@ -199,7 +201,7 @@ func TestProxy_Handler(t *testing.T) {
}
mux := http.NewServeMux()
mux.Handle("/", h)
req := httptest.NewRequest("GET", "/", nil)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusNotFound {
@ -254,7 +256,7 @@ func TestProxy_router(t *testing.T) {
p.AuthenticateClient = clients.MockAuthenticate{}
p.cipher = mockCipher{}
req := httptest.NewRequest("GET", tt.host, nil)
req := httptest.NewRequest(http.MethodGet, tt.host, nil)
_, ok := p.router(req)
if ok != tt.wantOk {
t.Errorf("Proxy.router() ok = %v, want %v", ok, tt.wantOk)
@ -555,5 +557,49 @@ func TestProxy_OAuthCallback(t *testing.T) {
}
})
}
}
func TestProxy_SignOut(t *testing.T) {
tests := []struct {
name string
verb string
redirectURL string
wantStatus int
}{
{"good post", http.MethodPost, "https://test.example", http.StatusFound},
{"good get", http.MethodGet, "https://test.example", http.StatusFound},
{"good empty default", http.MethodGet, "", http.StatusFound},
{"malformed", http.MethodPost, "", http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := testOptions()
p, err := New(opts)
if err != nil {
t.Fatal(err)
}
postForm := url.Values{}
postForm.Add("redirect_uri", tt.redirectURL)
uri := &url.URL{Path: "/"}
if tt.name == "malformed" {
uri.RawQuery = "redirect_uri=%zzzzz"
}
query, _ := url.ParseQuery(uri.RawQuery)
if tt.verb == http.MethodGet {
query.Add("redirect_uri", tt.redirectURL)
uri.RawQuery = query.Encode()
}
r := httptest.NewRequest(tt.verb, uri.String(), bytes.NewBufferString(postForm.Encode()))
w := httptest.NewRecorder()
if tt.verb == http.MethodPost {
r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
}
p.SignOut(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
}
})
}
}