mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
authenticate: handle XHR redirect flow (#387)
- authenticate: add cors preflight check support for sign_in endpoint - internal/httputil: indicate responses that originate from pomerium vs the app - proxy: detect XHR requests and do not redirect on failure. - authenticate: removed default session duration; should be maintained out of band with rpc.
This commit is contained in:
parent
9030bd32cb
commit
00c29f4e77
11 changed files with 128 additions and 35 deletions
|
@ -10,8 +10,9 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/csrf"
|
||||
"github.com/rs/cors"
|
||||
|
||||
"github.com/pomerium/csrf"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -51,6 +52,14 @@ func (a *Authenticate) Handler() http.Handler {
|
|||
|
||||
// Proxy service endpoints
|
||||
v := r.PathPrefix("/.pomerium").Subrouter()
|
||||
c := cors.New(cors.Options{
|
||||
AllowOriginRequestFunc: func(r *http.Request, _ string) bool {
|
||||
return middleware.ValidateRedirectURI(r, a.sharedKey)
|
||||
},
|
||||
AllowCredentials: true,
|
||||
AllowedHeaders: []string{"*"},
|
||||
})
|
||||
v.Use(c.Handler)
|
||||
v.Use(middleware.ValidateSignature(a.sharedKey))
|
||||
v.Use(sessions.RetrieveSession(a.sessionLoaders...))
|
||||
v.Use(a.VerifySession)
|
||||
|
@ -73,15 +82,15 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
|||
if errors.Is(err, sessions.ErrExpired) {
|
||||
if err := a.refresh(w, r, state); err != nil {
|
||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
|
||||
a.redirectToIdentityProvider(w, r)
|
||||
a.reauthenticateOrFail(w, r, err)
|
||||
return
|
||||
}
|
||||
// redirect to restart middleware-chain following refresh
|
||||
http.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound)
|
||||
httputil.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound)
|
||||
return
|
||||
} else if err != nil {
|
||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session")
|
||||
a.redirectToIdentityProvider(w, r)
|
||||
a.reauthenticateOrFail(w, r, err)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
|
@ -167,7 +176,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
|||
// build our hmac-d redirect URL with our session, pointing back to the
|
||||
// proxy's callback URL which is responsible for setting our new route-session
|
||||
uri := urlutil.SignedRedirectURL(a.sharedKey, callbackURL, redirectURL)
|
||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
httputil.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// SignOut signs the user out and attempts to revoke the user's identity session
|
||||
|
@ -189,16 +198,25 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
|||
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// redirectToIdentityProvider starts the authenticate process by redirecting the
|
||||
// reauthenticateOrFail starts the authenticate process by redirecting the
|
||||
// user to their respective identity provider. This function also builds the
|
||||
// 'state' parameter which is encrypted and includes authenticating data
|
||||
// for validation.
|
||||
// If the request is a `xhr/ajax` request (e.g the `X-Requested-With` header)
|
||||
// is set do not redirect but instead return 401 unauthorized.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
|
||||
// https://tools.ietf.org/html/rfc6749#section-4.2.1
|
||||
func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
// https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest
|
||||
func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) {
|
||||
// If request AJAX/XHR request, return a 401 instead .
|
||||
if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") {
|
||||
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
|
||||
return
|
||||
}
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
redirectURL := a.RedirectURL.ResolveReference(r.URL)
|
||||
nonce := csrf.Token(r)
|
||||
|
@ -207,7 +225,7 @@ func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http
|
|||
enc := cryptutil.Encrypt(a.cookieCipher, []byte(redirectURL.String()), b)
|
||||
b = append(b, enc...)
|
||||
encodedState := base64.URLEncoding.EncodeToString(b)
|
||||
http.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
|
||||
httputil.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
|
||||
}
|
||||
|
||||
// OAuthCallback handles the callback from the identity provider.
|
||||
|
@ -220,7 +238,7 @@ func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|||
httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err))
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, redirect.String(), http.StatusFound)
|
||||
httputil.Redirect(w, r, redirect.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
|
||||
|
|
|
@ -69,6 +69,25 @@ func TestAuthenticate_Handler(t *testing.T) {
|
|||
if body != expected {
|
||||
t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
|
||||
}
|
||||
|
||||
// cors preflight
|
||||
req = httptest.NewRequest(http.MethodOptions, "/.pomerium/sign_in", nil)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Access-Control-Request-Method", "GET")
|
||||
req.Header.Set("Access-Control-Request-Headers", "X-Requested-With")
|
||||
rr = httptest.NewRecorder()
|
||||
h.ServeHTTP(rr, req)
|
||||
expected = fmt.Sprintf("User-agent: *\nDisallow: /")
|
||||
code := rr.Code
|
||||
if code != http.StatusOK {
|
||||
t.Errorf("bad preflight code")
|
||||
}
|
||||
resp := rr.Result()
|
||||
body = resp.Header.Get("vary")
|
||||
if body == "" {
|
||||
t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAuthenticate_SignIn(t *testing.T) {
|
||||
|
|
3
go.mod
3
go.mod
|
@ -18,7 +18,8 @@ require (
|
|||
github.com/pomerium/go-oidc v2.0.0+incompatible
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||
github.com/prometheus/client_golang v0.9.3
|
||||
github.com/rs/zerolog v1.14.3
|
||||
github.com/rs/cors v1.7.0
|
||||
github.com/rs/zerolog v1.16.0
|
||||
github.com/spf13/afero v1.2.2 // indirect
|
||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
|
|
5
go.sum
5
go.sum
|
@ -173,9 +173,13 @@ github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40T
|
|||
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik=
|
||||
github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU=
|
||||
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
||||
github.com/rs/zerolog v1.14.3 h1:4EGfSkR2hJDB0s3oFfrlPqjU1e4WLncergLil3nEKW0=
|
||||
github.com/rs/zerolog v1.14.3/go.mod h1:3WXPzbXEEliJ+a6UFE4vhIxV8qR1EML6ngzP9ug4eYg=
|
||||
github.com/rs/zerolog v1.16.0 h1:AaELmZdcJHT8m6oZ5py4213cdFK8XGXkB3dFdAQ+P7Q=
|
||||
github.com/rs/zerolog v1.16.0/go.mod h1:9nvC1axdVrAHcu/s9taAVfBuIdTZLVQmKQyvrUjF5+I=
|
||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
|
||||
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||
|
@ -309,6 +313,7 @@ golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgw
|
|||
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||
golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||
golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20191010171213-8abd42400456/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
|
|
9
internal/httputil/constants.go
Normal file
9
internal/httputil/constants.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||
|
||||
const (
|
||||
// HeaderPomeriumResponse is set when pomerium itself creates a response,
|
||||
// as opposed to the downstream application and can be used to distinguish
|
||||
// between an application error, and a pomerium related error when debugging.
|
||||
// Especially useful when working with single page apps (SPA).
|
||||
HeaderPomeriumResponse = "x-pomerium-intercepted-response"
|
||||
)
|
|
@ -50,7 +50,7 @@ func (e *httpError) Debugable() bool {
|
|||
// ErrorResponse renders an error page given an error. If the error is a
|
||||
// http error from this package, a user friendly message is set, http status code,
|
||||
// the ability to debug are also set.
|
||||
func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) {
|
||||
func ErrorResponse(w http.ResponseWriter, r *http.Request, e error) {
|
||||
statusCode := http.StatusInternalServerError // default status code to return
|
||||
errorString := e.Error()
|
||||
var canDebug bool
|
||||
|
@ -63,6 +63,9 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) {
|
|||
errorString = httpError.Message
|
||||
}
|
||||
|
||||
// indicate to clients that the error originates from Pomerium, not the app
|
||||
w.Header().Set(HeaderPomeriumResponse, "true")
|
||||
|
||||
log.FromRequest(r).Error().Err(e).Str("http-message", errorString).Int("http-code", statusCode).Msg("http-error")
|
||||
|
||||
if id, ok := log.IDFromRequest(r); ok {
|
||||
|
@ -73,9 +76,9 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) {
|
|||
Error string `json:"error"`
|
||||
}
|
||||
response.Error = errorString
|
||||
writeJSONResponse(rw, statusCode, response)
|
||||
writeJSONResponse(w, statusCode, response)
|
||||
} else {
|
||||
rw.WriteHeader(statusCode)
|
||||
w.WriteHeader(statusCode)
|
||||
t := struct {
|
||||
Code int
|
||||
Title string
|
||||
|
@ -89,17 +92,17 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) {
|
|||
RequestID: requestID,
|
||||
CanDebug: canDebug,
|
||||
}
|
||||
templates.New().ExecuteTemplate(rw, "error.html", t)
|
||||
templates.New().ExecuteTemplate(w, "error.html", t)
|
||||
}
|
||||
}
|
||||
|
||||
// writeJSONResponse is a helper that sets the application/json header and writes a response.
|
||||
func writeJSONResponse(rw http.ResponseWriter, code int, response interface{}) {
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(code)
|
||||
func writeJSONResponse(w http.ResponseWriter, code int, response interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
|
||||
err := json.NewEncoder(rw).Encode(response)
|
||||
err := json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
io.WriteString(rw, err.Error())
|
||||
io.WriteString(w, err.Error())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,3 +17,10 @@ func HealthCheck(w http.ResponseWriter, r *http.Request) {
|
|||
w.Write([]byte(http.StatusText(http.StatusOK)))
|
||||
}
|
||||
}
|
||||
|
||||
// Redirect wraps the std libs's redirect method indicating that pomerium is
|
||||
// the origin of the response.
|
||||
func Redirect(w http.ResponseWriter, r *http.Request, url string, code int) {
|
||||
w.Header().Set(HeaderPomeriumResponse, "true")
|
||||
http.Redirect(w, r, url, code)
|
||||
}
|
||||
|
|
|
@ -35,3 +35,34 @@ func TestHealthCheck(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirect(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
|
||||
url string
|
||||
code int
|
||||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", http.MethodGet, "https://pomerium.io", http.StatusFound, http.StatusFound},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
r := httptest.NewRequest(tt.method, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
Redirect(w, r, tt.url, tt.code)
|
||||
if w.Code != tt.wantStatus {
|
||||
t.Errorf("code differs. got %d want %d body: %s", w.Code, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
if w.Result().Header.Get(HeaderPomeriumResponse) == "" {
|
||||
t.Errorf("pomerium header not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,25 +34,25 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
|
|||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature")
|
||||
defer span.End()
|
||||
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("couldn't parse form", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
redirectURI := r.Form.Get("redirect_uri")
|
||||
sigVal := r.Form.Get("sig")
|
||||
timestamp := r.Form.Get("ts")
|
||||
if !ValidSignature(redirectURI, sigVal, timestamp, sharedSecret) {
|
||||
if !ValidateRedirectURI(r, sharedSecret) {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateRedirectURI takes a request and parses `redirect_uri`, `sig`, `ts`
|
||||
// and validates the supplied signature (`sig`)'s HMAC for validity.
|
||||
func ValidateRedirectURI(r *http.Request, key string) bool {
|
||||
return ValidSignature(
|
||||
r.FormValue("redirect_uri"),
|
||||
r.FormValue("sig"),
|
||||
r.FormValue("ts"),
|
||||
key)
|
||||
}
|
||||
|
||||
// Healthcheck endpoint middleware useful to setting up a path like
|
||||
// `/ping` that load balancers or uptime testing external services
|
||||
// can make a request before hitting any routes. It's also convenient
|
||||
|
|
|
@ -67,7 +67,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
httputil.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// UserDashboard lets users investigate, and refresh their current session.
|
||||
|
@ -117,7 +117,7 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
|
|||
q.Add("impersonate_group", r.FormValue("group"))
|
||||
redirectURL.RawQuery = q.Encode()
|
||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String()
|
||||
http.Redirect(w, r, uri, http.StatusFound)
|
||||
httputil.Redirect(w, r, uri, http.StatusFound)
|
||||
}
|
||||
|
||||
func (p *Proxy) registerFwdAuthHandlers() http.Handler {
|
||||
|
@ -198,7 +198,7 @@ func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
redirectURL.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// ProgrammaticLogin returns a signed url that can be used to login
|
||||
|
|
|
@ -51,7 +51,7 @@ func (p *Proxy) authenticate(errOnFailure bool, w http.ResponseWriter, r *http.R
|
|||
return err
|
||||
}
|
||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
httputil.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
return err
|
||||
}
|
||||
// add pomerium's headers to the downstream request
|
||||
|
|
Loading…
Add table
Reference in a new issue