mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 02:46:30 +02:00
authenticate: handle XHR redirect flow (#387)
- authenticate: add cors preflight check support for sign_in endpoint - internal/httputil: indicate responses that originate from pomerium vs the app - proxy: detect XHR requests and do not redirect on failure. - authenticate: removed default session duration; should be maintained out of band with rpc.
This commit is contained in:
parent
9030bd32cb
commit
00c29f4e77
11 changed files with 128 additions and 35 deletions
|
@ -10,8 +10,9 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/csrf"
|
"github.com/rs/cors"
|
||||||
|
|
||||||
|
"github.com/pomerium/csrf"
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
@ -51,6 +52,14 @@ func (a *Authenticate) Handler() http.Handler {
|
||||||
|
|
||||||
// Proxy service endpoints
|
// Proxy service endpoints
|
||||||
v := r.PathPrefix("/.pomerium").Subrouter()
|
v := r.PathPrefix("/.pomerium").Subrouter()
|
||||||
|
c := cors.New(cors.Options{
|
||||||
|
AllowOriginRequestFunc: func(r *http.Request, _ string) bool {
|
||||||
|
return middleware.ValidateRedirectURI(r, a.sharedKey)
|
||||||
|
},
|
||||||
|
AllowCredentials: true,
|
||||||
|
AllowedHeaders: []string{"*"},
|
||||||
|
})
|
||||||
|
v.Use(c.Handler)
|
||||||
v.Use(middleware.ValidateSignature(a.sharedKey))
|
v.Use(middleware.ValidateSignature(a.sharedKey))
|
||||||
v.Use(sessions.RetrieveSession(a.sessionLoaders...))
|
v.Use(sessions.RetrieveSession(a.sessionLoaders...))
|
||||||
v.Use(a.VerifySession)
|
v.Use(a.VerifySession)
|
||||||
|
@ -73,15 +82,15 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
||||||
if errors.Is(err, sessions.ErrExpired) {
|
if errors.Is(err, sessions.ErrExpired) {
|
||||||
if err := a.refresh(w, r, state); err != nil {
|
if err := a.refresh(w, r, state); err != nil {
|
||||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
|
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
|
||||||
a.redirectToIdentityProvider(w, r)
|
a.reauthenticateOrFail(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// redirect to restart middleware-chain following refresh
|
// redirect to restart middleware-chain following refresh
|
||||||
http.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound)
|
httputil.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound)
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session")
|
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session")
|
||||||
a.redirectToIdentityProvider(w, r)
|
a.reauthenticateOrFail(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
|
@ -167,7 +176,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
||||||
// build our hmac-d redirect URL with our session, pointing back to the
|
// build our hmac-d redirect URL with our session, pointing back to the
|
||||||
// proxy's callback URL which is responsible for setting our new route-session
|
// proxy's callback URL which is responsible for setting our new route-session
|
||||||
uri := urlutil.SignedRedirectURL(a.sharedKey, callbackURL, redirectURL)
|
uri := urlutil.SignedRedirectURL(a.sharedKey, callbackURL, redirectURL)
|
||||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
httputil.Redirect(w, r, uri.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignOut signs the user out and attempts to revoke the user's identity session
|
// SignOut signs the user out and attempts to revoke the user's identity session
|
||||||
|
@ -189,16 +198,25 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// redirectToIdentityProvider starts the authenticate process by redirecting the
|
// reauthenticateOrFail starts the authenticate process by redirecting the
|
||||||
// user to their respective identity provider. This function also builds the
|
// user to their respective identity provider. This function also builds the
|
||||||
// 'state' parameter which is encrypted and includes authenticating data
|
// 'state' parameter which is encrypted and includes authenticating data
|
||||||
// for validation.
|
// for validation.
|
||||||
|
// If the request is a `xhr/ajax` request (e.g the `X-Requested-With` header)
|
||||||
|
// is set do not redirect but instead return 401 unauthorized.
|
||||||
|
//
|
||||||
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
|
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
|
||||||
// https://tools.ietf.org/html/rfc6749#section-4.2.1
|
// https://tools.ietf.org/html/rfc6749#section-4.2.1
|
||||||
func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
// https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest
|
||||||
|
func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
// If request AJAX/XHR request, return a 401 instead .
|
||||||
|
if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
a.sessionStore.ClearSession(w, r)
|
a.sessionStore.ClearSession(w, r)
|
||||||
redirectURL := a.RedirectURL.ResolveReference(r.URL)
|
redirectURL := a.RedirectURL.ResolveReference(r.URL)
|
||||||
nonce := csrf.Token(r)
|
nonce := csrf.Token(r)
|
||||||
|
@ -207,7 +225,7 @@ func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http
|
||||||
enc := cryptutil.Encrypt(a.cookieCipher, []byte(redirectURL.String()), b)
|
enc := cryptutil.Encrypt(a.cookieCipher, []byte(redirectURL.String()), b)
|
||||||
b = append(b, enc...)
|
b = append(b, enc...)
|
||||||
encodedState := base64.URLEncoding.EncodeToString(b)
|
encodedState := base64.URLEncoding.EncodeToString(b)
|
||||||
http.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
|
httputil.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthCallback handles the callback from the identity provider.
|
// OAuthCallback handles the callback from the identity provider.
|
||||||
|
@ -220,7 +238,7 @@ func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err))
|
httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
http.Redirect(w, r, redirect.String(), http.StatusFound)
|
httputil.Redirect(w, r, redirect.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
|
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
|
||||||
|
|
|
@ -69,6 +69,25 @@ func TestAuthenticate_Handler(t *testing.T) {
|
||||||
if body != expected {
|
if body != expected {
|
||||||
t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
|
t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cors preflight
|
||||||
|
req = httptest.NewRequest(http.MethodOptions, "/.pomerium/sign_in", nil)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Access-Control-Request-Method", "GET")
|
||||||
|
req.Header.Set("Access-Control-Request-Headers", "X-Requested-With")
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rr, req)
|
||||||
|
expected = fmt.Sprintf("User-agent: *\nDisallow: /")
|
||||||
|
code := rr.Code
|
||||||
|
if code != http.StatusOK {
|
||||||
|
t.Errorf("bad preflight code")
|
||||||
|
}
|
||||||
|
resp := rr.Result()
|
||||||
|
body = resp.Header.Get("vary")
|
||||||
|
if body == "" {
|
||||||
|
t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthenticate_SignIn(t *testing.T) {
|
func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
|
|
3
go.mod
3
go.mod
|
@ -18,7 +18,8 @@ require (
|
||||||
github.com/pomerium/go-oidc v2.0.0+incompatible
|
github.com/pomerium/go-oidc v2.0.0+incompatible
|
||||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||||
github.com/prometheus/client_golang v0.9.3
|
github.com/prometheus/client_golang v0.9.3
|
||||||
github.com/rs/zerolog v1.14.3
|
github.com/rs/cors v1.7.0
|
||||||
|
github.com/rs/zerolog v1.16.0
|
||||||
github.com/spf13/afero v1.2.2 // indirect
|
github.com/spf13/afero v1.2.2 // indirect
|
||||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||||
github.com/spf13/pflag v1.0.5 // indirect
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
|
|
5
go.sum
5
go.sum
|
@ -173,9 +173,13 @@ github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40T
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||||
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
|
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
|
||||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||||
|
github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik=
|
||||||
|
github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU=
|
||||||
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
||||||
github.com/rs/zerolog v1.14.3 h1:4EGfSkR2hJDB0s3oFfrlPqjU1e4WLncergLil3nEKW0=
|
github.com/rs/zerolog v1.14.3 h1:4EGfSkR2hJDB0s3oFfrlPqjU1e4WLncergLil3nEKW0=
|
||||||
github.com/rs/zerolog v1.14.3/go.mod h1:3WXPzbXEEliJ+a6UFE4vhIxV8qR1EML6ngzP9ug4eYg=
|
github.com/rs/zerolog v1.14.3/go.mod h1:3WXPzbXEEliJ+a6UFE4vhIxV8qR1EML6ngzP9ug4eYg=
|
||||||
|
github.com/rs/zerolog v1.16.0 h1:AaELmZdcJHT8m6oZ5py4213cdFK8XGXkB3dFdAQ+P7Q=
|
||||||
|
github.com/rs/zerolog v1.16.0/go.mod h1:9nvC1axdVrAHcu/s9taAVfBuIdTZLVQmKQyvrUjF5+I=
|
||||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||||
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
|
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
|
||||||
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||||
|
@ -309,6 +313,7 @@ golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgw
|
||||||
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||||
golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||||
golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
|
golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
golang.org/x/tools v0.0.0-20191010171213-8abd42400456/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
golang.org/x/tools v0.0.0-20191010171213-8abd42400456/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
|
|
9
internal/httputil/constants.go
Normal file
9
internal/httputil/constants.go
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// HeaderPomeriumResponse is set when pomerium itself creates a response,
|
||||||
|
// as opposed to the downstream application and can be used to distinguish
|
||||||
|
// between an application error, and a pomerium related error when debugging.
|
||||||
|
// Especially useful when working with single page apps (SPA).
|
||||||
|
HeaderPomeriumResponse = "x-pomerium-intercepted-response"
|
||||||
|
)
|
|
@ -50,7 +50,7 @@ func (e *httpError) Debugable() bool {
|
||||||
// ErrorResponse renders an error page given an error. If the error is a
|
// ErrorResponse renders an error page given an error. If the error is a
|
||||||
// http error from this package, a user friendly message is set, http status code,
|
// http error from this package, a user friendly message is set, http status code,
|
||||||
// the ability to debug are also set.
|
// the ability to debug are also set.
|
||||||
func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) {
|
func ErrorResponse(w http.ResponseWriter, r *http.Request, e error) {
|
||||||
statusCode := http.StatusInternalServerError // default status code to return
|
statusCode := http.StatusInternalServerError // default status code to return
|
||||||
errorString := e.Error()
|
errorString := e.Error()
|
||||||
var canDebug bool
|
var canDebug bool
|
||||||
|
@ -63,6 +63,9 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) {
|
||||||
errorString = httpError.Message
|
errorString = httpError.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// indicate to clients that the error originates from Pomerium, not the app
|
||||||
|
w.Header().Set(HeaderPomeriumResponse, "true")
|
||||||
|
|
||||||
log.FromRequest(r).Error().Err(e).Str("http-message", errorString).Int("http-code", statusCode).Msg("http-error")
|
log.FromRequest(r).Error().Err(e).Str("http-message", errorString).Int("http-code", statusCode).Msg("http-error")
|
||||||
|
|
||||||
if id, ok := log.IDFromRequest(r); ok {
|
if id, ok := log.IDFromRequest(r); ok {
|
||||||
|
@ -73,9 +76,9 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) {
|
||||||
Error string `json:"error"`
|
Error string `json:"error"`
|
||||||
}
|
}
|
||||||
response.Error = errorString
|
response.Error = errorString
|
||||||
writeJSONResponse(rw, statusCode, response)
|
writeJSONResponse(w, statusCode, response)
|
||||||
} else {
|
} else {
|
||||||
rw.WriteHeader(statusCode)
|
w.WriteHeader(statusCode)
|
||||||
t := struct {
|
t := struct {
|
||||||
Code int
|
Code int
|
||||||
Title string
|
Title string
|
||||||
|
@ -89,17 +92,17 @@ func ErrorResponse(rw http.ResponseWriter, r *http.Request, e error) {
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
CanDebug: canDebug,
|
CanDebug: canDebug,
|
||||||
}
|
}
|
||||||
templates.New().ExecuteTemplate(rw, "error.html", t)
|
templates.New().ExecuteTemplate(w, "error.html", t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeJSONResponse is a helper that sets the application/json header and writes a response.
|
// writeJSONResponse is a helper that sets the application/json header and writes a response.
|
||||||
func writeJSONResponse(rw http.ResponseWriter, code int, response interface{}) {
|
func writeJSONResponse(w http.ResponseWriter, code int, response interface{}) {
|
||||||
rw.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
rw.WriteHeader(code)
|
w.WriteHeader(code)
|
||||||
|
|
||||||
err := json.NewEncoder(rw).Encode(response)
|
err := json.NewEncoder(w).Encode(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
io.WriteString(rw, err.Error())
|
io.WriteString(w, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,3 +17,10 @@ func HealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(http.StatusText(http.StatusOK)))
|
w.Write([]byte(http.StatusText(http.StatusOK)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Redirect wraps the std libs's redirect method indicating that pomerium is
|
||||||
|
// the origin of the response.
|
||||||
|
func Redirect(w http.ResponseWriter, r *http.Request, url string, code int) {
|
||||||
|
w.Header().Set(HeaderPomeriumResponse, "true")
|
||||||
|
http.Redirect(w, r, url, code)
|
||||||
|
}
|
||||||
|
|
|
@ -35,3 +35,34 @@ func TestHealthCheck(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRedirect(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
|
||||||
|
url string
|
||||||
|
code int
|
||||||
|
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{"good", http.MethodGet, "https://pomerium.io", http.StatusFound, http.StatusFound},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
r := httptest.NewRequest(tt.method, "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
Redirect(w, r, tt.url, tt.code)
|
||||||
|
if w.Code != tt.wantStatus {
|
||||||
|
t.Errorf("code differs. got %d want %d body: %s", w.Code, tt.wantStatus, w.Body.String())
|
||||||
|
}
|
||||||
|
if w.Result().Header.Get(HeaderPomeriumResponse) == "" {
|
||||||
|
t.Errorf("pomerium header not found")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -34,25 +34,25 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature")
|
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
if !ValidateRedirectURI(r, sharedSecret) {
|
||||||
err := r.ParseForm()
|
|
||||||
if err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("couldn't parse form", http.StatusBadRequest, err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
redirectURI := r.Form.Get("redirect_uri")
|
|
||||||
sigVal := r.Form.Get("sig")
|
|
||||||
timestamp := r.Form.Get("ts")
|
|
||||||
if !ValidSignature(redirectURI, sigVal, timestamp, sharedSecret) {
|
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil))
|
httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateRedirectURI takes a request and parses `redirect_uri`, `sig`, `ts`
|
||||||
|
// and validates the supplied signature (`sig`)'s HMAC for validity.
|
||||||
|
func ValidateRedirectURI(r *http.Request, key string) bool {
|
||||||
|
return ValidSignature(
|
||||||
|
r.FormValue("redirect_uri"),
|
||||||
|
r.FormValue("sig"),
|
||||||
|
r.FormValue("ts"),
|
||||||
|
key)
|
||||||
|
}
|
||||||
|
|
||||||
// Healthcheck endpoint middleware useful to setting up a path like
|
// Healthcheck endpoint middleware useful to setting up a path like
|
||||||
// `/ping` that load balancers or uptime testing external services
|
// `/ping` that load balancers or uptime testing external services
|
||||||
// can make a request before hitting any routes. It's also convenient
|
// can make a request before hitting any routes. It's also convenient
|
||||||
|
|
|
@ -67,7 +67,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
|
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
|
||||||
p.sessionStore.ClearSession(w, r)
|
p.sessionStore.ClearSession(w, r)
|
||||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
httputil.Redirect(w, r, uri.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserDashboard lets users investigate, and refresh their current session.
|
// UserDashboard lets users investigate, and refresh their current session.
|
||||||
|
@ -117,7 +117,7 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
|
||||||
q.Add("impersonate_group", r.FormValue("group"))
|
q.Add("impersonate_group", r.FormValue("group"))
|
||||||
redirectURL.RawQuery = q.Encode()
|
redirectURL.RawQuery = q.Encode()
|
||||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String()
|
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String()
|
||||||
http.Redirect(w, r, uri, http.StatusFound)
|
httputil.Redirect(w, r, uri, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) registerFwdAuthHandlers() http.Handler {
|
func (p *Proxy) registerFwdAuthHandlers() http.Handler {
|
||||||
|
@ -198,7 +198,7 @@ func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
redirectURL.RawQuery = q.Encode()
|
redirectURL.RawQuery = q.Encode()
|
||||||
|
|
||||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProgrammaticLogin returns a signed url that can be used to login
|
// ProgrammaticLogin returns a signed url that can be used to login
|
||||||
|
|
|
@ -51,7 +51,7 @@ func (p *Proxy) authenticate(errOnFailure bool, w http.ResponseWriter, r *http.R
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
||||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
httputil.Redirect(w, r, uri.String(), http.StatusFound)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// add pomerium's headers to the downstream request
|
// add pomerium's headers to the downstream request
|
||||||
|
|
Loading…
Add table
Reference in a new issue