mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-30 06:51:30 +02:00
proxy: allow custom redirect url to be set following signout
This commit is contained in:
parent
fb3ed64fa1
commit
a7637cdf49
4 changed files with 79 additions and 15 deletions
|
@ -8,6 +8,8 @@
|
|||
|
||||
### CHANGED
|
||||
|
||||
- Proxy's sign out handler `{}/.pomerium/sign_out` now accepts an optional `redirect_uri` parameter which can be used to specify a custom redirect page, so long as it is under the same top-level domain. [GH-183]
|
||||
|
||||
### FIXED
|
||||
|
||||
## v0.0.5
|
||||
|
|
|
@ -136,7 +136,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
"state=example",
|
||||
"redirect_uri=some.example",
|
||||
&sessions.MockSessionStore{
|
||||
LoadError: errors.New("unexpeted"),
|
||||
LoadError: errors.New("error"),
|
||||
Session: &sessions.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
|
|
|
@ -37,7 +37,7 @@ func (p *Proxy) Handler() http.Handler {
|
|||
mux.HandleFunc("/robots.txt", p.RobotsTxt)
|
||||
mux.HandleFunc("/.pomerium", p.UserDashboard)
|
||||
mux.HandleFunc("/.pomerium/impersonate", p.Impersonate) // POST
|
||||
mux.HandleFunc("/.pomerium/sign_out", p.SignOutCallback)
|
||||
mux.HandleFunc("/.pomerium/sign_out", p.SignOut)
|
||||
// handlers handlers with validation
|
||||
mux.Handle("/.pomerium/callback", validate.ThenFunc(p.OAuthCallback))
|
||||
mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.Refresh))
|
||||
|
@ -51,12 +51,28 @@ func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
|
|||
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
||||
}
|
||||
|
||||
// SignOutCallback redirects the request to the sign out url. It's the responsibility
|
||||
// SignOut redirects the request to the sign out url. It's the responsibility
|
||||
// of the authenticate service to revoke the remote session and clear
|
||||
// the local session state.
|
||||
func (p *Proxy) SignOutCallback(w http.ResponseWriter, r *http.Request) {
|
||||
func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
if err := r.ParseForm(); err != nil {
|
||||
httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusBadRequest})
|
||||
return
|
||||
}
|
||||
uri, err := url.Parse(r.Form.Get("redirect_uri"))
|
||||
if err == nil && uri.String() != "" {
|
||||
redirectURL = uri
|
||||
}
|
||||
default:
|
||||
uri, err := url.Parse(r.URL.Query().Get("redirect_uri"))
|
||||
if err == nil && uri.String() != "" {
|
||||
redirectURL = uri
|
||||
}
|
||||
}
|
||||
http.Redirect(w, r, p.GetSignOutURL(p.AuthenticateURL, redirectURL).String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// OAuthStart begins the authenticate flow, encrypting the redirect url
|
||||
|
|
|
@ -49,7 +49,7 @@ func (a mockCipher) Unmarshal(s string, i interface{}) error {
|
|||
|
||||
func TestProxy_RobotsTxt(t *testing.T) {
|
||||
proxy := Proxy{}
|
||||
req := httptest.NewRequest("GET", "/robots.txt", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
proxy.RobotsTxt(rr, req)
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
|
@ -62,7 +62,6 @@ func TestProxy_RobotsTxt(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestProxy_GetRedirectURL(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
|
@ -103,7 +102,6 @@ func TestProxy_signRedirectURL(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestProxy_GetSignOutURL(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authenticate string
|
||||
|
@ -159,13 +157,17 @@ func TestProxy_Signout(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/.pomerium/sign_out", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/.pomerium/sign_out", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
proxy.SignOutCallback(rr, req)
|
||||
proxy.SignOut(rr, req)
|
||||
if status := rr.Code; status != http.StatusFound {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
|
||||
}
|
||||
// todo(bdd) : good way of mocking auth then serving a simple favicon?
|
||||
body := rr.Body.String()
|
||||
want := (proxy.AuthenticateURL.String())
|
||||
if !strings.Contains(body, want) {
|
||||
t.Errorf("handler returned unexpected body: got %v want %s ", body, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_OAuthStart(t *testing.T) {
|
||||
|
@ -173,7 +175,7 @@ func TestProxy_OAuthStart(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/oauth-start", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/oauth-start", nil)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
proxy.OAuthStart(rr, req)
|
||||
|
@ -199,7 +201,7 @@ func TestProxy_Handler(t *testing.T) {
|
|||
}
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/", h)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusNotFound {
|
||||
|
@ -254,7 +256,7 @@ func TestProxy_router(t *testing.T) {
|
|||
p.AuthenticateClient = clients.MockAuthenticate{}
|
||||
p.cipher = mockCipher{}
|
||||
|
||||
req := httptest.NewRequest("GET", tt.host, nil)
|
||||
req := httptest.NewRequest(http.MethodGet, tt.host, nil)
|
||||
_, ok := p.router(req)
|
||||
if ok != tt.wantOk {
|
||||
t.Errorf("Proxy.router() ok = %v, want %v", ok, tt.wantOk)
|
||||
|
@ -555,5 +557,49 @@ func TestProxy_OAuthCallback(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
func TestProxy_SignOut(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
verb string
|
||||
redirectURL string
|
||||
wantStatus int
|
||||
}{
|
||||
{"good post", http.MethodPost, "https://test.example", http.StatusFound},
|
||||
{"good get", http.MethodGet, "https://test.example", http.StatusFound},
|
||||
{"good empty default", http.MethodGet, "", http.StatusFound},
|
||||
{"malformed", http.MethodPost, "", http.StatusBadRequest},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
opts := testOptions()
|
||||
p, err := New(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
postForm := url.Values{}
|
||||
postForm.Add("redirect_uri", tt.redirectURL)
|
||||
uri := &url.URL{Path: "/"}
|
||||
if tt.name == "malformed" {
|
||||
uri.RawQuery = "redirect_uri=%zzzzz"
|
||||
}
|
||||
|
||||
query, _ := url.ParseQuery(uri.RawQuery)
|
||||
if tt.verb == http.MethodGet {
|
||||
query.Add("redirect_uri", tt.redirectURL)
|
||||
uri.RawQuery = query.Encode()
|
||||
}
|
||||
r := httptest.NewRequest(tt.verb, uri.String(), bytes.NewBufferString(postForm.Encode()))
|
||||
w := httptest.NewRecorder()
|
||||
if tt.verb == http.MethodPost {
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
||||
}
|
||||
p.SignOut(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue