diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 182dc347e..659a8a477 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -14,7 +14,7 @@ import ( "github.com/pomerium/pomerium/internal/sessions" ) -// CSPHeaders adds content security headers for authenticate's handlers +// CSPHeaders are the content security headers added to the service's handlers var CSPHeaders = map[string]string{ "Content-Security-Policy": "default-src 'none'; style-src 'self'" + " 'sha256-z9MsgkMbQjRSLxzAfN55jB3a9pP0PQ4OHFH8b4iDP6s=' " + @@ -80,36 +80,41 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { return default: log.FromRequest(r).Error().Err(err).Msg("proxy: unexpected error") - httputil.ErrorResponse(w, r, "An unexpected error occurred", http.StatusInternalServerError) + httpErr := &httputil.Error{Message: "An unexpected error occurred", Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } } err = a.authenticate(w, r, session) if err != nil { - httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } - err = r.ParseForm() - if err != nil { - httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) + if err = r.ParseForm(); err != nil { + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } // original `state` parameter received from the proxy application. state := r.Form.Get("state") if state == "" { - httputil.ErrorResponse(w, r, "no state parameter supplied", http.StatusBadRequest) + httpErr := &httputil.Error{Message: "no state parameter supplied", Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } redirectURL, err := url.Parse(r.Form.Get("redirect_uri")) if err != nil { - httputil.ErrorResponse(w, r, "malformed redirect_uri parameter passed", http.StatusBadRequest) + httpErr := &httputil.Error{Message: "malformed redirect_uri parameter passed", Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } // encrypt session state as json blob encrypted, err := sessions.MarshalSession(session, a.cipher) if err != nil { - httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } http.Redirect(w, r, getAuthCodeRedirectURL(redirectURL, state, string(encrypted)), http.StatusFound) @@ -130,9 +135,10 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string // SignOut signs the user out by trying to revoke the user's remote identity session along with // the associated local session state. Handles both GET and POST. func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { - err := r.ParseForm() - if err != nil { - httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) + if err := r.ParseForm(); err != nil { + log.Error().Err(err).Msg("authenticate: error SignOut form") + httpErr := &httputil.Error{Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } redirectURI := r.Form.Get("redirect_uri") @@ -146,7 +152,8 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { err = a.provider.Revoke(session.AccessToken) if err != nil { log.Error().Err(err).Msg("authenticate: failed to revoke user session") - httputil.ErrorResponse(w, r, fmt.Sprintf("could not revoke session: %s ", err.Error()), http.StatusBadRequest) + httpErr := &httputil.Error{Message: fmt.Sprintf("could not revoke session: %s ", err.Error()), Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } http.Redirect(w, r, redirectURI, http.StatusFound) @@ -163,14 +170,16 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { // verify redirect uri is from the root domain if !middleware.SameDomain(authRedirectURL, a.RedirectURL) { - httputil.ErrorResponse(w, r, "Invalid redirect parameter: redirect uri not from the root domain", http.StatusBadRequest) + httpErr := &httputil.Error{Message: "Invalid redirect parameter: redirect uri not from the root domain", Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } // verify proxy url is from the root domain proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri")) if err != nil || !middleware.SameDomain(proxyRedirectURL, a.RedirectURL) { - httputil.ErrorResponse(w, r, "Invalid redirect parameter: proxy url not from the root domain", http.StatusBadRequest) + httpErr := &httputil.Error{Message: "Invalid redirect parameter: proxy url not from the root domain", Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } @@ -178,7 +187,8 @@ func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { proxyRedirectSig := authRedirectURL.Query().Get("sig") ts := authRedirectURL.Query().Get("ts") if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, a.SharedKey) { - httputil.ErrorResponse(w, r, "Invalid redirect parameter: invalid signature", http.StatusBadRequest) + httpErr := &httputil.Error{Message: "Invalid redirect parameter: invalid signature", Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } @@ -197,36 +207,36 @@ func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) { switch h := err.(type) { case nil: break - case httputil.HTTPError: + case httputil.Error: log.Error().Err(err).Msg("authenticate: oauth callback error") - httputil.ErrorResponse(w, r, h.Message, h.Code) + httpErr := &httputil.Error{Message: h.Message, Code: h.Code} + httputil.ErrorResponse(w, r, httpErr) return default: log.Error().Err(err).Msg("authenticate: unexpected oauth callback error") - httputil.ErrorResponse(w, r, "Internal Error", http.StatusInternalServerError) + httpErr := &httputil.Error{Message: "Internal Error", Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } // redirect back to the proxy-service via sign_in - log.Info().Interface("redirect", redirect).Msg("proxy: OAuthCallback") http.Redirect(w, r, redirect, http.StatusFound) } // getOAuthCallback completes the oauth cycle from an identity provider's callback func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) { // handle the callback response from the identity provider - err := r.ParseForm() - if err != nil { - return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()} + if err := r.ParseForm(); err != nil { + return "", httputil.Error{Code: http.StatusInternalServerError, Message: err.Error()} } errorString := r.Form.Get("error") if errorString != "" { log.FromRequest(r).Error().Str("Error", errorString).Msg("authenticate: provider returned error") - return "", httputil.HTTPError{Code: http.StatusForbidden, Message: errorString} + return "", httputil.Error{Code: http.StatusForbidden, Message: errorString} } code := r.Form.Get("code") if code == "" { - log.FromRequest(r).Error().Err(err).Msg("authenticate: provider missing code") - return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Missing Code"} + log.FromRequest(r).Error().Msg("authenticate: provider missing code") + return "", httputil.Error{Code: http.StatusBadRequest, Message: "Missing Code"} } @@ -234,18 +244,18 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) session, err := a.provider.Authenticate(code) if err != nil { log.FromRequest(r).Error().Err(err).Msg("authenticate: error redeeming authenticate code") - return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()} + return "", httputil.Error{Code: http.StatusInternalServerError, Message: err.Error()} } // okay, time to go back to the proxy service. bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state")) if err != nil { log.FromRequest(r).Error().Err(err).Msg("authenticate: failed decoding state") - return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Couldn't decode state"} + return "", httputil.Error{Code: http.StatusBadRequest, Message: "Couldn't decode state"} } s := strings.SplitN(string(bytes), ":", 2) if len(s) != 2 { - return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Invalid State"} + return "", httputil.Error{Code: http.StatusBadRequest, Message: "Invalid State"} } nonce := s[0] redirect := s[1] @@ -253,22 +263,22 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) defer a.csrfStore.ClearCSRF(w, r) if err != nil || c.Value != nonce { log.FromRequest(r).Error().Err(err).Msg("authenticate: csrf failure") - return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "CSRF failed"} + return "", httputil.Error{Code: http.StatusForbidden, Message: "CSRF failed"} } redirectURL, err := url.Parse(redirect) if err != nil { log.FromRequest(r).Error().Err(err).Msg("authenticate: malformed redirect url") - return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Malformed redirect url"} + return "", httputil.Error{Code: http.StatusForbidden, Message: "Malformed redirect url"} } // sanity check, we are redirecting back to the same subdomain right? if !middleware.SameDomain(redirectURL, a.RedirectURL) { - return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Invalid Redirect URI domain"} + return "", httputil.Error{Code: http.StatusBadRequest, Message: "Invalid Redirect URI domain"} } err = a.sessionStore.SaveSession(w, r, session) if err != nil { log.Error().Err(err).Msg("authenticate: failed saving new session") - return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Internal Error"} + return "", httputil.Error{Code: http.StatusInternalServerError, Message: "Internal Error"} } return redirect, nil diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 8bb8f32f1..f4f0e9cfc 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -64,12 +64,16 @@ func TestAuthenticate_Handler(t *testing.T) { func TestAuthenticate_SignIn(t *testing.T) { tests := []struct { - name string - session sessions.SessionStore - provider identity.MockProvider - wantCode int + name string + state string + redirectURI string + session sessions.SessionStore + provider identity.MockProvider + wantCode int }{ {"good", + "state=example", + "redirect_uri=some.example", &sessions.MockSessionStore{ Session: &sessions.SessionState{ AccessToken: "AccessToken", @@ -77,8 +81,10 @@ func TestAuthenticate_SignIn(t *testing.T) { RefreshDeadline: time.Now().Add(10 * time.Second), }}, identity.MockProvider{ValidateResponse: true}, - http.StatusBadRequest}, + http.StatusFound}, {"session not valid", + "state=example", + "redirect_uri=some.example", &sessions.MockSessionStore{ Session: &sessions.SessionState{ AccessToken: "AccessToken", @@ -87,15 +93,9 @@ func TestAuthenticate_SignIn(t *testing.T) { }}, identity.MockProvider{ValidateResponse: false}, http.StatusInternalServerError}, - {"session fails fails to save", &sessions.MockSessionStore{ - SaveError: errors.New("error"), - Session: &sessions.SessionState{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - RefreshDeadline: time.Now().Add(10 * time.Second), - }}, identity.MockProvider{ValidateResponse: true}, - http.StatusBadRequest}, {"session refresh error", + "state=example", + "redirect_uri=some.example", &sessions.MockSessionStore{ Session: &sessions.SessionState{ AccessToken: "AccessToken", @@ -107,6 +107,8 @@ func TestAuthenticate_SignIn(t *testing.T) { RefreshError: errors.New("error")}, http.StatusInternalServerError}, {"session save after refresh error", + "state=example", + "redirect_uri=some.example", &sessions.MockSessionStore{ SaveError: errors.New("error"), Session: &sessions.SessionState{ @@ -119,6 +121,8 @@ func TestAuthenticate_SignIn(t *testing.T) { }, http.StatusInternalServerError}, {"no cookie found trying to load", + "state=example", + "redirect_uri=some.example", &sessions.MockSessionStore{ LoadError: http.ErrNoCookie, Session: &sessions.SessionState{ @@ -129,6 +133,8 @@ func TestAuthenticate_SignIn(t *testing.T) { identity.MockProvider{ValidateResponse: true}, http.StatusBadRequest}, {"unexpected error trying to load session", + "state=example", + "redirect_uri=some.example", &sessions.MockSessionStore{ LoadError: errors.New("unexpeted"), Session: &sessions.SessionState{ @@ -138,23 +144,63 @@ func TestAuthenticate_SignIn(t *testing.T) { }}, identity.MockProvider{ValidateResponse: true}, http.StatusInternalServerError}, + {"malformed form", + "state=example", + "redirect_uri=some.example", + &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + RefreshDeadline: time.Now().Add(10 * time.Second), + }}, + identity.MockProvider{ValidateResponse: true}, + http.StatusInternalServerError}, + {"empty state", + "state=", + "redirect_uri=some.example", + &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + RefreshDeadline: time.Now().Add(10 * time.Second), + }}, + identity.MockProvider{ValidateResponse: true}, + http.StatusBadRequest}, + + {"malformed redirect uri", + "state=example", + "redirect_uri=https://accounts.google.^", + &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + RefreshDeadline: time.Now().Add(10 * time.Second), + }}, + identity.MockProvider{ValidateResponse: true}, + http.StatusBadRequest}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authenticate{ sessionStore: tt.session, provider: tt.provider, - RedirectURL: uriParse("http://www.pomerium.io"), + RedirectURL: uriParse(tt.redirectURI), csrfStore: &sessions.MockCSRFStore{}, SharedKey: "secret", cipher: mockCipher{}, } - r := httptest.NewRequest("GET", "/sign-in", nil) + uri := &url.URL{Path: "/"} + if tt.name == "malformed form" { + uri.RawQuery = "example=%zzzzz" + } else { + uri.RawQuery = fmt.Sprintf("%s&%s", tt.state, tt.redirectURI) + } + r := httptest.NewRequest(http.MethodGet, uri.String(), nil) w := httptest.NewRecorder() a.SignIn(w, r) if status := w.Code; status != tt.wantCode { - t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode) + t.Errorf("handler returned wrong status code: got %v want %v %s", status, tt.wantCode, uri) t.Errorf("\n%+v", w.Body) } }) diff --git a/internal/config/options.go b/internal/config/options.go index 3c9db9983..ad62fbd5e 100644 --- a/internal/config/options.go +++ b/internal/config/options.go @@ -164,7 +164,7 @@ func NewOptions() *Options { AuthenticateInternalAddr: new(url.URL), AuthorizeURL: new(url.URL), RefreshCooldown: time.Duration(5 * time.Minute), - AllowWebsockets: false, + AllowWebsockets: false, } return o } diff --git a/internal/httputil/errors.go b/internal/httputil/errors.go index 87ab97b03..607ad7e21 100644 --- a/internal/httputil/errors.go +++ b/internal/httputil/errors.go @@ -10,14 +10,16 @@ import ( "github.com/pomerium/pomerium/internal/templates" ) -// HTTPError stores the status code and a message for a given HTTP error. -type HTTPError struct { - Code int - Message string +// Error reports an http error, its http status code, a custom message, and +// whether it is CanDebug. +type Error struct { + Message string + Code int + CanDebug bool } // Error fulfills the error interface, returning a string representation of the error. -func (h HTTPError) Error() string { +func (h Error) Error() string { return fmt.Sprintf("%d %s: %s", h.Code, http.StatusText(h.Code), h.Message) } @@ -32,34 +34,32 @@ func CodeForError(err error) int { // ErrorResponse renders an error page for errors given a message and a status code. // If no message is passed, defaults to the text of the status code. -func ErrorResponse(rw http.ResponseWriter, r *http.Request, message string, code int) { - if message == "" { - message = http.StatusText(code) - } - reqID := "" +func ErrorResponse(rw http.ResponseWriter, r *http.Request, e *Error) { + requestID := "" id, ok := log.IDFromRequest(r) if ok { - reqID = id + requestID = id } if r.Header.Get("Accept") == "application/json" { var response struct { Error string `json:"error"` } - response.Error = message - writeJSONResponse(rw, code, response) + response.Error = e.Message + writeJSONResponse(rw, e.Code, response) } else { - title := http.StatusText(code) - rw.WriteHeader(code) + rw.WriteHeader(e.Code) t := struct { Code int Title string Message string RequestID string + CanDebug bool }{ - Code: code, - Title: title, - Message: message, - RequestID: reqID, + Code: e.Code, + Title: http.StatusText(e.Code), + Message: e.Message, + RequestID: requestID, + CanDebug: e.CanDebug, } templates.New().ExecuteTemplate(rw, "error.html", t) } diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 84475f380..18e65b1ab 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -32,9 +32,9 @@ func SetHeaders(securityHeaders map[string]string) func(next http.Handler) http. func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - err := r.ParseForm() - if err != nil { - httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) + if err := r.ParseForm(); err != nil { + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } clientSecret := r.Form.Get("shared_secret") @@ -44,7 +44,7 @@ func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Hand } if clientSecret != sharedSecret { - httputil.ErrorResponse(w, r, "Invalid client secret", http.StatusUnauthorized) + httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusInternalServerError}) return } next.ServeHTTP(w, r) @@ -59,16 +59,25 @@ func ValidateRedirectURI(rootDomain *url.URL) func(next http.Handler) http.Handl return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { - httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) + httpErr := &httputil.Error{ + Message: err.Error(), + Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } redirectURI, err := url.Parse(r.Form.Get("redirect_uri")) if err != nil { - httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) + httpErr := &httputil.Error{ + Message: err.Error(), + Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } if !SameDomain(redirectURI, rootDomain) { - httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest) + httpErr := &httputil.Error{ + Message: "Invalid redirect parameter", + Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } next.ServeHTTP(w, r) @@ -96,14 +105,18 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { - httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } redirectURI := r.Form.Get("redirect_uri") sigVal := r.Form.Get("sig") timestamp := r.Form.Get("ts") if !ValidSignature(redirectURI, sigVal, timestamp, sharedSecret) { - httputil.ErrorResponse(w, r, "Cross service signature failed to validate", http.StatusUnauthorized) + httpErr := &httputil.Error{ + Message: "Cross service signature failed to validate", + Code: http.StatusUnauthorized} + httputil.ErrorResponse(w, r, httpErr) return } @@ -117,7 +130,7 @@ func ValidateHost(validHost func(host string) bool) func(next http.Handler) http return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !validHost(r.Host) { - httputil.ErrorResponse(w, r, "Unknown route", http.StatusNotFound) + httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusNotFound}) return } next.ServeHTTP(w, r) diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index de7b15c6b..e817b53c8 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -175,8 +175,8 @@ func TestValidateClientSecret(t *testing.T) { }{ {"simple", "secret", "secret", "secret", http.StatusOK}, {"missing get param, valid header", "secret", "", "secret", http.StatusOK}, - {"missing both", "secret", "", "", http.StatusUnauthorized}, - {"simple bad", "bad-secret", "secret", "", http.StatusUnauthorized}, + {"missing both", "secret", "", "", http.StatusInternalServerError}, + {"simple bad", "bad-secret", "secret", "", http.StatusInternalServerError}, {"malformed, invalid hex digits", "secret", "%zzzzz", "", http.StatusBadRequest}, } diff --git a/internal/templates/templates.go b/internal/templates/templates.go index 5c0a35175..f5989fb21 100644 --- a/internal/templates/templates.go +++ b/internal/templates/templates.go @@ -265,9 +265,11 @@ func New() *template.Template {

{{.Title}}

-

{{.Message}}.

-

Troubleshoot your session.
- {{if .RequestID}} Request {{.RequestID}} {{end}} +

+ {{if .Message}}{{.Message}}
{{end}} + {{if .CanDebug}}Troubleshoot your session.
{{end}} + {{if .RequestID}} Request {{.RequestID}}
{{end}} +

diff --git a/proxy/handlers.go b/proxy/handlers.go index 8beb31928..82e9ae582 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -3,12 +3,12 @@ package proxy // import "github.com/pomerium/pomerium/proxy" import ( "encoding/base64" "fmt" - "github.com/pomerium/pomerium/internal/config" "net/http" "net/url" "strings" "time" + "github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" @@ -74,7 +74,8 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) { localState, err := p.cipher.Marshal(state) if err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: failed to marshal csrf") - httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } p.csrfStore.SetCSRF(w, r, localState) @@ -84,7 +85,8 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) { remoteState, err := p.cipher.Marshal(state) if err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: failed to encrypt cookie") - httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } @@ -96,7 +98,8 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) { if remoteState == localState { p.sessionStore.ClearSession(w, r) log.FromRequest(r).Error().Msg("proxy: encrypted state should not match") - httputil.ErrorResponse(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + httpErr := &httputil.Error{Message: http.StatusText(http.StatusBadRequest), Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } @@ -113,25 +116,26 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) { // other metadata, from the authenticator. // finish the oauth cycle func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) { - err := r.ParseForm() - if err != nil { + if err := r.ParseForm(); err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing request form") - httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } - errorString := r.Form.Get("error") - if errorString != "" { - httputil.ErrorResponse(w, r, errorString, http.StatusBadRequest) + + if errorString := r.Form.Get("error"); errorString != "" { + httpErr := &httputil.Error{Message: errorString, Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } // Encrypted CSRF passed from authenticate service remoteStateEncrypted := r.Form.Get("state") remoteStatePlain := new(StateParameter) - err = p.cipher.Unmarshal(remoteStateEncrypted, remoteStatePlain) - if err != nil { + if err := p.cipher.Unmarshal(remoteStateEncrypted, remoteStatePlain); err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: could not unmarshal state") - httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError) + httpErr := &httputil.Error{Message: "Internal error", Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } @@ -139,7 +143,8 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) { c, err := p.csrfStore.GetCSRF(r) if err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing csrf cookie") - httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } p.csrfStore.ClearCSRF(w, r) @@ -148,7 +153,8 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) { err = p.cipher.Unmarshal(localStateEncrypted, localStatePlain) if err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: couldn't unmarshal CSRF") - httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError) + httpErr := &httputil.Error{Message: "Internal error", Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } @@ -157,7 +163,8 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) { if remoteStateEncrypted == localStateEncrypted { p.sessionStore.ClearSession(w, r) log.FromRequest(r).Error().Msg("proxy: local and remote state should not match") - httputil.ErrorResponse(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + httpErr := &httputil.Error{Message: http.StatusText(http.StatusBadRequest), Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } @@ -165,7 +172,8 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) { if remoteStatePlain.SessionID != localStatePlain.SessionID { p.sessionStore.ClearSession(w, r) log.FromRequest(r).Error().Msg("proxy: CSRF mismatch") - httputil.ErrorResponse(w, r, "CSRF mismatch", http.StatusBadRequest) + httpErr := &httputil.Error{Message: "CSRF mismatch", Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } @@ -215,7 +223,8 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { return default: log.FromRequest(r).Error().Err(err).Msg("proxy: unexpected error") - httputil.ErrorResponse(w, r, "An unexpected error occurred", http.StatusInternalServerError) + httpErr := &httputil.Error{Message: "An unexpected error occurred", Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } } @@ -223,13 +232,18 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { if err = p.authenticate(w, r, session); err != nil { p.sessionStore.ClearSession(w, r) log.Debug().Err(err).Msg("proxy: user unauthenticated") - httputil.ErrorResponse(w, r, "User unauthenticated", http.StatusForbidden) + httpErr := &httputil.Error{ + Message: "User unauthenticated", + Code: http.StatusForbidden, + CanDebug: true} + httputil.ErrorResponse(w, r, httpErr) return } authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, session) if err != nil || !authorized { log.FromRequest(r).Warn().Err(err).Msg("proxy: user unauthorized") - httputil.ErrorResponse(w, r, "Access unauthorized", http.StatusUnauthorized) + httpErr := &httputil.Error{Code: http.StatusUnauthorized, CanDebug: true} + httputil.ErrorResponse(w, r, httpErr) return } } @@ -237,7 +251,7 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { // We have validated the users request and now proxy their request to the provided upstream. route, ok := p.router(r) if !ok { - httputil.ErrorResponse(w, r, "unknown route to proxy", http.StatusNotFound) + httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusNotFound}) return } route.ServeHTTP(w, r) @@ -250,13 +264,15 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) { session, err := p.sessionStore.LoadSession(r) if err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: load session failed") - httputil.ErrorResponse(w, r, "", http.StatusBadRequest) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } if err := p.authenticate(w, r, session); err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: authenticate failed") - httputil.ErrorResponse(w, r, "", http.StatusUnauthorized) + httpErr := &httputil.Error{Code: http.StatusUnauthorized, CanDebug: true} + httputil.ErrorResponse(w, r, httpErr) return } @@ -264,7 +280,8 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) { isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) if err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: is admin client") - httputil.ErrorResponse(w, r, "", http.StatusInternalServerError) + httpErr := &httputil.Error{Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } @@ -273,7 +290,8 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) { csrfCookie, err := p.cipher.Marshal(csrf) if err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: failed to marshal csrf") - httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) + httpErr := &httputil.Error{Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } p.csrfStore.SetCSRF(w, r, csrfCookie) @@ -310,13 +328,15 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) { func (p *Proxy) Refresh(w http.ResponseWriter, r *http.Request) { session, err := p.sessionStore.LoadSession(r) if err != nil { - httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } iss, err := session.IssuedAt() if err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: couldn't get token's create time") - httputil.ErrorResponse(w, r, "", http.StatusInternalServerError) + httpErr := &httputil.Error{Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } @@ -324,20 +344,23 @@ func (p *Proxy) Refresh(w http.ResponseWriter, r *http.Request) { // trying to DOS the identity provider. if time.Since(iss) < p.refreshCooldown { log.FromRequest(r).Error().Dur("cooldown", p.refreshCooldown).Err(err).Msg("proxy: refresh cooldown") - httputil.ErrorResponse(w, r, - fmt.Sprintf("Session must be %v old before refresh", p.refreshCooldown), - http.StatusBadRequest) + httpErr := &httputil.Error{ + Message: fmt.Sprintf("Session must be %v old before refresh", p.refreshCooldown), + Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } newSession, err := p.AuthenticateClient.Refresh(r.Context(), session) if err != nil { log.FromRequest(r).Warn().Err(err).Msg("proxy: refresh failed") - httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } if err = p.sessionStore.SaveSession(w, r, newSession); err != nil { - httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } http.Redirect(w, r, "/.pomerium", http.StatusFound) @@ -350,27 +373,34 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { if err := r.ParseForm(); err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: impersonate form") - httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } session, err := p.sessionStore.LoadSession(r) if err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: load session") - httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } // authorization check -- is this user an admin? isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) if err != nil || !isAdmin { log.FromRequest(r).Error().Err(err).Msg("proxy: user must be admin to impersonate") - httputil.ErrorResponse(w, r, "user must be admin to impersonate", http.StatusForbidden) + httpErr := &httputil.Error{ + Message: fmt.Sprintf("%s must be and administrator", session.Email), + Code: http.StatusForbidden, + CanDebug: true} + httputil.ErrorResponse(w, r, httpErr) return } // CSRF check -- did this request originate from our form? c, err := p.csrfStore.GetCSRF(r) if err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing csrf cookie") - httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } p.csrfStore.ClearCSRF(w, r) @@ -378,12 +408,14 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) { decryptedCSRF := new(StateParameter) if err = p.cipher.Unmarshal(encryptedCSRF, decryptedCSRF); err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: couldn't unmarshal CSRF") - httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError) + httpErr := &httputil.Error{Message: "Internal error", Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } if decryptedCSRF.SessionID != r.FormValue("csrf") { log.FromRequest(r).Error().Err(err).Msg("proxy: impersonate CSRF mismatch") - httputil.ErrorResponse(w, r, "CSRF mismatch", http.StatusForbidden) + httpErr := &httputil.Error{Message: "CSRF mismatch", Code: http.StatusForbidden} + httputil.ErrorResponse(w, r, httpErr) return } @@ -393,7 +425,8 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) { if err := p.sessionStore.SaveSession(w, r, session); err != nil { log.FromRequest(r).Error().Err(err).Msg("proxy: save session") - httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) + httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusInternalServerError} + httputil.ErrorResponse(w, r, httpErr) return } } @@ -510,7 +543,8 @@ func websocketHandlerFunc(baseHandler http.Handler, timeoutHandler http.Handler, } log.FromRequest(r).Warn().Msg("proxy: attempt to proxy a websocket connection, but websocket support is disabled in the configuration") - httputil.ErrorResponse(w, r, "websockets not supported by proxy", http.StatusBadRequest) + httpErr := &httputil.Error{Message: "websockets not supported by proxy", Code: http.StatusBadRequest} + httputil.ErrorResponse(w, r, httpErr) return } diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 2fbca421b..7545e93d9 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -451,6 +451,7 @@ func TestProxy_Impersonate(t *testing.T) { tests := []struct { name string + malformed bool options *config.Options method string email string @@ -463,14 +464,15 @@ func TestProxy_Impersonate(t *testing.T) { authorizer clients.Authorizer wantStatus int }{ - {"good", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, - {"session load error", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, - {"non admin users rejected", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden}, - {"non admin users rejected on error", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden}, - {"csrf from store retrieve failure", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}, GetError: errors.New("err")}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest}, - {"can't decrypt csrf value", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{UnmarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, - {"decrypted csrf mismatch", opts, http.MethodPost, "user@blah.com", "", "CSRF!", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusForbidden}, - {"save session failure", opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, + {"good", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, + {"session load error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, + {"non admin users rejected", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden}, + {"non admin users rejected on error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden}, + {"csrf from store retrieve failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}, GetError: errors.New("err")}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest}, + {"can't decrypt csrf value", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{UnmarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, + {"decrypted csrf mismatch", false, opts, http.MethodPost, "user@blah.com", "", "CSRF!", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusForbidden}, + {"save session failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, + {"malformed", true, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.SessionState{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthenticate{}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -487,8 +489,12 @@ func TestProxy_Impersonate(t *testing.T) { postForm.Add("email", tt.email) postForm.Add("group", tt.groups) postForm.Set("csrf", tt.csrf) + uri := &url.URL{Path: "/"} + if tt.malformed { + uri.RawQuery = "email=%zzzzz" + } + r := httptest.NewRequest(tt.method, uri.String(), bytes.NewBufferString(postForm.Encode())) - r := httptest.NewRequest(tt.method, "/", bytes.NewBufferString(postForm.Encode())) r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") w := httptest.NewRecorder() p.Impersonate(w, r)