httputil : wrap handlers for additional context (#413)

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2019-12-06 11:07:45 -08:00 committed by GitHub
parent 487fc655d6
commit b3d3159185
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 495 additions and 463 deletions

View file

@ -32,12 +32,12 @@ func (a *Authenticate) Handler() http.Handler {
csrf.UnsafePaths([]string{callbackPath}), // enforce CSRF on "safe" handler csrf.UnsafePaths([]string{callbackPath}), // enforce CSRF on "safe" handler
csrf.FormValueName("state"), // rfc6749 section-10.12 csrf.FormValueName("state"), // rfc6749 section-10.12
csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieOptions.Name)), csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieOptions.Name)),
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)), csrf.ErrorHandler(httputil.HandlerFunc(httputil.CSRFFailureHandler)),
)) ))
r.HandleFunc("/robots.txt", a.RobotsTxt).Methods(http.MethodGet) r.Path("/robots.txt").HandlerFunc(a.RobotsTxt).Methods(http.MethodGet)
// Identity Provider (IdP) endpoints // Identity Provider (IdP) endpoints
r.HandleFunc("/oauth2/callback", a.OAuthCallback).Methods(http.MethodGet) r.Path("/oauth2/callback").Handler(httputil.HandlerFunc(a.OAuthCallback)).Methods(http.MethodGet)
// Proxy service endpoints // Proxy service endpoints
v := r.PathPrefix("/.pomerium").Subrouter() v := r.PathPrefix("/.pomerium").Subrouter()
@ -56,13 +56,13 @@ func (a *Authenticate) Handler() http.Handler {
v.Use(middleware.ValidateSignature(a.sharedKey)) v.Use(middleware.ValidateSignature(a.sharedKey))
v.Use(sessions.RetrieveSession(a.sessionLoaders...)) v.Use(sessions.RetrieveSession(a.sessionLoaders...))
v.Use(a.VerifySession) v.Use(a.VerifySession)
v.HandleFunc("/sign_in", a.SignIn) v.Path("/sign_in").Handler(httputil.HandlerFunc(a.SignIn))
v.HandleFunc("/sign_out", a.SignOut) v.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))
// programmatic access api endpoint // programmatic access api endpoint
api := r.PathPrefix("/api").Subrouter() api := r.PathPrefix("/api").Subrouter()
api.Use(sessions.RetrieveSession(a.sessionLoaders...)) api.Use(sessions.RetrieveSession(a.sessionLoaders...))
api.HandleFunc("/v1/refresh", a.RefreshAPI) api.Path("/v1/refresh").Handler(httputil.HandlerFunc(a.RefreshAPI))
return r return r
} }
@ -70,23 +70,22 @@ func (a *Authenticate) Handler() http.Handler {
// VerifySession is the middleware used to enforce a valid authentication // VerifySession is the middleware used to enforce a valid authentication
// session state is attached to the users's request context. // session state is attached to the users's request context.
func (a *Authenticate) VerifySession(next http.Handler) http.Handler { func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
state, err := sessions.FromContext(r.Context()) state, err := sessions.FromContext(r.Context())
if errors.Is(err, sessions.ErrExpired) { if errors.Is(err, sessions.ErrExpired) {
if err := a.refresh(w, r, state); err != nil { if err := a.refresh(w, r, state); err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh") log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
return
} }
// redirect to restart middleware-chain following refresh // redirect to restart middleware-chain following refresh
httputil.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound) httputil.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound)
return return nil
} else if err != nil { } else if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session") log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session")
a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
return
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return nil
}) })
} }
@ -109,11 +108,10 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
} }
// SignIn handles to authenticating a user. // SignIn handles to authenticating a user.
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
jwtAudience := []string{a.RedirectURL.Hostname(), redirectURL.Hostname()} jwtAudience := []string{a.RedirectURL.Hostname(), redirectURL.Hostname()}
@ -123,8 +121,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
if callbackStr := r.FormValue(urlutil.QueryCallbackURI); callbackStr != "" { if callbackStr := r.FormValue(urlutil.QueryCallbackURI); callbackStr != "" {
callbackURL, err = urlutil.ParseAndValidateURL(callbackStr) callbackURL, err = urlutil.ParseAndValidateURL(callbackStr)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
jwtAudience = append(jwtAudience, callbackURL.Hostname()) jwtAudience = append(jwtAudience, callbackURL.Hostname())
} else { } else {
@ -141,16 +138,14 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
s, err := sessions.FromContext(r.Context()) s, err := sessions.FromContext(r.Context())
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
// user impersonation // user impersonation
if impersonate := r.FormValue(urlutil.QueryImpersonateAction); impersonate != "" { if impersonate := r.FormValue(urlutil.QueryImpersonateAction); impersonate != "" {
s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups)) s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups))
if err := a.sessionStore.SaveSession(w, r, s); err != nil { if err := a.sessionStore.SaveSession(w, r, s); err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
} }
@ -162,8 +157,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
newSession.Programmatic = true newSession.Programmatic = true
encSession, err := a.encryptedEncoder.Marshal(newSession) encSession, err := a.encryptedEncoder.Marshal(newSession)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
callbackParams.Set(urlutil.QueryRefreshToken, string(encSession)) callbackParams.Set(urlutil.QueryRefreshToken, string(encSession))
callbackParams.Set(urlutil.QueryIsProgrammatic, "true") callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
@ -172,8 +167,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
// sign the route session, as a JWT // sign the route session, as a JWT
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession()) signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession())
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
// encrypt our route-based token JWT avoiding any accidental logging // encrypt our route-based token JWT avoiding any accidental logging
@ -190,28 +184,28 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
// proxy's callback URL which is responsible for setting our new route-session // proxy's callback URL which is responsible for setting our new route-session
uri := urlutil.NewSignedURL(a.sharedKey, callbackURL) uri := urlutil.NewSignedURL(a.sharedKey, callbackURL)
httputil.Redirect(w, r, uri.String(), http.StatusFound) httputil.Redirect(w, r, uri.String(), http.StatusFound)
return nil
} }
// SignOut signs the user out and attempts to revoke the user's identity session // SignOut signs the user out and attempts to revoke the user's identity session
// Handles both GET and POST. // Handles both GET and POST.
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
session, err := sessions.FromContext(r.Context()) session, err := sessions.FromContext(r.Context())
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
a.sessionStore.ClearSession(w, r) a.sessionStore.ClearSession(w, r)
err = a.provider.Revoke(r.Context(), session.AccessToken) err = a.provider.Revoke(r.Context(), session.AccessToken)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound) httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
return nil
} }
// reauthenticateOrFail starts the authenticate process by redirecting the // reauthenticateOrFail starts the authenticate process by redirecting the
@ -224,11 +218,10 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest // https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
// https://tools.ietf.org/html/rfc6749#section-4.2.1 // https://tools.ietf.org/html/rfc6749#section-4.2.1
// https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest // https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest
func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) { func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Request, err error) error {
// If request AJAX/XHR request, return a 401 instead . // If request AJAX/XHR request, return a 401 instead .
if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") { if reqType := r.Header.Get("X-Requested-With"); strings.EqualFold(reqType, "XmlHttpRequest") {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err)) return httputil.NewError(http.StatusUnauthorized, err)
return
} }
a.sessionStore.ClearSession(w, r) a.sessionStore.ClearSession(w, r)
redirectURL := a.RedirectURL.ResolveReference(r.URL) redirectURL := a.RedirectURL.ResolveReference(r.URL)
@ -239,19 +232,20 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
b = append(b, enc...) b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b) encodedState := base64.URLEncoding.EncodeToString(b)
httputil.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound) httputil.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
return nil
} }
// OAuthCallback handles the callback from the identity provider. // OAuthCallback handles the callback from the identity provider.
// //
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowSteps // https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowSteps
// https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse // https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) error {
redirect, err := a.getOAuthCallback(w, r) redirect, err := a.getOAuthCallback(w, r)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err)) return fmt.Errorf("oauth callback : %w", err)
return
} }
httputil.Redirect(w, r, redirect.String(), http.StatusFound) httputil.Redirect(w, r, redirect.String(), http.StatusFound)
return nil
} }
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) { func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
@ -259,12 +253,12 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
// //
// first, check if the identity provider returned an error // first, check if the identity provider returned an error
if idpError := r.FormValue("error"); idpError != "" { if idpError := r.FormValue("error"); idpError != "" {
return nil, httputil.Error(idpError, http.StatusBadRequest, fmt.Errorf("identity provider: %v", idpError)) return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("identity provider: %v", idpError))
} }
// fail if no session redemption code is returned // fail if no session redemption code is returned
code := r.FormValue("code") code := r.FormValue("code")
if code == "" { if code == "" {
return nil, httputil.Error("identity provider returned empty code", http.StatusBadRequest, nil) return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("identity provider returned empty code"))
} }
// Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5 // Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5
@ -277,20 +271,19 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
// state includes a csrf nonce (validated by middleware) and redirect uri // state includes a csrf nonce (validated by middleware) and redirect uri
bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state")) bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state"))
if err != nil { if err != nil {
return nil, httputil.Error("malformed state", http.StatusBadRequest, err) return nil, httputil.NewError(http.StatusBadRequest, err)
} }
// split state into concat'd components // split state into concat'd components
// (nonce|timestamp|redirect_url|encrypted_data(redirect_url)+mac(nonce,ts)) // (nonce|timestamp|redirect_url|encrypted_data(redirect_url)+mac(nonce,ts))
statePayload := strings.SplitN(string(bytes), "|", 3) statePayload := strings.SplitN(string(bytes), "|", 3)
if len(statePayload) != 3 { if len(statePayload) != 3 {
return nil, httputil.Error("'state' is malformed", http.StatusBadRequest, return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("state malformed, size: %d", len(statePayload)))
fmt.Errorf("state malformed, size: %d", len(statePayload)))
} }
// verify that the returned timestamp is valid // verify that the returned timestamp is valid
if err := cryptutil.ValidTimestamp(statePayload[1]); err != nil { if err := cryptutil.ValidTimestamp(statePayload[1]); err != nil {
return nil, httputil.Error(err.Error(), http.StatusBadRequest, err) return nil, httputil.NewError(http.StatusBadRequest, err)
} }
// Use our AEAD construct to enforce secrecy and authenticity: // Use our AEAD construct to enforce secrecy and authenticity:
@ -299,12 +292,12 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
b := []byte(fmt.Sprint(statePayload[0], "|", statePayload[1], "|")) b := []byte(fmt.Sprint(statePayload[0], "|", statePayload[1], "|"))
redirectString, err := cryptutil.Decrypt(a.cookieCipher, []byte(statePayload[2]), b) redirectString, err := cryptutil.Decrypt(a.cookieCipher, []byte(statePayload[2]), b)
if err != nil { if err != nil {
return nil, httputil.Error("'state' has invalid hmac", http.StatusBadRequest, err) return nil, httputil.NewError(http.StatusBadRequest, err)
} }
redirectURL, err := urlutil.ParseAndValidateURL(string(redirectString)) redirectURL, err := urlutil.ParseAndValidateURL(string(redirectString))
if err != nil { if err != nil {
return nil, httputil.Error("'state' has invalid redirect uri", http.StatusBadRequest, err) return nil, httputil.NewError(http.StatusBadRequest, err)
} }
// OK. Looks good so let's persist our user session // OK. Looks good so let's persist our user session
@ -317,29 +310,25 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
// RefreshAPI loads a global state, and attempts to refresh the session's access // RefreshAPI loads a global state, and attempts to refresh the session's access
// tokens and state with the identity provider. If successful, a new signed JWT // tokens and state with the identity provider. If successful, a new signed JWT
// and refresh token (`refresh_token`) are returned as JSON // and refresh token (`refresh_token`) are returned as JSON
func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error {
s, err := sessions.FromContext(r.Context()) s, err := sessions.FromContext(r.Context())
if err != nil && !errors.Is(err, sessions.ErrExpired) { if err != nil && !errors.Is(err, sessions.ErrExpired) {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
newSession, err := a.provider.Refresh(r.Context(), s) newSession, err := a.provider.Refresh(r.Context(), s)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusInternalServerError, err)) return err
return
} }
newSession = newSession.NewSession(s.Issuer, s.Audience) newSession = newSession.NewSession(s.Issuer, s.Audience)
encSession, err := a.encryptedEncoder.Marshal(newSession) encSession, err := a.encryptedEncoder.Marshal(newSession)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusInternalServerError, err)) return err
return
} }
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession()) signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession())
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusInternalServerError, err)) return err
return
} }
var response struct { var response struct {
JWT string `json:"jwt"` JWT string `json:"jwt"`
@ -350,9 +339,9 @@ func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) {
jsonResponse, err := json.Marshal(&response) jsonResponse, err := json.Marshal(&response)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write(jsonResponse) w.Write(jsonResponse)
return nil
} }

View file

@ -11,6 +11,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/encoding/mock"
@ -154,8 +156,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
r = r.WithContext(ctx) r = r.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
httputil.HandlerFunc(a.SignIn).ServeHTTP(w, r)
a.SignIn(w, r)
if status := w.Code; status != tt.wantCode { if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v %s", status, tt.wantCode, uri) t.Errorf("handler returned wrong status code: got %v want %v %s", status, tt.wantCode, uri)
t.Errorf("\n%+v", w.Body) t.Errorf("\n%+v", w.Body)
@ -186,9 +187,9 @@ func TestAuthenticate_SignOut(t *testing.T) {
wantBody string wantBody string
}{ }{
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""}, {"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"error\":\"could not revoke user session\"}\n"}, {"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"},
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"error\":\"Bad Request\"}\n"}, {"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"},
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"error\":\"malformed redirect_uri\"}\n"}, {"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -211,8 +212,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
httputil.HandlerFunc(a.SignOut).ServeHTTP(w, r)
a.SignOut(w, r)
if status := w.Code; status != tt.wantCode { if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode) t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
} }
@ -299,8 +299,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
r := httptest.NewRequest(tt.method, u.String(), nil) r := httptest.NewRequest(tt.method, u.String(), nil)
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
httputil.HandlerFunc(a.OAuthCallback).ServeHTTP(w, r)
a.OAuthCallback(w, r)
if w.Result().StatusCode != tt.wantCode { if w.Result().StatusCode != tt.wantCode {
t.Errorf("Authenticate.OAuthCallback() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantCode, w.Body.String()) t.Errorf("Authenticate.OAuthCallback() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantCode, w.Body.String())
return return
@ -366,7 +365,6 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
got.ServeHTTP(w, r) got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus { if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
} }
}) })
} }
@ -417,7 +415,7 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
a.RefreshAPI(w, r) httputil.HandlerFunc(a.RefreshAPI).ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus { if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())

View file

@ -10,6 +10,8 @@
<div id="main"> <div id="main">
<div id="info-box"> <div id="info-box">
<div class="card"> <div class="card">
<div class="card-header">
<h2>Current user</h2>
{{if .Session.Picture }} {{if .Session.Picture }}
<img class="icon" src="{{.Session.Picture}}" alt="user image" /> <img class="icon" src="{{.Session.Picture}}" alt="user image" />
{{else}} {{else}}
@ -17,14 +19,12 @@
class="icon" class="icon"
src="/.pomerium/assets/img/account_circle-24px.svg" src="/.pomerium/assets/img/account_circle-24px.svg"
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
width="24"
height="24"
/> />
{{end}} {{end}}
</div>
<form method="POST" action="/.pomerium/sign_out"> <form method="POST" action="/.pomerium/sign_out">
<section> <section>
<h2>Current user</h2>
<p class="message">Your current session details.</p> <p class="message">Your current session details.</p>
<fieldset> <fieldset>
{{if .Session.Name}} {{if .Session.Name}}
@ -189,11 +189,23 @@
<button class="button full" type="submit">Sign Out</button> <button class="button full" type="submit">Sign Out</button>
</div> </div>
</form> </form>
</div>
</div>
{{if .IsAdmin}} {{if .IsAdmin}}
<div id="info-box">
<div class="card">
<div class="card-header">
<h2>Sign-in-as</h2>
<img
class="icon"
src="/.pomerium/assets/img/supervised_user_circle-24px.svg"
xmlns="http://www.w3.org/2000/svg"
/>
</div>
<form method="POST" action="/.pomerium/impersonate"> <form method="POST" action="/.pomerium/impersonate">
<section> <section>
<h2>Sign-in-as</h2>
<p class="message"> <p class="message">
Administrators can temporarily impersonate another user. Administrators can temporarily impersonate another user.
</p> </p>
@ -235,7 +247,6 @@
{{ end }} {{ end }}
</div> </div>
</div> </div>
{{template "footer.html"}}
</div> </div>
</body> </body>
</html> </html>

View file

@ -2,34 +2,57 @@
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en" charset="utf-8"> <html lang="en" charset="utf-8">
<head> <head>
<title>{{.Code}} - {{.Title}}</title> <title>{{.Status}} - {{.StatusText}}</title>
{{template "header.html"}} {{template "header.html"}}
</head> </head>
<body> <body>
<div id="main"> <div id="main">
<div id="info-box"> <div id="info-box">
<div class="card"> <div class="card">
<div class="card-header">
<img <img
class="icon" class="icon"
src="/.pomerium/assets/img/error-24px.svg" src="/.pomerium/assets/img/error-24px.svg"
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
width="24"
height="24"
/> />
<h1 class="title">{{.Title}}</h1> <h2>{{.StatusText}}</h2>
<h2>{{.Status}}</h2>
</div>
<section> <section>
<p class="message"> <div class="message">
{{if .Message}}{{.Message}}{{end}} {{if .CanDebug}}Troubleshoot <div class="text-monospace">{{.Error}}</div>
your </div>
<a href="/.pomerium/">session</a>.{{end}} {{if .RequestID}} {{if .CanDebug}}
Request {{.RequestID}}{{end}} <div class="message">
</p> If you should have access, contact your administrator and provide
them with your
<a href="/.pomerium/">request details</a>.
</div>
{{end}} {{if.RetryURL}}
<div class="message">
If you believe the error is temporary, you can
<a href="{{.RetryURL}}">retry</a> the request.
</div>
{{end}}
</section> </section>
<div class="card-footer">
<a href="https://www.pomerium.io">
<img
src="/.pomerium/assets/img/pomerium_circle_96.svg"
xmlns="http://www.w3.org/2000/svg"
class="icon"
/>
</a>
<div class="text-right text-muted small">
{{.RequestID}} <br />
Pomerium {{.Version}}
</div>
</div>
</div> </div>
</div> </div>
{{template "footer.html"}}
</div> </div>
</body> </body>
</html> </html>
{{end}} {{end}}

View file

@ -1,12 +0,0 @@
{{define "footer.html"}}
<footer>
<a href="https://www.pomerium.io">
<img
class="powered-by-pomerium"
src="/.pomerium/assets/img/pomerium.svg"
xmlns="http://www.w3.org/2000/svg"
height="25"
/>
</a>
</footer>
{{end}}

View file

@ -8,5 +8,4 @@
type="text/css" type="text/css"
href="/.pomerium/assets/style/main.css" href="/.pomerium/assets/style/main.css"
/> />
{{end}} {{end}}

View file

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24"><path d="M0 0h24v24H0z" fill="none"/><path fill="#6e43e8" d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm1 15h-2v-2h2v2zm0-4h-2V7h2v6z"/></svg> <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24"><path d="M0 0h24v24H0z" fill="none"/><path fill="#333333" d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm1 15h-2v-2h2v2zm0-4h-2V7h2v6z"/></svg>

Before

Width:  |  Height:  |  Size: 249 B

After

Width:  |  Height:  |  Size: 249 B

Before After
Before After

View file

@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="96.03" height="96.03" viewBox="0 0 900 900">
<image id="Vector_Smart_Object" data-name="Vector Smart Object" width="900" height="900" xlink:href="data:img/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAN6ElEQVR4nO2dCXBUVRaGv3QCSCCsQoAh7IQ4oCjgXo67jugM6rhbQKkjMzVFuYzbOC6l1uiglIowipbKuMEoq6AoIjOCsomgCAGFsMkmQkICCQGy9dTpezt56XT3W7vTSfqv6kpVp9+9557z7r3n/Oe8+1LGDPWToGgHnAL0BboBfYAsIBNoA2QAzbXoFUAxcBg4AOwGtgE7gR3Aev19wiEtgQQShV4CXKAVP0Qr2iraR/ndUWANsA5YCnwO5MdnWNFR3zOgEzAcuAE420SJXqIU+BqYDnwE7IlTv3VQXwa4CLgFuAloVR8CGFAOzASmAvPj3bkvzv2J0hcD/wXuSADlC5oBNwMfA98AY+LZebwMcB3wnb7Lzo9Tn04wDHgN2ATcFo8OY22AM4EvgBnAqTHuy0tkA1P0jLg4lh3FygDpwIvASu3VNFTIjFikjZEZizHEwgC/1VP4nhi0XV+4TY/pZq/799oAzwGfAt09bjcR0BaYBrzppSxeGaCdXusf8Ki9RMbtwFqglxcyemEAiVpzG/habxeDgQ1ejNmtAS7Q7uWv3ArSAJGuZ/21bkR3Y4ArtQDxDuYSDbPcxAxOlXe5jhyTUBA3dXS8DPAbYEFS8XXwFnC93YvsGiAHWOJMviaB6XapFjsGaK259CSi439Aj1gYYCHQMal8U/i0c2L5x1YwXidMkrAGSZ++75UBZE27P6l427jRCndkZgDJGc+tH/kbBd7We2dEmBngRU1CJeEMkm17z6kBxOUcm1S8a4zQOfCwiGaA/yT+2BoMpkYSNJIBrmxgKcRERxfgL+FkjGSASU1WVbHDuHAthzPA74HeDXSQiYyMcHtquMKstTrh4Aj+qsatxRR35Ps+oKvxi9Da0GF2lZ+Sov4e3AelhyG1mWshExd+qCiHFulwYjc1Tps3XBe9wswLfhFqgLvttJaaBiVFkL8XhlwIg8+HzF6Qlkglvx6iqgr274LcZbBqAbQ9Edp1UkaxgfuNBjAuQa10iXeKlbZ8qXAoH44dgdGPw4iwe3zjxedTYcqj4PdDhy5QWWFrqFm6hL7WJnytVeULyo9BSSE8OKXpKV9w6a3w6DQ4WqI+1jUXQHUK02gAyyk1Wfv2boOr7oSzhtvquFFh4Nlw04Pw83bw2TNANUkXXIKa66dLWli5urIcDhXAS0ugW1/z3xf+Aj/vgJKD0LwldOwKWQPcmeKnH9TGX34c2nSAbv3UX6eQvWxPnhpXWnNonwm9B5o3duQwjD0HqirhBHu13oFlKLhdXmFV+YKifBh8nrny9+2AGS/Ad19AcSGUl0FqKrRoCb1PhuG3w7kjbAnN4umw8F3YvgGOl6qNsVlztSEOuwyuu0d5KFZRuB9mvgirPoOiA1BxXM3w5idAz5PgslFwcRRSuVUbGHIRLJoGXewZQCrGJwQNcImdK4+VQtc+0X+z9EOYeBcU7oMuvZW3IC6rX8+g3KWw+jO4cgzcNdG8T9nkxt8Ji6ZC63Zq45O/MvOr/HCsBGZOgCUz4Z6X4cwrzNv8fgk8dwf88hNk9lRG9GkZqyog71tYswhWfAwPTYFmEW7RE7vb9oTQlePVe8Apdq4UIaP5+qL8J65Xyuk/BNIzlNck1/h8aiAye7JyYNYk+Ocokw798PgfYOHb0GcQdOml7nppS9qUWdWqLWQPgYoyeOwa5SZGw/ql8PffKS8ue6gyZqpBRlmGOveEPqfAFx/AIyMiezoBOeztAYIz0AaQiXO63aurIgQguzfDs7dDp+5qrY8ktFyf1gxyhsKCd2H685H7ev0RWDYPBgyDlNTIwY/0Jf227wzjRqslMBwOHYBnRkHL1urOjySj9CMGEQN98xm8+qA9XZhA1pBsn45+WzpqIgzeeEQtMbKJVVZG/6340HLn9MqBaeOU8UKxeQ3Mngh9T1a/x+SRNlFmx25Qcgjeezr8b959GvL3QOceFvx3v1Jw38EwbzJsWG5VE5Zwmhign1etbVih7pTu2dYDE1GqbGTiTXz8et3/z31FGUm8J7/F5wml76xsWDEffvqx9v/EZfxyFnTvbyN48qslT5bOOa9YvMYaevj0Q9CeYPXnalCy3tuBzJTOWWpdPnqk5kLxSnJXQKcs5ebZgXgxxQdheUhGe9lc1e4J6fbak1kgy9uPqxT14hH6+7yknjevhtbtlbLkbrXzkY1agruta2va27EBDhdA8xb226vUfvmerbVl3J2n3GAnMsrGLO709vWeGaCvuKE9vWhJApmNX6sNruyo/etlgLuLlMIGnau+27YOtu5XLp59JwMKCtUda0TeGjhQBCnbHTQoMhVC3ndw+uXOrg9BxzT9tLpryN0xdoJSlng3tuFX+8DAs2ouPPUCeOplyHD4/HzZMWjdtmazF9zyMFxVoPYUJ5BlbeA5XmgsgIw0r8pOZE2NFjE6gfjgfWxFKOawG3nHGOk+s8KhJGKKFj7DkS8xhXhHkj8oLfauF1mypE3P0qB+1Z7EEHFC7HNXK+fD6oWwbb1SmPjTwiMNOgcuHam8HzuQNViILwmIxKeXPUc8r36DFf8zxMFz7WsXw9efQN5aleMQN1qCNKGbhfdv19lrrdRADFAWi4bFfXxprKIQhCtp01H55kcqlSso/Mq812DMOGvEGZoJfesJ2L1FcTeywYqyxK/fuEIFcuddA3dNsmZYobInjIUlM5QhhYwLuKhVyrjCaX30Gox6DC651bVKwqEiTZ+d44JJr4sDe+BvV8DOTdB7kFKScZkIsJg+2L8THr0GHngdLhsZvc05k+Bf9yp+Kfs0qKyqoSWE0+mQCeXlii2V+OHZT5VCI0ECvoeHw7qv1EYvM9PI6Qi517UXFOyDZ0ZCwc9wo/c14mWyB3i64sld9eSN6g7qf5qmoEPWaHELJRDK7AGZWTD+j4oajoTlHylqWygO4XkCHFMILSHKEzZTiLPtufDUzdGphnGjVJSdc4YqLqhDqOlgTmhvyQtMfgAWz/BGRwaU+rw+uuv98bBhGfQaaM61yAAzOqhMltzdx8MEcLIpvnyvWpNlWTGjJKRPuaOFx58dob7vkzfhy9lq3zCTMZjpktkw+T6VhfMQxT59sJ0nkJINGVyPAeZMqHGAolyJehe+U/f/899QtLIsPVb5IJlx3frAgrd0wtyAsuPw4Stq9lmFzA5hd2UZmvuqpwYoFAM4DMrrQjwJoXnT25rTxkaIwiRjtjLkyWOJZJfPU0SdnbIPWeJk05fN/qs5tf8n7e3apDwbq+wqemZJ7kBkPHbEwgXWsN0XrE/xAkJHy+Zll7kURYhCtq6HnQb6eOcPala1dhKr+5WTvfnb2l//+I2j7FUA6W2gYK9yqT3CJp8+X9M1pE5o7xZlACeQjVCWC6lMCELcTfFWUmzS22ijtsyA/JDbSxQoa7qduz8IcadFnt151n5vATt9+jxN+ynlEJSWQJmuenAEne2SNToICYpk6ju9Y4UglHpVoxcmkXiai9hf2hLm1yOsEwMUaSO4gp+aigLHCFW0JP9dnqrpx9ndHgkeFh7LWaW5wea+dy1Yiu3yvMjteNhm6OxxOpu8bkOfp1ddlpI8/yH+kBMZqw0wTx+AnUT8MMtoAPFsVyWVHzdI7LWFkOrotxv3mBMK1edIGA3wgVUJhboNrQRu21FtmA6rxKrbbWM4j0VY0woXZLkEhMLEGstkJN6wSpOElbGsbqwjHJVNOauPvjQmZIQV/RC42uxqKaSSCNVYcXBE+9vNnCTkUcaTgUglXLBmRyqgW7Vz2J6uDRI6Y+NKFUShaegWLuoARR6J1o1j37PFlpwSm1cXy4Q+JXm+Pt08KoJRqyg9CBmg8O8B+tmh3y1BXHFRDSsqs0xoCJuP/9SSqaICDufXxCcyU6Vqw+lMlbHLuI0kn9w48rEo5636ANgAwj2mmmdarqg5/lDOJ9VtgjNFF0xp5YgCA8kcl4GUUTGBSN2FH5+iZTQa0IacRboMqFqicHHdk6bN6DobUbjx4xp+NZhgm4HHQD2IYmvJ6DKICtQY+Wq3aUPO50Pd/Uhv0PgFiGEqukmiQpcAHTcOPhKz8demrq0Y4MlQ5WPyDhmp1u/fwAaZqCiMVPgQjdu7pUmpKLa4I1Lr0QywOnlokyeQs1bnRGrIjN3+U5Kkc43rojVgZoBiffxiEs5wt/YoHRtAMBsIUzCShAnkCGPTJ6CtJthGa68oCWso1K9oNIWdDOeFXiTvmwguDOfzh4MdA+yNdv5lEtW4wU6O3W6Of6kVuroJ48/6rYGW4aTIYq7T13U0cjyk30NpC06rXN7R79NKQuFh/RI723BDIv9bP9xh6Zz8Rgw5sG2y0+G5rfOSPPKlcoRQU9F2CG5yo3w8egeYvG30ZGCjB201FOzVBy5ZLmSIBK8qHbdoI0Q9K7+R4BPg117VUXl5xq1kSUfqIxm9qx9OHAgpeZ8+Wd6z5+picciwvNBsgBfTM4Egr+g9CXjBa5Fidcrzfr1BXa1fhNxQsUuXkQwPlhJ6jVgfsz1Xvwrlbl0P31BwEHhMv1t+Wixljtc55xP1IXX3J/iM2KWT5yLrP+LhXsfzoPkyXReTo4/uTaQXgn6l87ZSkPaE1w+vR0O0qoh4IEen7GSvGBrnvnP1cxHiLKyrn+HXvwGM6Kc3uzO1MVyeLl0H23ShwWrt1eR63L4jJNKrFsTLMKbwJNiRjxwqKGuynFQtxetSHC71zfIJyi9VqlLSKx8pGZaMlChc6lzl+FZ5+tj1c3CeA/g/z2BJ+JKJOlwAAAAASUVORK5CYII="/>
</svg>

After

Width:  |  Height:  |  Size: 5 KiB

View file

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24"><path d="M11.99 2c-5.52 0-10 4.48-10 10s4.48 10 10 10 10-4.48 10-10-4.48-10-10-10zm3.61 6.34c1.07 0 1.93.86 1.93 1.93 0 1.07-.86 1.93-1.93 1.93-1.07 0-1.93-.86-1.93-1.93-.01-1.07.86-1.93 1.93-1.93zm-6-1.58c1.3 0 2.36 1.06 2.36 2.36 0 1.3-1.06 2.36-2.36 2.36s-2.36-1.06-2.36-2.36c0-1.31 1.05-2.36 2.36-2.36zm0 9.13v3.75c-2.4-.75-4.3-2.6-5.14-4.96 1.05-1.12 3.67-1.69 5.14-1.69.53 0 1.2.08 1.9.22-1.64.87-1.9 2.02-1.9 2.68zM11.99 20c-.27 0-.53-.01-.79-.04v-4.07c0-1.42 2.94-2.13 4.4-2.13 1.07 0 2.92.39 3.84 1.15-1.17 2.97-4.06 5.09-7.45 5.09z"/><path fill="none" d="M0 0h24v24H0z"/></svg>

After

Width:  |  Height:  |  Size: 670 B

View file

@ -7,35 +7,50 @@
box-sizing: border-box; box-sizing: border-box;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto,
"Helvetica Neue", sans-serif; "Helvetica Neue", sans-serif;
font-size: 15px; font-size: 1rem;
line-height: 1.4em; line-height: 1.4em;
} }
.primary {
color: #6e43e8;
}
.light {
/* a571ff */
color: rgb(165, 113, 255);
}
.dark {
/* #422D66 */
color: rgb(66, 45, 102);
}
.text-monospace {
font-size: 0.85rem;
font-family: SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono",
"Courier New", monospace !important;
}
body { body {
display: flex; display: flex;
flex-direction: row; flex-direction: row;
align-items: center; align-items: center;
background: #f8f8ff; background: rgba(165, 113, 255, 0.05);
} color: rgb(33, 37, 41);
#main {
width: 100%;
height: 100vh;
text-align: center;
display: flex;
flex-direction: column;
justify-content: space-between; justify-content: space-between;
} }
#info-box { #main {
max-width: 480px; display: flex;
width: 480px; flex-direction: column;
margin-top: 200px;
margin-right: auto;
margin-bottom: 0px;
margin-left: auto;
justify-content: center; justify-content: center;
flex-grow: 1; width: 100%;
min-height: 100vh;
}
#info-box {
justify-content: center;
align-items: center;
display: flex;
padding-bottom: 2.2rem;
} }
section { section {
@ -43,46 +58,49 @@ section {
flex-direction: column; flex-direction: column;
position: relative; position: relative;
text-align: left; text-align: left;
min-height: 12em;
} }
h1 { h1 {
font-size: 36px; font-size: 2.5rem;
font-weight: 400; font-weight: 400;
text-align: center; text-align: center;
letter-spacing: 0.3px; letter-spacing: 0.3px;
text-transform: uppercase; text-transform: uppercase;
color: #32325d; color: rgb(110, 67, 232);
} }
h1.title { h1.title {
text-align: center; text-align: center;
background: #f8f8ff; padding: 0.75rem 1.25rem;
margin: 15px 0;
} }
h2 { h2 {
margin: 15px 0; text-align: left;
color: #32325d; color: #333333;
text-transform: uppercase; text-transform: uppercase;
letter-spacing: 0.3px; letter-spacing: 0.3px;
font-size: 18px; font-size: 1.25rem;
font-weight: 650; font-weight: 650;
padding-top: 20px;
} }
.card { .card {
margin: 0 -30px; border-radius: 0.5rem;
padding: 20px 30px 30px;
border-radius: 4px; border-radius: 4px;
border: 1px solid #e8e8fb; border: 1px solid rgba(0, 0, 0, 0.125);
background-color: #f8f8ff; flex-grow: 1;
flex-shrink: 1;
margin: 0 -30px;
max-width: 40%;
min-width: 480px;
padding: 1.25rem 1.25rem;
box-shadow: 0 0.125rem 0.25rem rgba(0, 0, 0, 0.075);
} }
fieldset { fieldset {
margin-bottom: 20px; margin-bottom: 20px;
background: #fcfcff; background: #fcfcff5d;
box-shadow: 0 1px 3px 0 rgba(50, 50, 93, 0.15), box-shadow: 0 0.125rem 0.25rem rgba(0, 0, 0, 0.075);
0 4px 6px 0 rgba(112, 157, 199, 0.15);
border-radius: 4px; border-radius: 4px;
border: none; border: none;
font-size: 0; font-size: 0;
@ -100,7 +118,7 @@ fieldset label {
} }
fieldset label:not(:last-child) { fieldset label:not(:last-child) {
border-bottom: 1px solid #f0f5fa; border-bottom: 1px solid rgba(0, 0, 0, 0.1);
} }
fieldset label span { fieldset label span {
@ -109,55 +127,24 @@ fieldset label span {
text-align: right; text-align: right;
} }
#group { img.icon {
display: flex; width: auto;
align-items: center; height: 36px;
}
#group::before {
display: inline-flex;
content: "";
height: 15px;
background-position: -1000px -1000px;
background-repeat: no-repeat;
}
.icon {
display: inline-table;
margin-top: -72px;
text-align: center;
width: 75px;
height: auto;
border-radius: 50%; border-radius: 50%;
} }
.icon svg { .message {
fill: #6e43e8; padding: 2.55rem 0.75rem;
background: red;
}
.logo {
padding-bottom: 20px;
padding-top: 20px;
width: 115px;
height: auto;
}
p.message {
margin-top: 10px;
margin-bottom: 10px;
padding-bottom: 20px;
} }
.field { .field {
flex: 1; flex: 1;
padding: 0 15px; padding: 0 15px;
background: transparent;
font-weight: 400; font-weight: 400;
color: #31325f; color: rgb(66, 45, 102);
background: #fcfcff5d;
outline: none; outline: none;
cursor: text; cursor: text;
white-space: nowrap; white-space: nowrap;
overflow: hidden; overflow: hidden;
text-overflow: ellipsis; text-overflow: ellipsis;
@ -177,7 +164,6 @@ fieldset .select::after {
input { input {
border-style: none; border-style: none;
outline: none; outline: none;
color: #313b3f;
} }
select { select {
@ -207,7 +193,6 @@ select {
border-radius: 4px; border-radius: 4px;
border: 0; border: 0;
font-weight: 700; font-weight: 700;
width: 50%;
height: 40px; height: 40px;
outline: none; outline: none;
cursor: pointer; cursor: pointer;
@ -234,6 +219,49 @@ select {
background: #5735b5; background: #5735b5;
} }
.powered-by-pomerium { .footer-icon {
align-items: center; display: inline-table;
margin-top: -12px;
height: 24px;
width: auto;
vertical-align: top;
}
.text-muted {
color: #6c757d !important;
font-size: 0.75rem;
}
.card-footer {
display: flex;
justify-content: space-between;
align-items: left;
margin-top: 0px;
margin-right: -1.25rem;
margin-bottom: -1.25rem;
margin-left: -1.25rem;
padding: 0.75rem 1.25rem;
background-color: rgba(0, 0, 0, 0.03);
border-top: 1px solid rgba(0, 0, 0, 0.125);
border-bottom-right-radius: 0.5rem;
border-bottom-left-radius: 0.5rem;
}
.card-header {
display: flex;
align-items: center;
justify-content: space-between;
margin-top: 0px;
margin-right: -1.25rem;
margin-top: -1.25rem;
margin-left: -1.25rem;
padding: 0.75rem 1.25rem;
background-color: rgba(0, 0, 0, 0.03);
border-bottom: 1px solid rgba(0, 0, 0, 0.125);
border-top-right-radius: 0.5rem;
border-top-left-radius: 0.5rem;
}
.text-right {
text-align: right;
} }

File diff suppressed because one or more lines are too long

View file

@ -2,110 +2,91 @@ package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt"
"html/template" "html/template"
"io"
"net/http" "net/http"
"github.com/pomerium/pomerium/internal/frontend" "github.com/pomerium/pomerium/internal/frontend"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/version"
) )
// Error formats creates a HTTP error with code, user friendly (and safe) error var errorTemplate = template.Must(frontend.NewTemplates())
// message. If nil or empty, HTTP status code defaults to 500 and message var fullVersion = version.FullVersion()
// defaults to the text of the status code.
func Error(message string, code int, err error) error {
if code == 0 {
code = http.StatusInternalServerError
}
if message == "" {
message = http.StatusText(code)
}
return &httpError{Message: message, Code: code, Err: err}
}
type httpError struct { // HTTPError contains an HTTP status code and wrapped error.
// Message to present to the end user. type HTTPError struct {
Message string
// HTTP status codes as registered with IANA. // HTTP status codes as registered with IANA.
Code int Status int
// Err is the wrapped error
Err error // the cause Err error
} }
func (e *httpError) Error() string { // NewError returns an error that contains a HTTP status and error.
s := fmt.Sprintf("%d %s: %s", e.Code, http.StatusText(e.Code), e.Message) func NewError(status int, err error) error {
if e.Err != nil { return &HTTPError{Status: status, Err: err}
return s + ": " + e.Err.Error()
}
return s
}
func (e *httpError) Unwrap() error { return e.Err }
// Timeout reports whether this error represents a user debuggable error.
func (e *httpError) Debugable() bool {
return e.Code == http.StatusUnauthorized || e.Code == http.StatusForbidden
} }
// ErrorResponse renders an error page given an error. If the error is a // Error implements the `error` interface.
// http error from this package, a user friendly message is set, http status code, func (e *HTTPError) Error() string {
// the ability to debug are also set. return http.StatusText(e.Status) + ": " + e.Err.Error()
func ErrorResponse(w http.ResponseWriter, r *http.Request, e error) {
statusCode := http.StatusInternalServerError // default status code to return
errorString := e.Error()
var canDebug bool
var requestID string
var httpError *httpError
// if this is an HTTPError, we can add some additional useful information
if errors.As(e, &httpError) {
canDebug = httpError.Debugable()
statusCode = httpError.Code
errorString = httpError.Message
} }
// Unwrap implements the `error` Unwrap interface.
func (e *HTTPError) Unwrap() error { return e.Err }
// Debugable reports whether this error represents a user debuggable error.
func (e *HTTPError) Debugable() bool {
return e.Status == http.StatusUnauthorized || e.Status == http.StatusForbidden
}
// RetryURL returns the requests intended destination, if any.
func (e *HTTPError) RetryURL(r *http.Request) string {
return r.FormValue(urlutil.QueryRedirectURI)
}
type errResponse struct {
Status int
Error string
StatusText string `json:"-"`
RequestID string `json:",omitempty"`
CanDebug bool `json:"-"`
RetryURL string `json:"-"`
Version string `json:"-"`
}
// ErrorResponse replies to the request with the specified error message and HTTP code.
// It does not otherwise end the request; the caller should ensure no further
// writes are done to w.
func (e *HTTPError) ErrorResponse(w http.ResponseWriter, r *http.Request) {
// indicate to clients that the error originates from Pomerium, not the app // indicate to clients that the error originates from Pomerium, not the app
w.Header().Set(HeaderPomeriumResponse, "true") w.Header().Set(HeaderPomeriumResponse, "true")
w.WriteHeader(e.Status)
log.FromRequest(r).Error().Err(e).Str("http-message", errorString).Int("http-code", statusCode).Msg("http-error") log.FromRequest(r).Info().Err(e).Msg("httputil: ErrorResponse")
var requestID string
if id, ok := log.IDFromRequest(r); ok { if id, ok := log.IDFromRequest(r); ok {
requestID = id requestID = id
} }
if r.Header.Get("Accept") == "application/json" { response := errResponse{
var response struct { Status: e.Status,
Error string `json:"error"` StatusText: http.StatusText(e.Status),
} Error: e.Error(),
response.Error = errorString
writeJSONResponse(w, statusCode, response)
} else {
w.WriteHeader(statusCode)
w.Header().Set("Content-Type", "text/html")
t := struct {
Code int
Title string
Message string
RequestID string
CanDebug bool
}{
Code: statusCode,
Title: http.StatusText(statusCode),
Message: errorString,
RequestID: requestID, RequestID: requestID,
CanDebug: canDebug, CanDebug: e.Debugable(),
} RetryURL: e.RetryURL(r),
template.Must(frontend.NewTemplates()).ExecuteTemplate(w, "error.html", t) Version: fullVersion,
}
} }
// writeJSONResponse is a helper that sets the application/json header and writes a response. if r.Header.Get("Accept") == "application/json" {
func writeJSONResponse(w http.ResponseWriter, code int, response interface{}) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
err := json.NewEncoder(w).Encode(response) err := json.NewEncoder(w).Encode(response)
if err != nil { if err != nil {
io.WriteString(w, err.Error()) http.Error(w, err.Error(), http.StatusInternalServerError)
}
} else {
w.Header().Set("Content-Type", "text/html; charset=UTF-8")
errorTemplate.ExecuteTemplate(w, "error.html", response)
} }
} }

View file

@ -9,68 +9,67 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
) )
func TestErrorResponse(t *testing.T) { func TestHTTPError_ErrorResponse(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
rw http.ResponseWriter Status int
r *http.Request Err error
e *httpError reqType string
wantStatus int
wantBody string
}{ }{
{"good", httptest.NewRecorder(), &http.Request{Method: http.MethodGet}, &httpError{Code: http.StatusBadRequest, Message: "missing id token"}}, {"404 json", http.StatusNotFound, errors.New("route not known"), "application/json", http.StatusNotFound, "{\"Status\":404,\"Error\":\"Not Found: route not known\"}\n"},
{"good json", httptest.NewRecorder(), &http.Request{Method: http.MethodGet, Header: http.Header{"Accept": []string{"application/json"}}}, &httpError{Code: http.StatusBadRequest, Message: "missing id token"}}, {"404 html", http.StatusNotFound, errors.New("route not known"), "", http.StatusNotFound, ""},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ErrorResponse(tt.rw, tt.r, tt.e) fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := NewError(tt.Status, tt.Err)
var e *HTTPError
if errors.As(err, &e) {
e.ErrorResponse(w, r)
} else {
http.Error(w, "coulnd't convert error type", http.StatusTeapot)
}
})
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", tt.reqType)
w := httptest.NewRecorder()
fn(w, r)
if diff := cmp.Diff(tt.wantStatus, w.Code); diff != "" {
t.Errorf("ErrorResponse status:\n %s", diff)
}
if tt.reqType == "application/json" {
if diff := cmp.Diff(tt.wantBody, w.Body.String()); diff != "" {
t.Errorf("ErrorResponse status:\n %s", diff)
}
}
}) })
} }
} }
func TestError_Error(t *testing.T) { func TestNewError(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
Message string status int
Code int
InnerErr error
want string
}{
{"good", "short and stout", http.StatusTeapot, nil, "418 I'm a teapot: short and stout"},
{"nested error", "short and stout", http.StatusTeapot, errors.New("another error"), "418 I'm a teapot: short and stout: another error"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := httpError{
Message: tt.Message,
Code: tt.Code,
Err: tt.InnerErr,
}
got := h.Error()
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("Error.Error() = %s", diff)
}
})
}
}
func Test_httpError_Error(t *testing.T) {
tests := []struct {
name string
message string
code int
err error err error
want string wantErr bool
}{ }{
{"good", "foobar", 200, nil, "200 OK: foobar"}, {"good", 404, errors.New("error"), true},
{"no code", "foobar", 0, nil, "500 Internal Server Error: foobar"},
{"no message or code", "", 0, nil, "500 Internal Server Error: Internal Server Error"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
e := Error(tt.message, tt.code, tt.err) err := NewError(tt.status, tt.err)
if got := e.Error(); got != tt.want { if (err != nil) != tt.wantErr {
t.Errorf("httpError.Error() = %v, want %v", got, tt.want) t.Errorf("NewError() error = %v, wantErr %v", err, tt.wantErr)
} }
if err != nil && !errors.Is(err, tt.err) {
t.Errorf("NewError() unwrap fail = %v, wantErr %v", err, tt.wantErr)
}
}) })
} }
} }

View file

@ -1,6 +1,8 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil" package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import ( import (
"errors"
"fmt"
"net/http" "net/http"
) )
@ -14,7 +16,7 @@ func HealthCheck(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
if r.Method == http.MethodGet { if r.Method == http.MethodGet {
w.Write([]byte(http.StatusText(http.StatusOK))) fmt.Fprintln(w, http.StatusText(http.StatusOK))
} }
} }
@ -24,3 +26,22 @@ func Redirect(w http.ResponseWriter, r *http.Request, url string, code int) {
w.Header().Set(HeaderPomeriumResponse, "true") w.Header().Set(HeaderPomeriumResponse, "true")
http.Redirect(w, r, url, code) http.Redirect(w, r, url, code)
} }
// The HandlerFunc type is an adapter to allow the use of
// ordinary functions as HTTP handlers. If f is a function
// with the appropriate signature, HandlerFunc(f) is a
// Handler that calls f.
//
// adapted from std library to suppport error wrapping
type HandlerFunc func(http.ResponseWriter, *http.Request) error
// ServeHTTP calls f(w, r) error.
func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := f(w, r); err != nil {
var e *HTTPError
if !errors.As(err, &e) {
e = &HTTPError{http.StatusInternalServerError, err}
}
e.ErrorResponse(w, r)
}
}

View file

@ -1,9 +1,12 @@
package httputil package httputil
import ( import (
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/google/go-cmp/cmp"
) )
func TestHealthCheck(t *testing.T) { func TestHealthCheck(t *testing.T) {
@ -66,3 +69,26 @@ func TestRedirect(t *testing.T) {
}) })
} }
} }
func TestHandlerFunc_ServeHTTP(t *testing.T) {
tests := []struct {
name string
f HandlerFunc
wantBody string
}{
{"good http error", func(w http.ResponseWriter, r *http.Request) error { return NewError(404, errors.New("404")) }, "{\"Status\":404,\"Error\":\"Not Found: 404\"}\n"},
{"good std error", func(w http.ResponseWriter, r *http.Request) error { return errors.New("404") }, "{\"Status\":500,\"Error\":\"Internal Server Error: 404\"}\n"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
tt.f.ServeHTTP(w, r)
if diff := cmp.Diff(tt.wantBody, w.Body.String()); diff != "" {
t.Errorf("ErrorResponse status:\n %s", diff)
}
})
}
}

View file

@ -14,6 +14,9 @@ func NewRouter() *mux.Router {
// CSRFFailureHandler sets a HTTP 403 Forbidden status and writes the // CSRFFailureHandler sets a HTTP 403 Forbidden status and writes the
// CSRF failure reason to the response. // CSRF failure reason to the response.
func CSRFFailureHandler(w http.ResponseWriter, r *http.Request) { func CSRFFailureHandler(w http.ResponseWriter, r *http.Request) error {
ErrorResponse(w, r, Error("CSRF Failure", http.StatusForbidden, csrf.FailureReason(r))) if err := csrf.FailureReason(r); err != nil {
return NewError(http.StatusBadRequest, csrf.FailureReason(r))
}
return nil
} }

View file

@ -1,43 +1,12 @@
package httputil package httputil
import ( import (
"net/http"
"net/http/httptest"
"reflect" "reflect"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
func TestCSRFFailureHandler(t *testing.T) {
tests := []struct {
name string
wantBody string
wantStatus int
}{
{"basic csrf failure", "{\"error\":\"CSRF Failure\"}\n", http.StatusForbidden},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
CSRFFailureHandler(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %s", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %s", diff)
}
})
}
}
func TestNewRouter(t *testing.T) { func TestNewRouter(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View file

@ -28,14 +28,14 @@ func SetHeaders(headers map[string]string) func(next http.Handler) http.Handler
// the correspdoning client secret key // the correspdoning client secret key
func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler { func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature") ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature")
defer span.End() defer span.End()
if err := ValidateRequestURL(r, sharedSecret); err != nil { if err := ValidateRequestURL(r, sharedSecret); err != nil {
httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
return nil
}) })
} }
} }

View file

@ -170,7 +170,7 @@ func TestValidateSignature(t *testing.T) {
wantBody string wantBody string
}{ }{
{"good", "secret", "secret", http.StatusOK, http.StatusText(http.StatusOK)}, {"good", "secret", "secret", http.StatusOK, http.StatusText(http.StatusOK)},
{"secret mistmatch", "secret", "hunter42", http.StatusBadRequest, "{\"error\":\"invalid signature\"}\n"}, {"secret mistmatch", "secret", "hunter42", http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: internal/urlutil: hmac failed\"}\n"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -65,12 +65,12 @@ func (su *SignedURL) Validate() error {
issued, err := newNumericDateFromString(params.Get(QueryHmacIssued)) issued, err := newNumericDateFromString(params.Get(QueryHmacIssued))
if err != nil { if err != nil {
return fmt.Errorf("internal/urlutil: issued %w", err) return err
} }
expiry, err := newNumericDateFromString(params.Get(QueryHmacExpiry)) expiry, err := newNumericDateFromString(params.Get(QueryHmacExpiry))
if err != nil { if err != nil {
return fmt.Errorf("internal/urlutil: expiry %w", err) return err
} }
if expiry != nil && now.Add(-DefaultLeeway).After(expiry.Time()) { if expiry != nil && now.Add(-DefaultLeeway).After(expiry.Time()) {
@ -86,7 +86,7 @@ func (su *SignedURL) Validate() error {
sig, sig,
su.key) su.key)
if !validHMAC { if !validHMAC {
return fmt.Errorf("internal/urlutil: hmac failed %s", su.uri.String()) return fmt.Errorf("internal/urlutil: hmac failed")
} }
return nil return nil
} }

View file

@ -52,7 +52,7 @@ func ValidateURL(u *url.URL) error {
return fmt.Errorf("nil url") return fmt.Errorf("nil url")
} }
if u.Scheme == "" { if u.Scheme == "" {
return fmt.Errorf("%s url does contain a valid scheme. Did you mean https://%s?", u.String(), u.String()) return fmt.Errorf("%s url does contain a valid scheme", u.String())
} }
if u.Host == "" { if u.Host == "" {
return fmt.Errorf("%s url does contain a valid hostname", u.String()) return fmt.Errorf("%s url does contain a valid hostname", u.String())

View file

@ -16,13 +16,13 @@ func (p *Proxy) registerFwdAuthHandlers() http.Handler {
r.StrictSlash(true) r.StrictSlash(true)
r.Use(sessions.RetrieveSession(p.sessionStore)) r.Use(sessions.RetrieveSession(p.sessionStore))
r.Handle("/verify", http.HandlerFunc(p.nginxCallback)). r.Handle("/verify", httputil.HandlerFunc(p.nginxCallback)).
Queries("uri", "{uri}", urlutil.QuerySessionEncrypted, "", urlutil.QueryRedirectURI, "") Queries("uri", "{uri}", urlutil.QuerySessionEncrypted, "", urlutil.QueryRedirectURI, "")
r.Handle("/", http.HandlerFunc(p.postSessionSetNOP)). r.Handle("/", httputil.HandlerFunc(p.postSessionSetNOP)).
Queries("uri", "{uri}", Queries("uri", "{uri}",
urlutil.QuerySessionEncrypted, "", urlutil.QuerySessionEncrypted, "",
urlutil.QueryRedirectURI, "") urlutil.QueryRedirectURI, "")
r.Handle("/", http.HandlerFunc(p.traefikCallback)). r.Handle("/", httputil.HandlerFunc(p.traefikCallback)).
HeadersRegexp(httputil.HeaderForwardedURI, urlutil.QuerySessionEncrypted) HeadersRegexp(httputil.HeaderForwardedURI, urlutil.QuerySessionEncrypted)
r.Handle("/", p.Verify(false)).Queries("uri", "{uri}") r.Handle("/", p.Verify(false)).Queries("uri", "{uri}")
r.Handle("/verify", p.Verify(true)).Queries("uri", "{uri}") r.Handle("/verify", p.Verify(true)).Queries("uri", "{uri}")
@ -31,37 +31,39 @@ func (p *Proxy) registerFwdAuthHandlers() http.Handler {
} }
// postSessionSetNOP after successfully setting the // postSessionSetNOP after successfully setting the
func (p *Proxy) postSessionSetNOP(w http.ResponseWriter, r *http.Request) { func (p *Proxy) postSessionSetNOP(w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
httputil.Redirect(w, r, r.FormValue(urlutil.QueryRedirectURI), http.StatusFound) httputil.Redirect(w, r, r.FormValue(urlutil.QueryRedirectURI), http.StatusFound)
return nil
} }
func (p *Proxy) nginxCallback(w http.ResponseWriter, r *http.Request) { func (p *Proxy) nginxCallback(w http.ResponseWriter, r *http.Request) error {
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted) encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil { if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return return httputil.NewError(http.StatusBadRequest, err)
} }
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return nil
} }
func (p *Proxy) traefikCallback(w http.ResponseWriter, r *http.Request) { func (p *Proxy) traefikCallback(w http.ResponseWriter, r *http.Request) error {
forwardedURL, err := url.Parse(r.Header.Get(httputil.HeaderForwardedURI)) forwardedURL, err := url.Parse(r.Header.Get(httputil.HeaderForwardedURI))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
q := forwardedURL.Query() q := forwardedURL.Query()
redirectURLString := q.Get(urlutil.QueryRedirectURI) redirectURLString := q.Get(urlutil.QueryRedirectURI)
encryptedSession := q.Get(urlutil.QuerySessionEncrypted) encryptedSession := q.Get(urlutil.QuerySessionEncrypted)
if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil { if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return return httputil.NewError(http.StatusBadRequest, err)
} }
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
httputil.Redirect(w, r, redirectURLString, http.StatusFound) httputil.Redirect(w, r, redirectURLString, http.StatusFound)
return nil
} }
// Verify checks a user's credentials for an arbitrary host. If the user // Verify checks a user's credentials for an arbitrary host. If the user
@ -70,18 +72,16 @@ func (p *Proxy) traefikCallback(w http.ResponseWriter, r *http.Request) {
// will be redirected to the authenticate service to sign in with their identity // will be redirected to the authenticate service to sign in with their identity
// provider. If the user is unauthorized, a `401` error is returned. // provider. If the user is unauthorized, a `401` error is returned.
func (p *Proxy) Verify(verifyOnly bool) http.Handler { func (p *Proxy) Verify(verifyOnly bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
uri, err := urlutil.ParseAndValidateURL(r.FormValue("uri")) uri, err := urlutil.ParseAndValidateURL(r.FormValue("uri"))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("bad verification uri", http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
s, err := sessions.FromContext(r.Context()) s, err := sessions.FromContext(r.Context())
if errors.Is(err, sessions.ErrNoSessionFound) || errors.Is(err, sessions.ErrExpired) { if errors.Is(err, sessions.ErrNoSessionFound) || errors.Is(err, sessions.ErrExpired) {
if verifyOnly { if verifyOnly {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err)) return httputil.NewError(http.StatusUnauthorized, err)
return
} }
authN := *p.authenticateSigninURL authN := *p.authenticateSigninURL
q := authN.Query() q := authN.Query()
@ -90,25 +90,24 @@ func (p *Proxy) Verify(verifyOnly bool) http.Handler {
q.Set(urlutil.QueryForwardAuth, urlutil.StripPort(r.Host)) // add fwd auth to trusted audience q.Set(urlutil.QueryForwardAuth, urlutil.StripPort(r.Host)) // add fwd auth to trusted audience
authN.RawQuery = q.Encode() authN.RawQuery = q.Encode()
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &authN).String(), http.StatusFound) httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &authN).String(), http.StatusFound)
return return nil
} else if err != nil { } else if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err)) return httputil.NewError(http.StatusUnauthorized, err)
return
} }
// depending on the configuration of the fronting proxy, the request Host // depending on the configuration of the fronting proxy, the request Host
// and/or `X-Forwarded-Host` may be untrustd or change so we reverify // and/or `X-Forwarded-Host` may be untrustd or change so we reverify
// the session's validity against the supplied uri // the session's validity against the supplied uri
if err := s.Verify(uri.Hostname()); err != nil { if err := s.Verify(uri.Hostname()); err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err)) return httputil.NewError(http.StatusUnauthorized, err)
return
} }
p.addPomeriumHeaders(w, r) p.addPomeriumHeaders(w, r)
if err := p.authorize(uri.Host, w, r); err != nil { if err := p.authorize(uri.Host, r); err != nil {
return return err
} }
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "Access to %s is allowed.", uri.Host) fmt.Fprintf(w, "Access to %s is allowed.", uri.Host)
return nil
}) })
} }

View file

@ -42,19 +42,19 @@ func TestProxy_ForwardAuth(t *testing.T) {
}{ }{
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."}, {"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."},
{"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""}, {"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, {"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
{"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"}, {"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
{"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"}, {"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
{"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"}, {"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
{"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"}, {"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
{"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"}, {"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"}, {"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""}, {"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"}, {"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
{"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"error\":\"authz error\"}\n"}, {"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"Status\":500,\"Error\":\"Internal Server Error: authz error\"}\n"},
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"}, {"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
{"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, {"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
{"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, {"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
// traefik // traefik
{"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, {"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},

View file

@ -29,12 +29,12 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
p.cookieSecret, p.cookieSecret,
csrf.Secure(p.cookieOptions.Secure), csrf.Secure(p.cookieOptions.Secure),
csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieOptions.Name)), csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieOptions.Name)),
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)), csrf.ErrorHandler(httputil.HandlerFunc(httputil.CSRFFailureHandler)),
)) ))
// dashboard endpoints can be used by user's to view, or modify their session // dashboard endpoints can be used by user's to view, or modify their session
h.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet) h.Path("/").Handler(httputil.HandlerFunc(p.UserDashboard)).Methods(http.MethodGet)
h.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost) h.Path("/impersonate").Handler(httputil.HandlerFunc(p.Impersonate)).Methods(http.MethodPost)
h.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost) h.Path("/sign_out").HandlerFunc(p.SignOut).Methods(http.MethodGet, http.MethodPost)
// Authenticate service callback handlers and middleware // Authenticate service callback handlers and middleware
// callback used to set route-scoped session and redirect back to destination // callback used to set route-scoped session and redirect back to destination
@ -42,14 +42,16 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
c := r.PathPrefix(dashboardURL + "/callback").Subrouter() c := r.PathPrefix(dashboardURL + "/callback").Subrouter()
c.Use(middleware.ValidateSignature(p.SharedKey)) c.Use(middleware.ValidateSignature(p.SharedKey))
c.Path("/").HandlerFunc(p.ProgrammaticCallback).Methods(http.MethodGet). c.Path("/").
Handler(httputil.HandlerFunc(p.ProgrammaticCallback)).
Methods(http.MethodGet).
Queries(urlutil.QueryIsProgrammatic, "true") Queries(urlutil.QueryIsProgrammatic, "true")
c.Path("/").HandlerFunc(p.Callback).Methods(http.MethodGet) c.Path("/").Handler(httputil.HandlerFunc(p.Callback)).Methods(http.MethodGet)
// Programmatic API handlers and middleware // Programmatic API handlers and middleware
a := r.PathPrefix(dashboardURL + "/api").Subrouter() a := r.PathPrefix(dashboardURL + "/api").Subrouter()
// login api handler generates a user-navigable login url to authenticate // login api handler generates a user-navigable login url to authenticate
a.HandleFunc("/v1/login", p.ProgrammaticLogin). a.Path("/v1/login").Handler(httputil.HandlerFunc(p.ProgrammaticLogin)).
Queries(urlutil.QueryRedirectURI, ""). Queries(urlutil.QueryRedirectURI, "").
Methods(http.MethodGet) Methods(http.MethodGet)
@ -84,17 +86,15 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
// UserDashboard lets users investigate, and refresh their current session. // UserDashboard lets users investigate, and refresh their current session.
// It also contains certain administrative actions like user impersonation. // It also contains certain administrative actions like user impersonation.
// Nota bene: This endpoint does authentication, not authorization. // Nota bene: This endpoint does authentication, not authorization.
func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) { func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) error {
session, err := sessions.FromContext(r.Context()) session, err := sessions.FromContext(r.Context())
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err) return err
return
} }
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err) return err
return
} }
p.templates.ExecuteTemplate(w, "dashboard.html", map[string]interface{}{ p.templates.ExecuteTemplate(w, "dashboard.html", map[string]interface{}{
@ -105,23 +105,23 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
"ImpersonateEmail": urlutil.QueryImpersonateEmail, "ImpersonateEmail": urlutil.QueryImpersonateEmail,
"ImpersonateGroups": urlutil.QueryImpersonateGroups, "ImpersonateGroups": urlutil.QueryImpersonateGroups,
}) })
return nil
} }
// Impersonate takes the result of a form and adds user impersonation details // Impersonate takes the result of a form and adds user impersonation details
// to the user's current user sessions state if the user is currently an // to the user's current user sessions state if the user is currently an
// administrative user. Requests are redirected back to the user dashboard. // administrative user. Requests are redirected back to the user dashboard.
func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) { func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) error {
session, err := sessions.FromContext(r.Context()) session, err := sessions.FromContext(r.Context())
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err) return err
return
} }
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
if err != nil || !isAdmin { if err != nil {
errStr := fmt.Sprintf("%s is not an administrator", session.RequestEmail()) return err
httpErr := httputil.Error(errStr, http.StatusForbidden, err) }
httputil.ErrorResponse(w, r, httpErr) if !isAdmin {
return return httputil.NewError(http.StatusForbidden, fmt.Errorf("%s is not an administrator", session.RequestEmail()))
} }
// OK to impersonation // OK to impersonation
redirectURL := urlutil.GetAbsoluteURL(r) redirectURL := urlutil.GetAbsoluteURL(r)
@ -134,20 +134,20 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
q.Set(urlutil.QueryImpersonateGroups, r.FormValue(urlutil.QueryImpersonateGroups)) q.Set(urlutil.QueryImpersonateGroups, r.FormValue(urlutil.QueryImpersonateGroups))
signinURL.RawQuery = q.Encode() signinURL.RawQuery = q.Encode()
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound) httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
return nil
} }
// Callback handles the result of a successful call to the authenticate service // Callback handles the result of a successful call to the authenticate service
// and is responsible setting returned per-route session. // and is responsible setting returned per-route session.
func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) { func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error {
redirectURLString := r.FormValue(urlutil.QueryRedirectURI) redirectURLString := r.FormValue(urlutil.QueryRedirectURI)
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted) encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil { if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
httputil.Redirect(w, r, redirectURLString, http.StatusFound) httputil.Redirect(w, r, redirectURLString, http.StatusFound)
return nil
} }
// saveCallbackSession takes an encrypted per-route session token, and decrypts // saveCallbackSession takes an encrypted per-route session token, and decrypts
@ -172,11 +172,10 @@ func (p *Proxy) saveCallbackSession(w http.ResponseWriter, r *http.Request, enct
// ProgrammaticLogin returns a signed url that can be used to login // ProgrammaticLogin returns a signed url that can be used to login
// using the authenticate service. // using the authenticate service.
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) { func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error {
redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect uri", http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
signinURL := *p.authenticateSigninURL signinURL := *p.authenticateSigninURL
callbackURI := urlutil.GetAbsoluteURL(r) callbackURI := urlutil.GetAbsoluteURL(r)
@ -191,31 +190,30 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(response)) w.Write([]byte(response))
return nil
} }
// ProgrammaticCallback handles a successful call to the authenticate service. // ProgrammaticCallback handles a successful call to the authenticate service.
// In addition to returning the individual route session (JWT) it also returns // In addition to returning the individual route session (JWT) it also returns
// the refresh token. // the refresh token.
func (p *Proxy) ProgrammaticCallback(w http.ResponseWriter, r *http.Request) { func (p *Proxy) ProgrammaticCallback(w http.ResponseWriter, r *http.Request) error {
redirectURLString := r.FormValue(urlutil.QueryRedirectURI) redirectURLString := r.FormValue(urlutil.QueryRedirectURI)
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted) encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString) redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect uri", http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
rawJWT, err := p.saveCallbackSession(w, r, encryptedSession) rawJWT, err := p.saveCallbackSession(w, r, encryptedSession)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return httputil.NewError(http.StatusBadRequest, err)
return
} }
q := redirectURL.Query() q := redirectURL.Query()
q.Set(urlutil.QueryPomeriumJWT, string(rawJWT)) q.Set(urlutil.QueryPomeriumJWT, string(rawJWT))
q.Set(urlutil.QueryRefreshToken, r.FormValue(urlutil.QueryRefreshToken)) q.Set(urlutil.QueryRefreshToken, r.FormValue(urlutil.QueryRefreshToken))
redirectURL.RawQuery = q.Encode() redirectURL.RawQuery = q.Encode()
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound) httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
return nil
} }

View file

@ -103,7 +103,7 @@ func TestProxy_UserDashboard(t *testing.T) {
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.UserDashboard(w, r) httputil.HandlerFunc(p.UserDashboard).ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus { if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus) t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", opts) t.Errorf("\n%+v", opts)
@ -139,7 +139,7 @@ func TestProxy_Impersonate(t *testing.T) {
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, {"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden}, {"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden}, {"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusInternalServerError},
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, {"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
} }
for _, tt := range tests { for _, tt := range tests {
@ -165,7 +165,7 @@ func TestProxy_Impersonate(t *testing.T) {
r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.Impersonate(w, r) httputil.HandlerFunc(p.Impersonate).ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus { if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus) t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", opts) t.Errorf("\n%+v", opts)
@ -289,7 +289,7 @@ func TestProxy_Callback(t *testing.T) {
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.Callback(w, r) httputil.HandlerFunc(p.Callback).ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus { if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus) t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", w.Body.String()) t.Errorf("\n%+v", w.Body.String())
@ -326,7 +326,7 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
{"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusOK, ""}, {"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusOK, ""},
{"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusOK, ""}, {"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusOK, ""},
{"router miss, bad redirect_uri query", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{"bad_redirect_uri": "http://localhost"}, http.StatusNotFound, ""}, {"router miss, bad redirect_uri query", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{"bad_redirect_uri": "http://localhost"}, http.StatusNotFound, ""},
{"bad redirect_uri missing scheme", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "localhost"}, http.StatusBadRequest, "{\"error\":\"malformed redirect uri\"}\n"}, {"bad redirect_uri missing scheme", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "localhost"}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: localhost url does contain a valid scheme\"}\n"},
{"bad http method", opts, http.MethodPost, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusMethodNotAllowed, ""}, {"bad http method", opts, http.MethodPost, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusMethodNotAllowed, ""},
} }
for _, tt := range tests { for _, tt := range tests {
@ -430,7 +430,7 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.ProgrammaticCallback(w, r) httputil.HandlerFunc(p.ProgrammaticCallback).ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus { if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus) t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", w.Body.String()) t.Errorf("\n%+v", w.Body.String())

View file

@ -26,7 +26,7 @@ const (
// AuthenticateSession is middleware to enforce a valid authentication // AuthenticateSession is middleware to enforce a valid authentication
// session state is retrieved from the users's request context. // session state is retrieved from the users's request context.
func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler { func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession") ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession")
defer span.End() defer span.End()
@ -34,18 +34,17 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session") log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session")
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
if s != nil && s.Programmatic { if s != nil && s.Programmatic {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err)) return httputil.NewError(http.StatusUnauthorized, err)
return
} }
signinURL := *p.authenticateSigninURL signinURL := *p.authenticateSigninURL
q := signinURL.Query() q := signinURL.Query()
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String()) q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
signinURL.RawQuery = q.Encode() signinURL.RawQuery = q.Encode()
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound) httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
return
} }
p.addPomeriumHeaders(w, r) p.addPomeriumHeaders(w, r)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
return nil
}) })
} }
@ -65,31 +64,28 @@ func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
// AuthorizeSession is middleware to enforce a user is authorized for a request // AuthorizeSession is middleware to enforce a user is authorized for a request
// session state is retrieved from the users's request context. // session state is retrieved from the users's request context.
func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler { func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession") ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession")
defer span.End() defer span.End()
if err := p.authorize(r.Host, w, r.WithContext(ctx)); err != nil { if err := p.authorize(r.Host, r.WithContext(ctx)); err != nil {
log.FromRequest(r).Debug().Err(err).Msg("proxy: AuthorizeSession") log.FromRequest(r).Debug().Err(err).Msg("proxy: AuthorizeSession")
return return err
} }
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
return nil
}) })
} }
func (p *Proxy) authorize(host string, w http.ResponseWriter, r *http.Request) error { func (p *Proxy) authorize(host string, r *http.Request) error {
s, err := sessions.FromContext(r.Context()) s, err := sessions.FromContext(r.Context())
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusUnauthorized, err)) return httputil.NewError(http.StatusUnauthorized, err)
return err
} }
authorized, err := p.AuthorizeClient.Authorize(r.Context(), host, s) authorized, err := p.AuthorizeClient.Authorize(r.Context(), host, s)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err)
return err return err
} else if !authorized { } else if !authorized {
err = fmt.Errorf("%s is not authorized for %s", s.RequestEmail(), host) return httputil.NewError(http.StatusUnauthorized, fmt.Errorf("%s is not authorized for %s", s.RequestEmail(), host))
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
return err
} }
return nil return nil
} }
@ -98,13 +94,12 @@ func (p *Proxy) authorize(host string, w http.ResponseWriter, r *http.Request) e
// email, and group. Session state is retrieved from the users's request context // email, and group. Session state is retrieved from the users's request context
func (p *Proxy) SignRequest(signer encoding.Marshaler) func(next http.Handler) http.Handler { func (p *Proxy) SignRequest(signer encoding.Marshaler) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
ctx, span := trace.StartSpan(r.Context(), "proxy.SignRequest") ctx, span := trace.StartSpan(r.Context(), "proxy.SignRequest")
defer span.End() defer span.End()
s, err := sessions.FromContext(r.Context()) s, err := sessions.FromContext(r.Context())
if err != nil { if err != nil {
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err)) return httputil.NewError(http.StatusForbidden, err)
return
} }
newSession := s.NewSession(r.Host, []string{r.Host}) newSession := s.NewSession(r.Host, []string{r.Host})
jwt, err := signer.Marshal(newSession.RouteSession()) jwt, err := signer.Marshal(newSession.RouteSession())
@ -115,6 +110,7 @@ func (p *Proxy) SignRequest(signer encoding.Marshaler) func(next http.Handler) h
w.Header().Set(HeaderJWT, string(jwt)) w.Header().Set(HeaderJWT, string(jwt))
} }
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
return nil
}) })
} }
} }

View file

@ -176,8 +176,8 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error {
log.Warn().Msg("proxy: configuration has no policies") log.Warn().Msg("proxy: configuration has no policies")
} }
r := httputil.NewRouter() r := httputil.NewRouter()
r.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.NotFoundHandler = httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s route unknown", r.Host), http.StatusNotFound, nil)) return httputil.NewError(http.StatusNotFound, fmt.Errorf("%s route unknown", r.Host))
}) })
r.SkipClean(true) r.SkipClean(true)
r.StrictSlash(true) r.StrictSlash(true)