internal/httputil: use error structs for http errors (#159)

The existing implementation used a ErrorResponse method to propogate
and create http error messages. Since we added functionality to
troubleshoot, signout, and do other tasks following an http error
it's useful to use Error struct in place of method arguments.

This fixes #157 where a troubleshooting links were appearing on pages
that it didn't make sense on (e.g. pages without valid sessions).
This commit is contained in:
Bobby DeSimone 2019-06-03 20:00:37 -07:00 committed by GitHub
parent 14403ce388
commit bade9f50e6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 244 additions and 133 deletions

View file

@ -14,7 +14,7 @@ import (
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
) )
// CSPHeaders adds content security headers for authenticate's handlers // CSPHeaders are the content security headers added to the service's handlers
var CSPHeaders = map[string]string{ var CSPHeaders = map[string]string{
"Content-Security-Policy": "default-src 'none'; style-src 'self'" + "Content-Security-Policy": "default-src 'none'; style-src 'self'" +
" 'sha256-z9MsgkMbQjRSLxzAfN55jB3a9pP0PQ4OHFH8b4iDP6s=' " + " 'sha256-z9MsgkMbQjRSLxzAfN55jB3a9pP0PQ4OHFH8b4iDP6s=' " +
@ -80,36 +80,41 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
return return
default: default:
log.FromRequest(r).Error().Err(err).Msg("proxy: unexpected error") log.FromRequest(r).Error().Err(err).Msg("proxy: unexpected error")
httputil.ErrorResponse(w, r, "An unexpected error occurred", http.StatusInternalServerError) httpErr := &httputil.Error{Message: "An unexpected error occurred", Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
} }
err = a.authenticate(w, r, session) err = a.authenticate(w, r, session)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
err = r.ParseForm() if err = r.ParseForm(); err != nil {
if err != nil { httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) httputil.ErrorResponse(w, r, httpErr)
return return
} }
// original `state` parameter received from the proxy application. // original `state` parameter received from the proxy application.
state := r.Form.Get("state") state := r.Form.Get("state")
if state == "" { if state == "" {
httputil.ErrorResponse(w, r, "no state parameter supplied", http.StatusBadRequest) httpErr := &httputil.Error{Message: "no state parameter supplied", Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
redirectURL, err := url.Parse(r.Form.Get("redirect_uri")) redirectURL, err := url.Parse(r.Form.Get("redirect_uri"))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, "malformed redirect_uri parameter passed", http.StatusBadRequest) httpErr := &httputil.Error{Message: "malformed redirect_uri parameter passed", Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
// encrypt session state as json blob // encrypt session state as json blob
encrypted, err := sessions.MarshalSession(session, a.cipher) encrypted, err := sessions.MarshalSession(session, a.cipher)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
http.Redirect(w, r, getAuthCodeRedirectURL(redirectURL, state, string(encrypted)), http.StatusFound) http.Redirect(w, r, getAuthCodeRedirectURL(redirectURL, state, string(encrypted)), http.StatusFound)
@ -130,9 +135,10 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string
// SignOut signs the user out by trying to revoke the user's remote identity session along with // SignOut signs the user out by trying to revoke the user's remote identity session along with
// the associated local session state. Handles both GET and POST. // the associated local session state. Handles both GET and POST.
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm() if err := r.ParseForm(); err != nil {
if err != nil { log.Error().Err(err).Msg("authenticate: error SignOut form")
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) httpErr := &httputil.Error{Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
redirectURI := r.Form.Get("redirect_uri") redirectURI := r.Form.Get("redirect_uri")
@ -146,7 +152,8 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
err = a.provider.Revoke(session.AccessToken) err = a.provider.Revoke(session.AccessToken)
if err != nil { if err != nil {
log.Error().Err(err).Msg("authenticate: failed to revoke user session") log.Error().Err(err).Msg("authenticate: failed to revoke user session")
httputil.ErrorResponse(w, r, fmt.Sprintf("could not revoke session: %s ", err.Error()), http.StatusBadRequest) httpErr := &httputil.Error{Message: fmt.Sprintf("could not revoke session: %s ", err.Error()), Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
http.Redirect(w, r, redirectURI, http.StatusFound) http.Redirect(w, r, redirectURI, http.StatusFound)
@ -163,14 +170,16 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
// verify redirect uri is from the root domain // verify redirect uri is from the root domain
if !middleware.SameDomain(authRedirectURL, a.RedirectURL) { if !middleware.SameDomain(authRedirectURL, a.RedirectURL) {
httputil.ErrorResponse(w, r, "Invalid redirect parameter: redirect uri not from the root domain", http.StatusBadRequest) httpErr := &httputil.Error{Message: "Invalid redirect parameter: redirect uri not from the root domain", Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
// verify proxy url is from the root domain // verify proxy url is from the root domain
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri")) proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
if err != nil || !middleware.SameDomain(proxyRedirectURL, a.RedirectURL) { if err != nil || !middleware.SameDomain(proxyRedirectURL, a.RedirectURL) {
httputil.ErrorResponse(w, r, "Invalid redirect parameter: proxy url not from the root domain", http.StatusBadRequest) httpErr := &httputil.Error{Message: "Invalid redirect parameter: proxy url not from the root domain", Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
@ -178,7 +187,8 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
proxyRedirectSig := authRedirectURL.Query().Get("sig") proxyRedirectSig := authRedirectURL.Query().Get("sig")
ts := authRedirectURL.Query().Get("ts") ts := authRedirectURL.Query().Get("ts")
if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, a.SharedKey) { if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, a.SharedKey) {
httputil.ErrorResponse(w, r, "Invalid redirect parameter: invalid signature", http.StatusBadRequest) httpErr := &httputil.Error{Message: "Invalid redirect parameter: invalid signature", Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
@ -197,36 +207,36 @@ func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
switch h := err.(type) { switch h := err.(type) {
case nil: case nil:
break break
case httputil.HTTPError: case httputil.Error:
log.Error().Err(err).Msg("authenticate: oauth callback error") log.Error().Err(err).Msg("authenticate: oauth callback error")
httputil.ErrorResponse(w, r, h.Message, h.Code) httpErr := &httputil.Error{Message: h.Message, Code: h.Code}
httputil.ErrorResponse(w, r, httpErr)
return return
default: default:
log.Error().Err(err).Msg("authenticate: unexpected oauth callback error") log.Error().Err(err).Msg("authenticate: unexpected oauth callback error")
httputil.ErrorResponse(w, r, "Internal Error", http.StatusInternalServerError) httpErr := &httputil.Error{Message: "Internal Error", Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
// redirect back to the proxy-service via sign_in // redirect back to the proxy-service via sign_in
log.Info().Interface("redirect", redirect).Msg("proxy: OAuthCallback")
http.Redirect(w, r, redirect, http.StatusFound) http.Redirect(w, r, redirect, http.StatusFound)
} }
// getOAuthCallback completes the oauth cycle from an identity provider's callback // getOAuthCallback completes the oauth cycle from an identity provider's callback
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) { func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) {
// handle the callback response from the identity provider // handle the callback response from the identity provider
err := r.ParseForm() if err := r.ParseForm(); err != nil {
if err != nil { return "", httputil.Error{Code: http.StatusInternalServerError, Message: err.Error()}
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
} }
errorString := r.Form.Get("error") errorString := r.Form.Get("error")
if errorString != "" { if errorString != "" {
log.FromRequest(r).Error().Str("Error", errorString).Msg("authenticate: provider returned error") log.FromRequest(r).Error().Str("Error", errorString).Msg("authenticate: provider returned error")
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: errorString} return "", httputil.Error{Code: http.StatusForbidden, Message: errorString}
} }
code := r.Form.Get("code") code := r.Form.Get("code")
if code == "" { if code == "" {
log.FromRequest(r).Error().Err(err).Msg("authenticate: provider missing code") log.FromRequest(r).Error().Msg("authenticate: provider missing code")
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Missing Code"} return "", httputil.Error{Code: http.StatusBadRequest, Message: "Missing Code"}
} }
@ -234,18 +244,18 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
session, err := a.provider.Authenticate(code) session, err := a.provider.Authenticate(code)
if err != nil { if err != nil {
log.FromRequest(r).Error().Err(err).Msg("authenticate: error redeeming authenticate code") log.FromRequest(r).Error().Err(err).Msg("authenticate: error redeeming authenticate code")
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()} return "", httputil.Error{Code: http.StatusInternalServerError, Message: err.Error()}
} }
// okay, time to go back to the proxy service. // okay, time to go back to the proxy service.
bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state")) bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state"))
if err != nil { if err != nil {
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed decoding state") log.FromRequest(r).Error().Err(err).Msg("authenticate: failed decoding state")
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Couldn't decode state"} return "", httputil.Error{Code: http.StatusBadRequest, Message: "Couldn't decode state"}
} }
s := strings.SplitN(string(bytes), ":", 2) s := strings.SplitN(string(bytes), ":", 2)
if len(s) != 2 { if len(s) != 2 {
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Invalid State"} return "", httputil.Error{Code: http.StatusBadRequest, Message: "Invalid State"}
} }
nonce := s[0] nonce := s[0]
redirect := s[1] redirect := s[1]
@ -253,22 +263,22 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
defer a.csrfStore.ClearCSRF(w, r) defer a.csrfStore.ClearCSRF(w, r)
if err != nil || c.Value != nonce { if err != nil || c.Value != nonce {
log.FromRequest(r).Error().Err(err).Msg("authenticate: csrf failure") log.FromRequest(r).Error().Err(err).Msg("authenticate: csrf failure")
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "CSRF failed"} return "", httputil.Error{Code: http.StatusForbidden, Message: "CSRF failed"}
} }
redirectURL, err := url.Parse(redirect) redirectURL, err := url.Parse(redirect)
if err != nil { if err != nil {
log.FromRequest(r).Error().Err(err).Msg("authenticate: malformed redirect url") log.FromRequest(r).Error().Err(err).Msg("authenticate: malformed redirect url")
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Malformed redirect url"} return "", httputil.Error{Code: http.StatusForbidden, Message: "Malformed redirect url"}
} }
// sanity check, we are redirecting back to the same subdomain right? // sanity check, we are redirecting back to the same subdomain right?
if !middleware.SameDomain(redirectURL, a.RedirectURL) { if !middleware.SameDomain(redirectURL, a.RedirectURL) {
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Invalid Redirect URI domain"} return "", httputil.Error{Code: http.StatusBadRequest, Message: "Invalid Redirect URI domain"}
} }
err = a.sessionStore.SaveSession(w, r, session) err = a.sessionStore.SaveSession(w, r, session)
if err != nil { if err != nil {
log.Error().Err(err).Msg("authenticate: failed saving new session") log.Error().Err(err).Msg("authenticate: failed saving new session")
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Internal Error"} return "", httputil.Error{Code: http.StatusInternalServerError, Message: "Internal Error"}
} }
return redirect, nil return redirect, nil

View file

@ -64,12 +64,16 @@ func TestAuthenticate_Handler(t *testing.T) {
func TestAuthenticate_SignIn(t *testing.T) { func TestAuthenticate_SignIn(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
session sessions.SessionStore state string
provider identity.MockProvider redirectURI string
wantCode int session sessions.SessionStore
provider identity.MockProvider
wantCode int
}{ }{
{"good", {"good",
"state=example",
"redirect_uri=some.example",
&sessions.MockSessionStore{ &sessions.MockSessionStore{
Session: &sessions.SessionState{ Session: &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
@ -77,8 +81,10 @@ func TestAuthenticate_SignIn(t *testing.T) {
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(10 * time.Second),
}}, }},
identity.MockProvider{ValidateResponse: true}, identity.MockProvider{ValidateResponse: true},
http.StatusBadRequest}, http.StatusFound},
{"session not valid", {"session not valid",
"state=example",
"redirect_uri=some.example",
&sessions.MockSessionStore{ &sessions.MockSessionStore{
Session: &sessions.SessionState{ Session: &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
@ -87,15 +93,9 @@ func TestAuthenticate_SignIn(t *testing.T) {
}}, }},
identity.MockProvider{ValidateResponse: false}, identity.MockProvider{ValidateResponse: false},
http.StatusInternalServerError}, http.StatusInternalServerError},
{"session fails fails to save", &sessions.MockSessionStore{
SaveError: errors.New("error"),
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
}}, identity.MockProvider{ValidateResponse: true},
http.StatusBadRequest},
{"session refresh error", {"session refresh error",
"state=example",
"redirect_uri=some.example",
&sessions.MockSessionStore{ &sessions.MockSessionStore{
Session: &sessions.SessionState{ Session: &sessions.SessionState{
AccessToken: "AccessToken", AccessToken: "AccessToken",
@ -107,6 +107,8 @@ func TestAuthenticate_SignIn(t *testing.T) {
RefreshError: errors.New("error")}, RefreshError: errors.New("error")},
http.StatusInternalServerError}, http.StatusInternalServerError},
{"session save after refresh error", {"session save after refresh error",
"state=example",
"redirect_uri=some.example",
&sessions.MockSessionStore{ &sessions.MockSessionStore{
SaveError: errors.New("error"), SaveError: errors.New("error"),
Session: &sessions.SessionState{ Session: &sessions.SessionState{
@ -119,6 +121,8 @@ func TestAuthenticate_SignIn(t *testing.T) {
}, },
http.StatusInternalServerError}, http.StatusInternalServerError},
{"no cookie found trying to load", {"no cookie found trying to load",
"state=example",
"redirect_uri=some.example",
&sessions.MockSessionStore{ &sessions.MockSessionStore{
LoadError: http.ErrNoCookie, LoadError: http.ErrNoCookie,
Session: &sessions.SessionState{ Session: &sessions.SessionState{
@ -129,6 +133,8 @@ func TestAuthenticate_SignIn(t *testing.T) {
identity.MockProvider{ValidateResponse: true}, identity.MockProvider{ValidateResponse: true},
http.StatusBadRequest}, http.StatusBadRequest},
{"unexpected error trying to load session", {"unexpected error trying to load session",
"state=example",
"redirect_uri=some.example",
&sessions.MockSessionStore{ &sessions.MockSessionStore{
LoadError: errors.New("unexpeted"), LoadError: errors.New("unexpeted"),
Session: &sessions.SessionState{ Session: &sessions.SessionState{
@ -138,23 +144,63 @@ func TestAuthenticate_SignIn(t *testing.T) {
}}, }},
identity.MockProvider{ValidateResponse: true}, identity.MockProvider{ValidateResponse: true},
http.StatusInternalServerError}, http.StatusInternalServerError},
{"malformed form",
"state=example",
"redirect_uri=some.example",
&sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
identity.MockProvider{ValidateResponse: true},
http.StatusInternalServerError},
{"empty state",
"state=",
"redirect_uri=some.example",
&sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
identity.MockProvider{ValidateResponse: true},
http.StatusBadRequest},
{"malformed redirect uri",
"state=example",
"redirect_uri=https://accounts.google.^",
&sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
}},
identity.MockProvider{ValidateResponse: true},
http.StatusBadRequest},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
a := &Authenticate{ a := &Authenticate{
sessionStore: tt.session, sessionStore: tt.session,
provider: tt.provider, provider: tt.provider,
RedirectURL: uriParse("http://www.pomerium.io"), RedirectURL: uriParse(tt.redirectURI),
csrfStore: &sessions.MockCSRFStore{}, csrfStore: &sessions.MockCSRFStore{},
SharedKey: "secret", SharedKey: "secret",
cipher: mockCipher{}, cipher: mockCipher{},
} }
r := httptest.NewRequest("GET", "/sign-in", nil) uri := &url.URL{Path: "/"}
if tt.name == "malformed form" {
uri.RawQuery = "example=%zzzzz"
} else {
uri.RawQuery = fmt.Sprintf("%s&%s", tt.state, tt.redirectURI)
}
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
a.SignIn(w, r) a.SignIn(w, r)
if status := w.Code; status != tt.wantCode { if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode) t.Errorf("handler returned wrong status code: got %v want %v %s", status, tt.wantCode, uri)
t.Errorf("\n%+v", w.Body) t.Errorf("\n%+v", w.Body)
} }
}) })

View file

@ -164,7 +164,7 @@ func NewOptions() *Options {
AuthenticateInternalAddr: new(url.URL), AuthenticateInternalAddr: new(url.URL),
AuthorizeURL: new(url.URL), AuthorizeURL: new(url.URL),
RefreshCooldown: time.Duration(5 * time.Minute), RefreshCooldown: time.Duration(5 * time.Minute),
AllowWebsockets: false, AllowWebsockets: false,
} }
return o return o
} }

View file

@ -10,14 +10,16 @@ import (
"github.com/pomerium/pomerium/internal/templates" "github.com/pomerium/pomerium/internal/templates"
) )
// HTTPError stores the status code and a message for a given HTTP error. // Error reports an http error, its http status code, a custom message, and
type HTTPError struct { // whether it is CanDebug.
Code int type Error struct {
Message string Message string
Code int
CanDebug bool
} }
// Error fulfills the error interface, returning a string representation of the error. // Error fulfills the error interface, returning a string representation of the error.
func (h HTTPError) Error() string { func (h Error) Error() string {
return fmt.Sprintf("%d %s: %s", h.Code, http.StatusText(h.Code), h.Message) return fmt.Sprintf("%d %s: %s", h.Code, http.StatusText(h.Code), h.Message)
} }
@ -32,34 +34,32 @@ func CodeForError(err error) int {
// ErrorResponse renders an error page for errors given a message and a status code. // ErrorResponse renders an error page for errors given a message and a status code.
// If no message is passed, defaults to the text of the status code. // If no message is passed, defaults to the text of the status code.
func ErrorResponse(rw http.ResponseWriter, r *http.Request, message string, code int) { func ErrorResponse(rw http.ResponseWriter, r *http.Request, e *Error) {
if message == "" { requestID := ""
message = http.StatusText(code)
}
reqID := ""
id, ok := log.IDFromRequest(r) id, ok := log.IDFromRequest(r)
if ok { if ok {
reqID = id requestID = id
} }
if r.Header.Get("Accept") == "application/json" { if r.Header.Get("Accept") == "application/json" {
var response struct { var response struct {
Error string `json:"error"` Error string `json:"error"`
} }
response.Error = message response.Error = e.Message
writeJSONResponse(rw, code, response) writeJSONResponse(rw, e.Code, response)
} else { } else {
title := http.StatusText(code) rw.WriteHeader(e.Code)
rw.WriteHeader(code)
t := struct { t := struct {
Code int Code int
Title string Title string
Message string Message string
RequestID string RequestID string
CanDebug bool
}{ }{
Code: code, Code: e.Code,
Title: title, Title: http.StatusText(e.Code),
Message: message, Message: e.Message,
RequestID: reqID, RequestID: requestID,
CanDebug: e.CanDebug,
} }
templates.New().ExecuteTemplate(rw, "error.html", t) templates.New().ExecuteTemplate(rw, "error.html", t)
} }

View file

@ -32,9 +32,9 @@ func SetHeaders(securityHeaders map[string]string) func(next http.Handler) http.
func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Handler { func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm() if err := r.ParseForm(); err != nil {
if err != nil { httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) httputil.ErrorResponse(w, r, httpErr)
return return
} }
clientSecret := r.Form.Get("shared_secret") clientSecret := r.Form.Get("shared_secret")
@ -44,7 +44,7 @@ func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Hand
} }
if clientSecret != sharedSecret { if clientSecret != sharedSecret {
httputil.ErrorResponse(w, r, "Invalid client secret", http.StatusUnauthorized) httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusInternalServerError})
return return
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@ -59,16 +59,25 @@ func ValidateRedirectURI(rootDomain *url.URL) func(next http.Handler) http.Handl
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) httpErr := &httputil.Error{
Message: err.Error(),
Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
redirectURI, err := url.Parse(r.Form.Get("redirect_uri")) redirectURI, err := url.Parse(r.Form.Get("redirect_uri"))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) httpErr := &httputil.Error{
Message: err.Error(),
Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
if !SameDomain(redirectURI, rootDomain) { if !SameDomain(redirectURI, rootDomain) {
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest) httpErr := &httputil.Error{
Message: "Invalid redirect parameter",
Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@ -96,14 +105,18 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
redirectURI := r.Form.Get("redirect_uri") redirectURI := r.Form.Get("redirect_uri")
sigVal := r.Form.Get("sig") sigVal := r.Form.Get("sig")
timestamp := r.Form.Get("ts") timestamp := r.Form.Get("ts")
if !ValidSignature(redirectURI, sigVal, timestamp, sharedSecret) { if !ValidSignature(redirectURI, sigVal, timestamp, sharedSecret) {
httputil.ErrorResponse(w, r, "Cross service signature failed to validate", http.StatusUnauthorized) httpErr := &httputil.Error{
Message: "Cross service signature failed to validate",
Code: http.StatusUnauthorized}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
@ -117,7 +130,7 @@ func ValidateHost(validHost func(host string) bool) func(next http.Handler) http
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !validHost(r.Host) { if !validHost(r.Host) {
httputil.ErrorResponse(w, r, "Unknown route", http.StatusNotFound) httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusNotFound})
return return
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)

View file

@ -175,8 +175,8 @@ func TestValidateClientSecret(t *testing.T) {
}{ }{
{"simple", "secret", "secret", "secret", http.StatusOK}, {"simple", "secret", "secret", "secret", http.StatusOK},
{"missing get param, valid header", "secret", "", "secret", http.StatusOK}, {"missing get param, valid header", "secret", "", "secret", http.StatusOK},
{"missing both", "secret", "", "", http.StatusUnauthorized}, {"missing both", "secret", "", "", http.StatusInternalServerError},
{"simple bad", "bad-secret", "secret", "", http.StatusUnauthorized}, {"simple bad", "bad-secret", "secret", "", http.StatusInternalServerError},
{"malformed, invalid hex digits", "secret", "%zzzzz", "", http.StatusBadRequest}, {"malformed, invalid hex digits", "secret", "%zzzzz", "", http.StatusBadRequest},
} }

View file

@ -265,9 +265,11 @@ func New() *template.Template {
<svg class="icon error" xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24"><path fill="none" d="M0 0h24v24H0V0z"/><path d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zM4 12c0-4.42 3.58-8 8-8 1.85 0 3.55.63 4.9 1.69L5.69 16.9C4.63 15.55 4 13.85 4 12zm8 8c-1.85 0-3.55-.63-4.9-1.69L18.31 7.1C19.37 8.45 20 10.15 20 12c0 4.42-3.58 8-8 8z"/></svg> <svg class="icon error" xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24"><path fill="none" d="M0 0h24v24H0V0z"/><path d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zM4 12c0-4.42 3.58-8 8-8 1.85 0 3.55.63 4.9 1.69L5.69 16.9C4.63 15.55 4 13.85 4 12zm8 8c-1.85 0-3.55-.63-4.9-1.69L18.31 7.1C19.37 8.45 20 10.15 20 12c0 4.42-3.58 8-8 8z"/></svg>
<h1 class="title">{{.Title}}</h1> <h1 class="title">{{.Title}}</h1>
<section> <section>
<p class="message">{{.Message}}.</p> <p class="message">
<p class="message">Troubleshoot your <a href="/.pomerium">session</a>.</br> {{if .Message}}{{.Message}}</br>{{end}}
{{if .RequestID}} Request {{.RequestID}} {{end}} {{if .CanDebug}}Troubleshoot your <a href="/.pomerium">session</a>.</br>{{end}}
{{if .RequestID}} Request {{.RequestID}}</br>{{end}}
</p> </p>
</section> </section>
</form> </form>

View file

@ -3,12 +3,12 @@ package proxy // import "github.com/pomerium/pomerium/proxy"
import ( import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/pomerium/pomerium/internal/config"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
@ -74,7 +74,8 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
localState, err := p.cipher.Marshal(state) localState, err := p.cipher.Marshal(state)
if err != nil { if err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: failed to marshal csrf") log.FromRequest(r).Error().Err(err).Msg("proxy: failed to marshal csrf")
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
p.csrfStore.SetCSRF(w, r, localState) p.csrfStore.SetCSRF(w, r, localState)
@ -84,7 +85,8 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
remoteState, err := p.cipher.Marshal(state) remoteState, err := p.cipher.Marshal(state)
if err != nil { if err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: failed to encrypt cookie") log.FromRequest(r).Error().Err(err).Msg("proxy: failed to encrypt cookie")
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
@ -96,7 +98,8 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
if remoteState == localState { if remoteState == localState {
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
log.FromRequest(r).Error().Msg("proxy: encrypted state should not match") log.FromRequest(r).Error().Msg("proxy: encrypted state should not match")
httputil.ErrorResponse(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) httpErr := &httputil.Error{Message: http.StatusText(http.StatusBadRequest), Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
@ -113,25 +116,26 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
// other metadata, from the authenticator. // other metadata, from the authenticator.
// finish the oauth cycle // finish the oauth cycle
func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) { func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm() if err := r.ParseForm(); err != nil {
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing request form") log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing request form")
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
errorString := r.Form.Get("error")
if errorString != "" { if errorString := r.Form.Get("error"); errorString != "" {
httputil.ErrorResponse(w, r, errorString, http.StatusBadRequest) httpErr := &httputil.Error{Message: errorString, Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
// Encrypted CSRF passed from authenticate service // Encrypted CSRF passed from authenticate service
remoteStateEncrypted := r.Form.Get("state") remoteStateEncrypted := r.Form.Get("state")
remoteStatePlain := new(StateParameter) remoteStatePlain := new(StateParameter)
err = p.cipher.Unmarshal(remoteStateEncrypted, remoteStatePlain) if err := p.cipher.Unmarshal(remoteStateEncrypted, remoteStatePlain); err != nil {
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: could not unmarshal state") log.FromRequest(r).Error().Err(err).Msg("proxy: could not unmarshal state")
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError) httpErr := &httputil.Error{Message: "Internal error", Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
@ -139,7 +143,8 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
c, err := p.csrfStore.GetCSRF(r) c, err := p.csrfStore.GetCSRF(r)
if err != nil { if err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing csrf cookie") log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing csrf cookie")
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
p.csrfStore.ClearCSRF(w, r) p.csrfStore.ClearCSRF(w, r)
@ -148,7 +153,8 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
err = p.cipher.Unmarshal(localStateEncrypted, localStatePlain) err = p.cipher.Unmarshal(localStateEncrypted, localStatePlain)
if err != nil { if err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: couldn't unmarshal CSRF") log.FromRequest(r).Error().Err(err).Msg("proxy: couldn't unmarshal CSRF")
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError) httpErr := &httputil.Error{Message: "Internal error", Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
@ -157,7 +163,8 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
if remoteStateEncrypted == localStateEncrypted { if remoteStateEncrypted == localStateEncrypted {
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
log.FromRequest(r).Error().Msg("proxy: local and remote state should not match") log.FromRequest(r).Error().Msg("proxy: local and remote state should not match")
httputil.ErrorResponse(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) httpErr := &httputil.Error{Message: http.StatusText(http.StatusBadRequest), Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
@ -165,7 +172,8 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
if remoteStatePlain.SessionID != localStatePlain.SessionID { if remoteStatePlain.SessionID != localStatePlain.SessionID {
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
log.FromRequest(r).Error().Msg("proxy: CSRF mismatch") log.FromRequest(r).Error().Msg("proxy: CSRF mismatch")
httputil.ErrorResponse(w, r, "CSRF mismatch", http.StatusBadRequest) httpErr := &httputil.Error{Message: "CSRF mismatch", Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
@ -215,7 +223,8 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
return return
default: default:
log.FromRequest(r).Error().Err(err).Msg("proxy: unexpected error") log.FromRequest(r).Error().Err(err).Msg("proxy: unexpected error")
httputil.ErrorResponse(w, r, "An unexpected error occurred", http.StatusInternalServerError) httpErr := &httputil.Error{Message: "An unexpected error occurred", Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
} }
@ -223,13 +232,18 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
if err = p.authenticate(w, r, session); err != nil { if err = p.authenticate(w, r, session); err != nil {
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
log.Debug().Err(err).Msg("proxy: user unauthenticated") log.Debug().Err(err).Msg("proxy: user unauthenticated")
httputil.ErrorResponse(w, r, "User unauthenticated", http.StatusForbidden) httpErr := &httputil.Error{
Message: "User unauthenticated",
Code: http.StatusForbidden,
CanDebug: true}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, session) authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, session)
if err != nil || !authorized { if err != nil || !authorized {
log.FromRequest(r).Warn().Err(err).Msg("proxy: user unauthorized") log.FromRequest(r).Warn().Err(err).Msg("proxy: user unauthorized")
httputil.ErrorResponse(w, r, "Access unauthorized", http.StatusUnauthorized) httpErr := &httputil.Error{Code: http.StatusUnauthorized, CanDebug: true}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
} }
@ -237,7 +251,7 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
// We have validated the users request and now proxy their request to the provided upstream. // We have validated the users request and now proxy their request to the provided upstream.
route, ok := p.router(r) route, ok := p.router(r)
if !ok { if !ok {
httputil.ErrorResponse(w, r, "unknown route to proxy", http.StatusNotFound) httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusNotFound})
return return
} }
route.ServeHTTP(w, r) route.ServeHTTP(w, r)
@ -250,13 +264,15 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
session, err := p.sessionStore.LoadSession(r) session, err := p.sessionStore.LoadSession(r)
if err != nil { if err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: load session failed") log.FromRequest(r).Error().Err(err).Msg("proxy: load session failed")
httputil.ErrorResponse(w, r, "", http.StatusBadRequest) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
if err := p.authenticate(w, r, session); err != nil { if err := p.authenticate(w, r, session); err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: authenticate failed") log.FromRequest(r).Error().Err(err).Msg("proxy: authenticate failed")
httputil.ErrorResponse(w, r, "", http.StatusUnauthorized) httpErr := &httputil.Error{Code: http.StatusUnauthorized, CanDebug: true}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
@ -264,7 +280,8 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
if err != nil { if err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: is admin client") log.FromRequest(r).Error().Err(err).Msg("proxy: is admin client")
httputil.ErrorResponse(w, r, "", http.StatusInternalServerError) httpErr := &httputil.Error{Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
@ -273,7 +290,8 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
csrfCookie, err := p.cipher.Marshal(csrf) csrfCookie, err := p.cipher.Marshal(csrf)
if err != nil { if err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: failed to marshal csrf") log.FromRequest(r).Error().Err(err).Msg("proxy: failed to marshal csrf")
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) httpErr := &httputil.Error{Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
p.csrfStore.SetCSRF(w, r, csrfCookie) p.csrfStore.SetCSRF(w, r, csrfCookie)
@ -310,13 +328,15 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
func (p *Proxy) Refresh(w http.ResponseWriter, r *http.Request) { func (p *Proxy) Refresh(w http.ResponseWriter, r *http.Request) {
session, err := p.sessionStore.LoadSession(r) session, err := p.sessionStore.LoadSession(r)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
iss, err := session.IssuedAt() iss, err := session.IssuedAt()
if err != nil { if err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: couldn't get token's create time") log.FromRequest(r).Error().Err(err).Msg("proxy: couldn't get token's create time")
httputil.ErrorResponse(w, r, "", http.StatusInternalServerError) httpErr := &httputil.Error{Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
@ -324,20 +344,23 @@ func (p *Proxy) Refresh(w http.ResponseWriter, r *http.Request) {
// trying to DOS the identity provider. // trying to DOS the identity provider.
if time.Since(iss) < p.refreshCooldown { if time.Since(iss) < p.refreshCooldown {
log.FromRequest(r).Error().Dur("cooldown", p.refreshCooldown).Err(err).Msg("proxy: refresh cooldown") log.FromRequest(r).Error().Dur("cooldown", p.refreshCooldown).Err(err).Msg("proxy: refresh cooldown")
httputil.ErrorResponse(w, r, httpErr := &httputil.Error{
fmt.Sprintf("Session must be %v old before refresh", p.refreshCooldown), Message: fmt.Sprintf("Session must be %v old before refresh", p.refreshCooldown),
http.StatusBadRequest) Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
newSession, err := p.AuthenticateClient.Refresh(r.Context(), session) newSession, err := p.AuthenticateClient.Refresh(r.Context(), session)
if err != nil { if err != nil {
log.FromRequest(r).Warn().Err(err).Msg("proxy: refresh failed") log.FromRequest(r).Warn().Err(err).Msg("proxy: refresh failed")
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
if err = p.sessionStore.SaveSession(w, r, newSession); err != nil { if err = p.sessionStore.SaveSession(w, r, newSession); err != nil {
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
http.Redirect(w, r, "/.pomerium", http.StatusFound) http.Redirect(w, r, "/.pomerium", http.StatusFound)
@ -350,27 +373,34 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost { if r.Method == http.MethodPost {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: impersonate form") log.FromRequest(r).Error().Err(err).Msg("proxy: impersonate form")
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
session, err := p.sessionStore.LoadSession(r) session, err := p.sessionStore.LoadSession(r)
if err != nil { if err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: load session") log.FromRequest(r).Error().Err(err).Msg("proxy: load session")
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
// authorization check -- is this user an admin? // authorization check -- is this user an admin?
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
if err != nil || !isAdmin { if err != nil || !isAdmin {
log.FromRequest(r).Error().Err(err).Msg("proxy: user must be admin to impersonate") log.FromRequest(r).Error().Err(err).Msg("proxy: user must be admin to impersonate")
httputil.ErrorResponse(w, r, "user must be admin to impersonate", http.StatusForbidden) httpErr := &httputil.Error{
Message: fmt.Sprintf("%s must be and administrator", session.Email),
Code: http.StatusForbidden,
CanDebug: true}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
// CSRF check -- did this request originate from our form? // CSRF check -- did this request originate from our form?
c, err := p.csrfStore.GetCSRF(r) c, err := p.csrfStore.GetCSRF(r)
if err != nil { if err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing csrf cookie") log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing csrf cookie")
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
p.csrfStore.ClearCSRF(w, r) p.csrfStore.ClearCSRF(w, r)
@ -378,12 +408,14 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
decryptedCSRF := new(StateParameter) decryptedCSRF := new(StateParameter)
if err = p.cipher.Unmarshal(encryptedCSRF, decryptedCSRF); err != nil { if err = p.cipher.Unmarshal(encryptedCSRF, decryptedCSRF); err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: couldn't unmarshal CSRF") log.FromRequest(r).Error().Err(err).Msg("proxy: couldn't unmarshal CSRF")
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError) httpErr := &httputil.Error{Message: "Internal error", Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
if decryptedCSRF.SessionID != r.FormValue("csrf") { if decryptedCSRF.SessionID != r.FormValue("csrf") {
log.FromRequest(r).Error().Err(err).Msg("proxy: impersonate CSRF mismatch") log.FromRequest(r).Error().Err(err).Msg("proxy: impersonate CSRF mismatch")
httputil.ErrorResponse(w, r, "CSRF mismatch", http.StatusForbidden) httpErr := &httputil.Error{Message: "CSRF mismatch", Code: http.StatusForbidden}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
@ -393,7 +425,8 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
if err := p.sessionStore.SaveSession(w, r, session); err != nil { if err := p.sessionStore.SaveSession(w, r, session); err != nil {
log.FromRequest(r).Error().Err(err).Msg("proxy: save session") log.FromRequest(r).Error().Err(err).Msg("proxy: save session")
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError}
httputil.ErrorResponse(w, r, httpErr)
return return
} }
} }
@ -510,7 +543,8 @@ func websocketHandlerFunc(baseHandler http.Handler, timeoutHandler http.Handler,
} }
log.FromRequest(r).Warn().Msg("proxy: attempt to proxy a websocket connection, but websocket support is disabled in the configuration") log.FromRequest(r).Warn().Msg("proxy: attempt to proxy a websocket connection, but websocket support is disabled in the configuration")
httputil.ErrorResponse(w, r, "websockets not supported by proxy", http.StatusBadRequest) httpErr := &httputil.Error{Message: "websockets not supported by proxy", Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return return
} }

View file

@ -451,6 +451,7 @@ func TestProxy_Impersonate(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
malformed bool
options *config.Options options *config.Options
method string method string
email string email string
@ -463,14 +464,15 @@ func TestProxy_Impersonate(t *testing.T) {
authorizer clients.Authorizer authorizer clients.Authorizer
wantStatus int wantStatus int
}{ }{
{"good", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, {"good", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"session load error", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"session load error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"non admin users rejected", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden}, {"non admin users rejected", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
{"non admin users rejected on error", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden}, {"non admin users rejected on error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
{"csrf from store retrieve failure", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}, GetError: errors.New("err")}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest}, {"csrf from store retrieve failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}, GetError: errors.New("err")}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest},
{"can't decrypt csrf value", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{UnmarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"can't decrypt csrf value", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{UnmarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"decrypted csrf mismatch", opts, http.MethodPost, "user@blah.com", "", "CSRF!", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusForbidden}, {"decrypted csrf mismatch", false, opts, http.MethodPost, "user@blah.com", "", "CSRF!", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusForbidden},
{"save session failure", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"save session failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"malformed", true, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -487,8 +489,12 @@ func TestProxy_Impersonate(t *testing.T) {
postForm.Add("email", tt.email) postForm.Add("email", tt.email)
postForm.Add("group", tt.groups) postForm.Add("group", tt.groups)
postForm.Set("csrf", tt.csrf) postForm.Set("csrf", tt.csrf)
uri := &url.URL{Path: "/"}
if tt.malformed {
uri.RawQuery = "email=%zzzzz"
}
r := httptest.NewRequest(tt.method, uri.String(), bytes.NewBufferString(postForm.Encode()))
r := httptest.NewRequest(tt.method, "/", bytes.NewBufferString(postForm.Encode()))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.Impersonate(w, r) p.Impersonate(w, r)