Update JWT headers to only be in responses from forward auth endpoint (#642)

This commit is contained in:
Travis Groth 2020-05-04 07:26:37 -04:00 committed by GitHub
parent f7ee08b05a
commit b2e3b22f14
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 96 additions and 51 deletions

View file

@ -19,6 +19,7 @@ func (p *Proxy) registerFwdAuthHandlers() http.Handler {
r := httputil.NewRouter()
r.StrictSlash(true)
r.Use(sessions.RetrieveSession(p.sessionStore))
r.Use(p.jwtClaimMiddleware(true))
// NGNIX's forward-auth capabilities are split across two settings:
// `auth-url` and `auth-signin` which correspond to `verify` and `auth-url`
@ -117,7 +118,8 @@ func (p *Proxy) Verify(verifyOnly bool) http.Handler {
}
originalRequest := p.getOriginalRequest(r, uri)
if err := p.authorize(w, originalRequest); err != nil {
authz, err := p.authorize(w, originalRequest)
if err != nil {
// no session, so redirect
if _, err := sessions.FromContext(r.Context()); err != nil {
if verifyOnly {
@ -132,10 +134,11 @@ func (p *Proxy) Verify(verifyOnly bool) http.Handler {
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &authN).String(), http.StatusFound)
return nil
}
return err
}
w.Header().Set(httputil.HeaderPomeriumJWTAssertion, authz.GetSignedJwt())
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "Access to %s is allowed.", uri.Host)

View file

@ -9,8 +9,10 @@ import (
"net/http"
"strings"
"github.com/gorilla/mux"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/internal/grpc/authorize"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
@ -92,36 +94,42 @@ func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession")
defer span.End()
if err := p.authorize(w, r); err != nil {
authz, err := p.authorize(w, r)
if err != nil {
return err
}
r.Header.Set(httputil.HeaderPomeriumJWTAssertion, authz.GetSignedJwt())
next.ServeHTTP(w, r.WithContext(ctx))
return nil
})
}
func (p *Proxy) authorize(w http.ResponseWriter, r *http.Request) error {
func (p *Proxy) authorize(w http.ResponseWriter, r *http.Request) (*authorize.IsAuthorizedReply, error) {
ctx, span := trace.StartSpan(r.Context(), "proxy.authorize")
defer span.End()
jwt, _ := sessions.FromContext(ctx)
authz, err := p.AuthorizeClient.Authorize(ctx, jwt, r)
if err != nil {
return httputil.NewError(http.StatusInternalServerError, err)
return nil, httputil.NewError(http.StatusInternalServerError, err)
}
if authz.GetSessionExpired() {
newJwt, err := p.refresh(ctx, jwt)
if err != nil {
p.sessionStore.ClearSession(w, r)
log.FromRequest(r).Warn().Err(err).Msg("proxy: refresh failed")
return p.redirectToSignin(w, r)
return nil, p.redirectToSignin(w, r)
}
if err = p.sessionStore.SaveSession(w, r, newJwt); err != nil {
return httputil.NewError(http.StatusUnauthorized, err)
return nil, httputil.NewError(http.StatusUnauthorized, err)
}
authz, err = p.AuthorizeClient.Authorize(ctx, newJwt, r)
if err != nil {
return httputil.NewError(http.StatusUnauthorized, err)
return nil, httputil.NewError(http.StatusUnauthorized, err)
}
}
if !authz.GetAllow() {
@ -130,12 +138,11 @@ func (p *Proxy) authorize(w http.ResponseWriter, r *http.Request) error {
Bool("allow", authz.GetAllow()).
Bool("expired", authz.GetSessionExpired()).
Msg("proxy/authorize: deny")
return httputil.NewError(http.StatusForbidden, errors.New("request denied"))
return nil, httputil.NewError(http.StatusForbidden, errors.New("request denied"))
}
r.Header.Set(httputil.HeaderPomeriumJWTAssertion, authz.GetSignedJwt())
w.Header().Set(httputil.HeaderPomeriumJWTAssertion, authz.GetSignedJwt())
return nil
return authz, nil
}
// SetResponseHeaders sets a map of response headers.
@ -152,50 +159,78 @@ func SetResponseHeaders(headers map[string]string) func(next http.Handler) http.
}
}
func (p *Proxy) jwtClaimMiddleware(next http.Handler) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if jwt, err := sessions.FromContext(r.Context()); err == nil {
var jwtClaims map[string]interface{}
if err := p.encoder.Unmarshal([]byte(jwt), &jwtClaims); err == nil {
formattedJWTClaims := make(map[string]string)
// jwtClaimMiddleware logs and propagates JWT claim information via request headers
//
// if returnJWTInfo is set to true, it will also return JWT claim information in the response
func (p *Proxy) jwtClaimMiddleware(returnJWTInfo bool) mux.MiddlewareFunc {
// reformat claims into something resembling map[string]string
for claim, value := range jwtClaims {
var formattedClaim string
if cv, ok := value.([]interface{}); ok {
elements := make([]string, len(cv))
return func(next http.Handler) http.Handler {
for i, v := range cv {
elements[i] = fmt.Sprintf("%v", v)
}
formattedClaim = strings.Join(elements, ",")
} else {
formattedClaim = fmt.Sprintf("%v", value)
}
formattedJWTClaims[claim] = formattedClaim
}
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
defer next.ServeHTTP(w, r)
// log group, email, user claims
l := log.Ctx(r.Context())
for _, claimName := range []string{"groups", "email", "user"} {
jwt, err := sessions.FromContext(r.Context())
if err != nil {
log.Error().Err(err).Msg("proxy: could not locate session from context")
return nil // best effort decoding
}
l.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(claimName, fmt.Sprintf("%v", formattedJWTClaims[claimName]))
})
formattedJWTClaims, err := p.getFormatedJWTClaims([]byte(jwt))
if err != nil {
log.Error().Err(err).Msg("proxy: failed to format jwt claims")
return nil // best effort formatting
}
}
// log group, email, user claims
l := log.Ctx(r.Context())
for _, claimName := range []string{"groups", "email", "user"} {
// set headers for any claims specified by config
for _, claimName := range p.jwtClaimHeaders {
if _, ok := formattedJWTClaims[claimName]; ok {
l.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(claimName, fmt.Sprintf("%v", formattedJWTClaims[claimName]))
})
headerName := fmt.Sprintf("x-pomerium-claim-%s", claimName)
r.Header.Set(headerName, formattedJWTClaims[claimName])
}
// set headers for any claims specified by config
for _, claimName := range p.jwtClaimHeaders {
if _, ok := formattedJWTClaims[claimName]; ok {
headerName := fmt.Sprintf("x-pomerium-claim-%s", claimName)
r.Header.Set(headerName, formattedJWTClaims[claimName])
if returnJWTInfo {
w.Header().Add(headerName, formattedJWTClaims[claimName])
}
}
}
}
next.ServeHTTP(w, r)
return nil
})
return nil
})
}
}
// getFormatJWTClaims reformats jwtClaims into something resembling map[string]string
func (p *Proxy) getFormatedJWTClaims(jwt []byte) (map[string]string, error) {
formattedJWTClaims := make(map[string]string)
var jwtClaims map[string]interface{}
if err := p.encoder.Unmarshal(jwt, &jwtClaims); err != nil {
return formattedJWTClaims, err
}
for claim, value := range jwtClaims {
var formattedClaim string
if cv, ok := value.([]interface{}); ok {
elements := make([]string, len(cv))
for i, v := range cv {
elements[i] = fmt.Sprintf("%v", v)
}
formattedClaim = strings.Join(elements, ",")
} else {
formattedClaim = fmt.Sprintf("%v", value)
}
formattedJWTClaims[claim] = formattedClaim
}
return formattedJWTClaims, nil
}

View file

@ -73,7 +73,7 @@ func TestProxy_AuthenticateSession(t *testing.T) {
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := a.jwtClaimMiddleware(a.AuthenticateSession(fn))
got := a.jwtClaimMiddleware(false)(a.AuthenticateSession(fn))
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("AuthenticateSession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
@ -113,7 +113,7 @@ func Test_jwtClaimMiddleware(t *testing.T) {
ctx = sessions.NewContext(ctx, string(state), nil)
r = r.WithContext(ctx)
w := httptest.NewRecorder()
proxyHandler := a.jwtClaimMiddleware(handler)
proxyHandler := a.jwtClaimMiddleware(true)(handler)
proxyHandler.ServeHTTP(w, r)
t.Run("email claim", func(t *testing.T) {
@ -130,6 +130,13 @@ func Test_jwtClaimMiddleware(t *testing.T) {
}
})
t.Run("email response claim", func(t *testing.T) {
emailHeader := w.Header().Get("x-pomerium-claim-email")
if emailHeader != email {
t.Errorf("did not find claim email in response, want=%q, got=%q", email, emailHeader)
}
})
t.Run("missing claim", func(t *testing.T) {
absentHeader := r.Header.Get("x-pomerium-claim-missing")
if absentHeader != "" {

View file

@ -274,7 +274,7 @@ func (p *Proxy) reverseProxyHandler(r *mux.Router, policy config.Policy) *mux.Ro
// 7. Strip the user session cookie from the downstream request
rp.Use(middleware.StripCookie(p.cookieOptions.Name))
// 8 . Add claim details to the request logger context and headers
rp.Use(p.jwtClaimMiddleware)
rp.Use(p.jwtClaimMiddleware(false))
return r
}