From b2e3b22f14c15059689fd44d9b4c87c0b6e066e7 Mon Sep 17 00:00:00 2001 From: Travis Groth Date: Mon, 4 May 2020 07:26:37 -0400 Subject: [PATCH] Update JWT headers to only be in responses from forward auth endpoint (#642) --- proxy/forward_auth.go | 7 ++- proxy/middleware.go | 127 +++++++++++++++++++++++++-------------- proxy/middleware_test.go | 11 +++- proxy/proxy.go | 2 +- 4 files changed, 96 insertions(+), 51 deletions(-) diff --git a/proxy/forward_auth.go b/proxy/forward_auth.go index d7fe72226..ab41bc221 100644 --- a/proxy/forward_auth.go +++ b/proxy/forward_auth.go @@ -19,6 +19,7 @@ func (p *Proxy) registerFwdAuthHandlers() http.Handler { r := httputil.NewRouter() r.StrictSlash(true) r.Use(sessions.RetrieveSession(p.sessionStore)) + r.Use(p.jwtClaimMiddleware(true)) // NGNIX's forward-auth capabilities are split across two settings: // `auth-url` and `auth-signin` which correspond to `verify` and `auth-url` @@ -117,7 +118,8 @@ func (p *Proxy) Verify(verifyOnly bool) http.Handler { } originalRequest := p.getOriginalRequest(r, uri) - if err := p.authorize(w, originalRequest); err != nil { + authz, err := p.authorize(w, originalRequest) + if err != nil { // no session, so redirect if _, err := sessions.FromContext(r.Context()); err != nil { if verifyOnly { @@ -132,10 +134,11 @@ func (p *Proxy) Verify(verifyOnly bool) http.Handler { httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &authN).String(), http.StatusFound) return nil } - return err } + w.Header().Set(httputil.HeaderPomeriumJWTAssertion, authz.GetSignedJwt()) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(http.StatusOK) fmt.Fprintf(w, "Access to %s is allowed.", uri.Host) diff --git a/proxy/middleware.go b/proxy/middleware.go index 555477b50..9fb18a425 100644 --- a/proxy/middleware.go +++ b/proxy/middleware.go @@ -9,8 +9,10 @@ import ( "net/http" "strings" + "github.com/gorilla/mux" "github.com/rs/zerolog" + "github.com/pomerium/pomerium/internal/grpc/authorize" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/sessions" @@ -92,36 +94,42 @@ func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession") defer span.End() - if err := p.authorize(w, r); err != nil { + authz, err := p.authorize(w, r) + if err != nil { return err } + + r.Header.Set(httputil.HeaderPomeriumJWTAssertion, authz.GetSignedJwt()) next.ServeHTTP(w, r.WithContext(ctx)) return nil }) } -func (p *Proxy) authorize(w http.ResponseWriter, r *http.Request) error { +func (p *Proxy) authorize(w http.ResponseWriter, r *http.Request) (*authorize.IsAuthorizedReply, error) { ctx, span := trace.StartSpan(r.Context(), "proxy.authorize") defer span.End() + jwt, _ := sessions.FromContext(ctx) + authz, err := p.AuthorizeClient.Authorize(ctx, jwt, r) if err != nil { - return httputil.NewError(http.StatusInternalServerError, err) + return nil, httputil.NewError(http.StatusInternalServerError, err) } + if authz.GetSessionExpired() { newJwt, err := p.refresh(ctx, jwt) if err != nil { p.sessionStore.ClearSession(w, r) log.FromRequest(r).Warn().Err(err).Msg("proxy: refresh failed") - return p.redirectToSignin(w, r) + return nil, p.redirectToSignin(w, r) } if err = p.sessionStore.SaveSession(w, r, newJwt); err != nil { - return httputil.NewError(http.StatusUnauthorized, err) + return nil, httputil.NewError(http.StatusUnauthorized, err) } authz, err = p.AuthorizeClient.Authorize(ctx, newJwt, r) if err != nil { - return httputil.NewError(http.StatusUnauthorized, err) + return nil, httputil.NewError(http.StatusUnauthorized, err) } } if !authz.GetAllow() { @@ -130,12 +138,11 @@ func (p *Proxy) authorize(w http.ResponseWriter, r *http.Request) error { Bool("allow", authz.GetAllow()). Bool("expired", authz.GetSessionExpired()). Msg("proxy/authorize: deny") - return httputil.NewError(http.StatusForbidden, errors.New("request denied")) + return nil, httputil.NewError(http.StatusForbidden, errors.New("request denied")) } - r.Header.Set(httputil.HeaderPomeriumJWTAssertion, authz.GetSignedJwt()) - w.Header().Set(httputil.HeaderPomeriumJWTAssertion, authz.GetSignedJwt()) - return nil + return authz, nil + } // SetResponseHeaders sets a map of response headers. @@ -152,50 +159,78 @@ func SetResponseHeaders(headers map[string]string) func(next http.Handler) http. } } -func (p *Proxy) jwtClaimMiddleware(next http.Handler) http.Handler { - return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - if jwt, err := sessions.FromContext(r.Context()); err == nil { - var jwtClaims map[string]interface{} - if err := p.encoder.Unmarshal([]byte(jwt), &jwtClaims); err == nil { - formattedJWTClaims := make(map[string]string) +// jwtClaimMiddleware logs and propagates JWT claim information via request headers +// +// if returnJWTInfo is set to true, it will also return JWT claim information in the response +func (p *Proxy) jwtClaimMiddleware(returnJWTInfo bool) mux.MiddlewareFunc { - // reformat claims into something resembling map[string]string - for claim, value := range jwtClaims { - var formattedClaim string - if cv, ok := value.([]interface{}); ok { - elements := make([]string, len(cv)) + return func(next http.Handler) http.Handler { - for i, v := range cv { - elements[i] = fmt.Sprintf("%v", v) - } - formattedClaim = strings.Join(elements, ",") - } else { - formattedClaim = fmt.Sprintf("%v", value) - } - formattedJWTClaims[claim] = formattedClaim - } + return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + defer next.ServeHTTP(w, r) - // log group, email, user claims - l := log.Ctx(r.Context()) - for _, claimName := range []string{"groups", "email", "user"} { + jwt, err := sessions.FromContext(r.Context()) + if err != nil { + log.Error().Err(err).Msg("proxy: could not locate session from context") + return nil // best effort decoding + } - l.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.Str(claimName, fmt.Sprintf("%v", formattedJWTClaims[claimName])) - }) + formattedJWTClaims, err := p.getFormatedJWTClaims([]byte(jwt)) + if err != nil { + log.Error().Err(err).Msg("proxy: failed to format jwt claims") + return nil // best effort formatting + } - } + // log group, email, user claims + l := log.Ctx(r.Context()) + for _, claimName := range []string{"groups", "email", "user"} { - // set headers for any claims specified by config - for _, claimName := range p.jwtClaimHeaders { - if _, ok := formattedJWTClaims[claimName]; ok { + l.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(claimName, fmt.Sprintf("%v", formattedJWTClaims[claimName])) + }) - headerName := fmt.Sprintf("x-pomerium-claim-%s", claimName) - r.Header.Set(headerName, formattedJWTClaims[claimName]) + } + + // set headers for any claims specified by config + for _, claimName := range p.jwtClaimHeaders { + if _, ok := formattedJWTClaims[claimName]; ok { + + headerName := fmt.Sprintf("x-pomerium-claim-%s", claimName) + r.Header.Set(headerName, formattedJWTClaims[claimName]) + if returnJWTInfo { + w.Header().Add(headerName, formattedJWTClaims[claimName]) } } } - } - next.ServeHTTP(w, r) - return nil - }) + + return nil + }) + } +} + +// getFormatJWTClaims reformats jwtClaims into something resembling map[string]string +func (p *Proxy) getFormatedJWTClaims(jwt []byte) (map[string]string, error) { + formattedJWTClaims := make(map[string]string) + + var jwtClaims map[string]interface{} + if err := p.encoder.Unmarshal(jwt, &jwtClaims); err != nil { + return formattedJWTClaims, err + } + + for claim, value := range jwtClaims { + var formattedClaim string + if cv, ok := value.([]interface{}); ok { + elements := make([]string, len(cv)) + + for i, v := range cv { + elements[i] = fmt.Sprintf("%v", v) + } + formattedClaim = strings.Join(elements, ",") + } else { + formattedClaim = fmt.Sprintf("%v", value) + } + formattedJWTClaims[claim] = formattedClaim + } + + return formattedJWTClaims, nil } diff --git a/proxy/middleware_test.go b/proxy/middleware_test.go index 174f04768..9fd6a550d 100644 --- a/proxy/middleware_test.go +++ b/proxy/middleware_test.go @@ -73,7 +73,7 @@ func TestProxy_AuthenticateSession(t *testing.T) { r = r.WithContext(ctx) r.Header.Set("Accept", "application/json") w := httptest.NewRecorder() - got := a.jwtClaimMiddleware(a.AuthenticateSession(fn)) + got := a.jwtClaimMiddleware(false)(a.AuthenticateSession(fn)) got.ServeHTTP(w, r) if status := w.Code; status != tt.wantStatus { t.Errorf("AuthenticateSession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) @@ -113,7 +113,7 @@ func Test_jwtClaimMiddleware(t *testing.T) { ctx = sessions.NewContext(ctx, string(state), nil) r = r.WithContext(ctx) w := httptest.NewRecorder() - proxyHandler := a.jwtClaimMiddleware(handler) + proxyHandler := a.jwtClaimMiddleware(true)(handler) proxyHandler.ServeHTTP(w, r) t.Run("email claim", func(t *testing.T) { @@ -130,6 +130,13 @@ func Test_jwtClaimMiddleware(t *testing.T) { } }) + t.Run("email response claim", func(t *testing.T) { + emailHeader := w.Header().Get("x-pomerium-claim-email") + if emailHeader != email { + t.Errorf("did not find claim email in response, want=%q, got=%q", email, emailHeader) + } + }) + t.Run("missing claim", func(t *testing.T) { absentHeader := r.Header.Get("x-pomerium-claim-missing") if absentHeader != "" { diff --git a/proxy/proxy.go b/proxy/proxy.go index 43339000a..34e057c90 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -274,7 +274,7 @@ func (p *Proxy) reverseProxyHandler(r *mux.Router, policy config.Policy) *mux.Ro // 7. Strip the user session cookie from the downstream request rp.Use(middleware.StripCookie(p.cookieOptions.Name)) // 8 . Add claim details to the request logger context and headers - rp.Use(p.jwtClaimMiddleware) + rp.Use(p.jwtClaimMiddleware(false)) return r }