From 00c29f4e77e36cb6f44c028361785b399aabb616 Mon Sep 17 00:00:00 2001 From: Bobby DeSimone Date: Thu, 14 Nov 2019 19:37:31 -0800 Subject: [PATCH] authenticate: handle XHR redirect flow (#387) - authenticate: add cors preflight check support for sign_in endpoint - internal/httputil: indicate responses that originate from pomerium vs the app - proxy: detect XHR requests and do not redirect on failure. - authenticate: removed default session duration; should be maintained out of band with rpc. --- authenticate/handlers.go | 38 ++++++++++++++++++++++-------- authenticate/handlers_test.go | 19 +++++++++++++++ go.mod | 3 ++- go.sum | 5 ++++ internal/httputil/constants.go | 9 +++++++ internal/httputil/errors.go | 21 ++++++++++------- internal/httputil/handlers.go | 7 ++++++ internal/httputil/handlers_test.go | 31 ++++++++++++++++++++++++ internal/middleware/middleware.go | 22 ++++++++--------- proxy/handlers.go | 6 ++--- proxy/middleware.go | 2 +- 11 files changed, 128 insertions(+), 35 deletions(-) create mode 100644 internal/httputil/constants.go diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 4506aa4a0..acfe51190 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -10,8 +10,9 @@ import ( "strings" "time" - "github.com/pomerium/csrf" + "github.com/rs/cors" + "github.com/pomerium/csrf" "github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" @@ -51,6 +52,14 @@ func (a *Authenticate) Handler() http.Handler { // Proxy service endpoints v := r.PathPrefix("/.pomerium").Subrouter() + c := cors.New(cors.Options{ + AllowOriginRequestFunc: func(r *http.Request, _ string) bool { + return middleware.ValidateRedirectURI(r, a.sharedKey) + }, + AllowCredentials: true, + AllowedHeaders: []string{"*"}, + }) + v.Use(c.Handler) v.Use(middleware.ValidateSignature(a.sharedKey)) v.Use(sessions.RetrieveSession(a.sessionLoaders...)) v.Use(a.VerifySession) @@ -73,15 +82,15 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { if errors.Is(err, sessions.ErrExpired) { if err := a.refresh(w, r, state); err != nil { log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh") - a.redirectToIdentityProvider(w, r) + a.reauthenticateOrFail(w, r, err) return } // redirect to restart middleware-chain following refresh - http.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound) + httputil.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound) return } else if err != nil { log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session") - a.redirectToIdentityProvider(w, r) + a.reauthenticateOrFail(w, r, err) return } next.ServeHTTP(w, r) @@ -167,7 +176,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { // build our hmac-d redirect URL with our session, pointing back to the // proxy's callback URL which is responsible for setting our new route-session uri := urlutil.SignedRedirectURL(a.sharedKey, callbackURL, redirectURL) - http.Redirect(w, r, uri.String(), http.StatusFound) + httputil.Redirect(w, r, uri.String(), http.StatusFound) } // SignOut signs the user out and attempts to revoke the user's identity session @@ -189,16 +198,25 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) return } - http.Redirect(w, r, redirectURL.String(), http.StatusFound) + httputil.Redirect(w, r, redirectURL.String(), http.StatusFound) } -// redirectToIdentityProvider starts the authenticate process by redirecting the +// reauthenticateOrFail starts the authenticate process by redirecting the // user to their respective identity provider. This function also builds the // 'state' parameter which is encrypted and includes authenticating data // for validation. +// If the request is a `xhr/ajax` request (e.g the `X-Requested-With` header) +// is set do not redirect but instead return 401 unauthorized. +// // https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest // https://tools.ietf.org/html/rfc6749#section-4.2.1 -func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http.Request) { +// https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest +func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) { + // If request AJAX/XHR request, return a 401 instead . + if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") { + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err)) + return + } a.sessionStore.ClearSession(w, r) redirectURL := a.RedirectURL.ResolveReference(r.URL) nonce := csrf.Token(r) @@ -207,7 +225,7 @@ func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http enc := cryptutil.Encrypt(a.cookieCipher, []byte(redirectURL.String()), b) b = append(b, enc...) encodedState := base64.URLEncoding.EncodeToString(b) - http.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound) + httputil.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound) } // OAuthCallback handles the callback from the identity provider. @@ -220,7 +238,7 @@ func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) { httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err)) return } - http.Redirect(w, r, redirect.String(), http.StatusFound) + httputil.Redirect(w, r, redirect.String(), http.StatusFound) } func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) { diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 38485569b..c856b4f99 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -69,6 +69,25 @@ func TestAuthenticate_Handler(t *testing.T) { if body != expected { t.Errorf("handler returned unexpected body: got %v want %v", body, expected) } + + // cors preflight + req = httptest.NewRequest(http.MethodOptions, "/.pomerium/sign_in", nil) + req.Header.Set("Accept", "application/json") + req.Header.Set("Access-Control-Request-Method", "GET") + req.Header.Set("Access-Control-Request-Headers", "X-Requested-With") + rr = httptest.NewRecorder() + h.ServeHTTP(rr, req) + expected = fmt.Sprintf("User-agent: *\nDisallow: /") + code := rr.Code + if code != http.StatusOK { + t.Errorf("bad preflight code") + } + resp := rr.Result() + body = resp.Header.Get("vary") + if body == "" { + t.Errorf("handler returned unexpected body: got %v want %v", body, expected) + } + } func TestAuthenticate_SignIn(t *testing.T) { diff --git a/go.mod b/go.mod index 33006ee69..a9d813c94 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,8 @@ require ( github.com/pomerium/go-oidc v2.0.0+incompatible github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/prometheus/client_golang v0.9.3 - github.com/rs/zerolog v1.14.3 + github.com/rs/cors v1.7.0 + github.com/rs/zerolog v1.16.0 github.com/spf13/afero v1.2.2 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect diff --git a/go.sum b/go.sum index 05ad589c1..956d38bd8 100644 --- a/go.sum +++ b/go.sum @@ -173,9 +173,13 @@ github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40T github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= +github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.14.3 h1:4EGfSkR2hJDB0s3oFfrlPqjU1e4WLncergLil3nEKW0= github.com/rs/zerolog v1.14.3/go.mod h1:3WXPzbXEEliJ+a6UFE4vhIxV8qR1EML6ngzP9ug4eYg= +github.com/rs/zerolog v1.16.0 h1:AaELmZdcJHT8m6oZ5py4213cdFK8XGXkB3dFdAQ+P7Q= +github.com/rs/zerolog v1.16.0/go.mod h1:9nvC1axdVrAHcu/s9taAVfBuIdTZLVQmKQyvrUjF5+I= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= @@ -309,6 +313,7 @@ golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgw golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191010171213-8abd42400456/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= diff --git a/internal/httputil/constants.go b/internal/httputil/constants.go new file mode 100644 index 000000000..fa92b4adc --- /dev/null +++ b/internal/httputil/constants.go @@ -0,0 +1,9 @@ +package httputil // import "github.com/pomerium/pomerium/internal/httputil" + +const ( + // HeaderPomeriumResponse is set when pomerium itself creates a response, + // as opposed to the downstream application and can be used to distinguish + // between an application error, and a pomerium related error when debugging. + // Especially useful when working with single page apps (SPA). + HeaderPomeriumResponse = "x-pomerium-intercepted-response" +) diff --git a/internal/httputil/errors.go b/internal/httputil/errors.go index 27c17f414..cc4a9547b 100644 --- a/internal/httputil/errors.go +++ b/internal/httputil/errors.go @@ -50,7 +50,7 @@ func (e *httpError) Debugable() bool { // ErrorResponse renders an error page given an error. If the error is a // http error from this package, a user friendly message is set, http status code, // the ability to debug are also set. -func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) { +func ErrorResponse(w http.ResponseWriter, r *http.Request, e error) { statusCode := http.StatusInternalServerError // default status code to return errorString := e.Error() var canDebug bool @@ -63,6 +63,9 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) { errorString = httpError.Message } + // indicate to clients that the error originates from Pomerium, not the app + w.Header().Set(HeaderPomeriumResponse, "true") + log.FromRequest(r).Error().Err(e).Str("http-message", errorString).Int("http-code", statusCode).Msg("http-error") if id, ok := log.IDFromRequest(r); ok { @@ -73,9 +76,9 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) { Error string `json:"error"` } response.Error = errorString - writeJSONResponse(rw, statusCode, response) + writeJSONResponse(w, statusCode, response) } else { - rw.WriteHeader(statusCode) + w.WriteHeader(statusCode) t := struct { Code int Title string @@ -89,17 +92,17 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) { RequestID: requestID, CanDebug: canDebug, } - templates.New().ExecuteTemplate(rw, "error.html", t) + templates.New().ExecuteTemplate(w, "error.html", t) } } // writeJSONResponse is a helper that sets the application/json header and writes a response. -func writeJSONResponse(rw http.ResponseWriter, code int, response interface{}) { - rw.Header().Set("Content-Type", "application/json") - rw.WriteHeader(code) +func writeJSONResponse(w http.ResponseWriter, code int, response interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) - err := json.NewEncoder(rw).Encode(response) + err := json.NewEncoder(w).Encode(response) if err != nil { - io.WriteString(rw, err.Error()) + io.WriteString(w, err.Error()) } } diff --git a/internal/httputil/handlers.go b/internal/httputil/handlers.go index da66743d7..48fd73463 100644 --- a/internal/httputil/handlers.go +++ b/internal/httputil/handlers.go @@ -17,3 +17,10 @@ func HealthCheck(w http.ResponseWriter, r *http.Request) { w.Write([]byte(http.StatusText(http.StatusOK))) } } + +// Redirect wraps the std libs's redirect method indicating that pomerium is +// the origin of the response. +func Redirect(w http.ResponseWriter, r *http.Request, url string, code int) { + w.Header().Set(HeaderPomeriumResponse, "true") + http.Redirect(w, r, url, code) +} diff --git a/internal/httputil/handlers_test.go b/internal/httputil/handlers_test.go index b2caead62..87e794859 100644 --- a/internal/httputil/handlers_test.go +++ b/internal/httputil/handlers_test.go @@ -35,3 +35,34 @@ func TestHealthCheck(t *testing.T) { }) } } + +func TestRedirect(t *testing.T) { + t.Parallel() + tests := []struct { + name string + method string + + url string + code int + + wantStatus int + }{ + {"good", http.MethodGet, "https://pomerium.io", http.StatusFound, http.StatusFound}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + r := httptest.NewRequest(tt.method, "/", nil) + w := httptest.NewRecorder() + + Redirect(w, r, tt.url, tt.code) + if w.Code != tt.wantStatus { + t.Errorf("code differs. got %d want %d body: %s", w.Code, tt.wantStatus, w.Body.String()) + } + if w.Result().Header.Get(HeaderPomeriumResponse) == "" { + t.Errorf("pomerium header not found") + } + }) + } +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 44c1254b5..d6dcf8374 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -34,25 +34,25 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature") defer span.End() - - err := r.ParseForm() - if err != nil { - httputil.ErrorResponse(w, r, httputil.Error("couldn't parse form", http.StatusBadRequest, err)) - return - } - redirectURI := r.Form.Get("redirect_uri") - sigVal := r.Form.Get("sig") - timestamp := r.Form.Get("ts") - if !ValidSignature(redirectURI, sigVal, timestamp, sharedSecret) { + if !ValidateRedirectURI(r, sharedSecret) { httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil)) return } - next.ServeHTTP(w, r.WithContext(ctx)) }) } } +// ValidateRedirectURI takes a request and parses `redirect_uri`, `sig`, `ts` +// and validates the supplied signature (`sig`)'s HMAC for validity. +func ValidateRedirectURI(r *http.Request, key string) bool { + return ValidSignature( + r.FormValue("redirect_uri"), + r.FormValue("sig"), + r.FormValue("ts"), + key) +} + // Healthcheck endpoint middleware useful to setting up a path like // `/ping` that load balancers or uptime testing external services // can make a request before hitting any routes. It's also convenient diff --git a/proxy/handlers.go b/proxy/handlers.go index e37eca584..1fa296a5d 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -67,7 +67,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) { } uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL) p.sessionStore.ClearSession(w, r) - http.Redirect(w, r, uri.String(), http.StatusFound) + httputil.Redirect(w, r, uri.String(), http.StatusFound) } // UserDashboard lets users investigate, and refresh their current session. @@ -117,7 +117,7 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) { q.Add("impersonate_group", r.FormValue("group")) redirectURL.RawQuery = q.Encode() uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String() - http.Redirect(w, r, uri, http.StatusFound) + httputil.Redirect(w, r, uri, http.StatusFound) } func (p *Proxy) registerFwdAuthHandlers() http.Handler { @@ -198,7 +198,7 @@ func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) { } redirectURL.RawQuery = q.Encode() - http.Redirect(w, r, redirectURL.String(), http.StatusFound) + httputil.Redirect(w, r, redirectURL.String(), http.StatusFound) } // ProgrammaticLogin returns a signed url that can be used to login diff --git a/proxy/middleware.go b/proxy/middleware.go index 365765d92..6b569cb9a 100644 --- a/proxy/middleware.go +++ b/proxy/middleware.go @@ -51,7 +51,7 @@ func (p *Proxy) authenticate(errOnFailure bool, w http.ResponseWriter, r *http.R return err } uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r)) - http.Redirect(w, r, uri.String(), http.StatusFound) + httputil.Redirect(w, r, uri.String(), http.StatusFound) return err } // add pomerium's headers to the downstream request