mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-30 23:09:23 +02:00
proxy: allow custom redirect url to be set following signout
This commit is contained in:
parent
fb3ed64fa1
commit
a7637cdf49
4 changed files with 79 additions and 15 deletions
|
@ -8,6 +8,8 @@
|
||||||
|
|
||||||
### CHANGED
|
### CHANGED
|
||||||
|
|
||||||
|
- Proxy's sign out handler `{}/.pomerium/sign_out` now accepts an optional `redirect_uri` parameter which can be used to specify a custom redirect page, so long as it is under the same top-level domain. [GH-183]
|
||||||
|
|
||||||
### FIXED
|
### FIXED
|
||||||
|
|
||||||
## v0.0.5
|
## v0.0.5
|
||||||
|
|
|
@ -136,7 +136,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
"state=example",
|
"state=example",
|
||||||
"redirect_uri=some.example",
|
"redirect_uri=some.example",
|
||||||
&sessions.MockSessionStore{
|
&sessions.MockSessionStore{
|
||||||
LoadError: errors.New("unexpeted"),
|
LoadError: errors.New("error"),
|
||||||
Session: &sessions.SessionState{
|
Session: &sessions.SessionState{
|
||||||
AccessToken: "AccessToken",
|
AccessToken: "AccessToken",
|
||||||
RefreshToken: "RefreshToken",
|
RefreshToken: "RefreshToken",
|
||||||
|
|
|
@ -37,7 +37,7 @@ func (p *Proxy) Handler() http.Handler {
|
||||||
mux.HandleFunc("/robots.txt", p.RobotsTxt)
|
mux.HandleFunc("/robots.txt", p.RobotsTxt)
|
||||||
mux.HandleFunc("/.pomerium", p.UserDashboard)
|
mux.HandleFunc("/.pomerium", p.UserDashboard)
|
||||||
mux.HandleFunc("/.pomerium/impersonate", p.Impersonate) // POST
|
mux.HandleFunc("/.pomerium/impersonate", p.Impersonate) // POST
|
||||||
mux.HandleFunc("/.pomerium/sign_out", p.SignOutCallback)
|
mux.HandleFunc("/.pomerium/sign_out", p.SignOut)
|
||||||
// handlers handlers with validation
|
// handlers handlers with validation
|
||||||
mux.Handle("/.pomerium/callback", validate.ThenFunc(p.OAuthCallback))
|
mux.Handle("/.pomerium/callback", validate.ThenFunc(p.OAuthCallback))
|
||||||
mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.Refresh))
|
mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.Refresh))
|
||||||
|
@ -51,12 +51,28 @@ func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
|
||||||
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignOutCallback redirects the request to the sign out url. It's the responsibility
|
// SignOut redirects the request to the sign out url. It's the responsibility
|
||||||
// of the authenticate service to revoke the remote session and clear
|
// of the authenticate service to revoke the remote session and clear
|
||||||
// the local session state.
|
// the local session state.
|
||||||
func (p *Proxy) SignOutCallback(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
|
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
|
||||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
switch r.Method {
|
||||||
|
case http.MethodPost:
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusBadRequest})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
uri, err := url.Parse(r.Form.Get("redirect_uri"))
|
||||||
|
if err == nil && uri.String() != "" {
|
||||||
|
redirectURL = uri
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
uri, err := url.Parse(r.URL.Query().Get("redirect_uri"))
|
||||||
|
if err == nil && uri.String() != "" {
|
||||||
|
redirectURL = uri
|
||||||
|
}
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, p.GetSignOutURL(p.AuthenticateURL, redirectURL).String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthStart begins the authenticate flow, encrypting the redirect url
|
// OAuthStart begins the authenticate flow, encrypting the redirect url
|
||||||
|
|
|
@ -49,7 +49,7 @@ func (a mockCipher) Unmarshal(s string, i interface{}) error {
|
||||||
|
|
||||||
func TestProxy_RobotsTxt(t *testing.T) {
|
func TestProxy_RobotsTxt(t *testing.T) {
|
||||||
proxy := Proxy{}
|
proxy := Proxy{}
|
||||||
req := httptest.NewRequest("GET", "/robots.txt", nil)
|
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
proxy.RobotsTxt(rr, req)
|
proxy.RobotsTxt(rr, req)
|
||||||
if status := rr.Code; status != http.StatusOK {
|
if status := rr.Code; status != http.StatusOK {
|
||||||
|
@ -62,7 +62,6 @@ func TestProxy_RobotsTxt(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxy_GetRedirectURL(t *testing.T) {
|
func TestProxy_GetRedirectURL(t *testing.T) {
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
host string
|
host string
|
||||||
|
@ -103,7 +102,6 @@ func TestProxy_signRedirectURL(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxy_GetSignOutURL(t *testing.T) {
|
func TestProxy_GetSignOutURL(t *testing.T) {
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
authenticate string
|
authenticate string
|
||||||
|
@ -159,13 +157,17 @@ func TestProxy_Signout(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
req := httptest.NewRequest("GET", "/.pomerium/sign_out", nil)
|
req := httptest.NewRequest(http.MethodGet, "/.pomerium/sign_out", nil)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
proxy.SignOutCallback(rr, req)
|
proxy.SignOut(rr, req)
|
||||||
if status := rr.Code; status != http.StatusFound {
|
if status := rr.Code; status != http.StatusFound {
|
||||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
|
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
|
||||||
}
|
}
|
||||||
// todo(bdd) : good way of mocking auth then serving a simple favicon?
|
body := rr.Body.String()
|
||||||
|
want := (proxy.AuthenticateURL.String())
|
||||||
|
if !strings.Contains(body, want) {
|
||||||
|
t.Errorf("handler returned unexpected body: got %v want %s ", body, want)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxy_OAuthStart(t *testing.T) {
|
func TestProxy_OAuthStart(t *testing.T) {
|
||||||
|
@ -173,7 +175,7 @@ func TestProxy_OAuthStart(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
req := httptest.NewRequest("GET", "/oauth-start", nil)
|
req := httptest.NewRequest(http.MethodGet, "/oauth-start", nil)
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
proxy.OAuthStart(rr, req)
|
proxy.OAuthStart(rr, req)
|
||||||
|
@ -199,7 +201,7 @@ func TestProxy_Handler(t *testing.T) {
|
||||||
}
|
}
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.Handle("/", h)
|
mux.Handle("/", h)
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
mux.ServeHTTP(rr, req)
|
mux.ServeHTTP(rr, req)
|
||||||
if rr.Code != http.StatusNotFound {
|
if rr.Code != http.StatusNotFound {
|
||||||
|
@ -254,7 +256,7 @@ func TestProxy_router(t *testing.T) {
|
||||||
p.AuthenticateClient = clients.MockAuthenticate{}
|
p.AuthenticateClient = clients.MockAuthenticate{}
|
||||||
p.cipher = mockCipher{}
|
p.cipher = mockCipher{}
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", tt.host, nil)
|
req := httptest.NewRequest(http.MethodGet, tt.host, nil)
|
||||||
_, ok := p.router(req)
|
_, ok := p.router(req)
|
||||||
if ok != tt.wantOk {
|
if ok != tt.wantOk {
|
||||||
t.Errorf("Proxy.router() ok = %v, want %v", ok, tt.wantOk)
|
t.Errorf("Proxy.router() ok = %v, want %v", ok, tt.wantOk)
|
||||||
|
@ -555,5 +557,49 @@ func TestProxy_OAuthCallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
func TestProxy_SignOut(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
verb string
|
||||||
|
redirectURL string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{"good post", http.MethodPost, "https://test.example", http.StatusFound},
|
||||||
|
{"good get", http.MethodGet, "https://test.example", http.StatusFound},
|
||||||
|
{"good empty default", http.MethodGet, "", http.StatusFound},
|
||||||
|
{"malformed", http.MethodPost, "", http.StatusBadRequest},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
opts := testOptions()
|
||||||
|
p, err := New(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
postForm := url.Values{}
|
||||||
|
postForm.Add("redirect_uri", tt.redirectURL)
|
||||||
|
uri := &url.URL{Path: "/"}
|
||||||
|
if tt.name == "malformed" {
|
||||||
|
uri.RawQuery = "redirect_uri=%zzzzz"
|
||||||
|
}
|
||||||
|
|
||||||
|
query, _ := url.ParseQuery(uri.RawQuery)
|
||||||
|
if tt.verb == http.MethodGet {
|
||||||
|
query.Add("redirect_uri", tt.redirectURL)
|
||||||
|
uri.RawQuery = query.Encode()
|
||||||
|
}
|
||||||
|
r := httptest.NewRequest(tt.verb, uri.String(), bytes.NewBufferString(postForm.Encode()))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
if tt.verb == http.MethodPost {
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
||||||
|
}
|
||||||
|
p.SignOut(w, r)
|
||||||
|
if status := w.Code; status != tt.wantStatus {
|
||||||
|
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue