mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
Update JWT headers to only be in responses from forward auth endpoint (#642)
This commit is contained in:
parent
f7ee08b05a
commit
b2e3b22f14
4 changed files with 96 additions and 51 deletions
|
@ -19,6 +19,7 @@ func (p *Proxy) registerFwdAuthHandlers() http.Handler {
|
|||
r := httputil.NewRouter()
|
||||
r.StrictSlash(true)
|
||||
r.Use(sessions.RetrieveSession(p.sessionStore))
|
||||
r.Use(p.jwtClaimMiddleware(true))
|
||||
|
||||
// NGNIX's forward-auth capabilities are split across two settings:
|
||||
// `auth-url` and `auth-signin` which correspond to `verify` and `auth-url`
|
||||
|
@ -117,7 +118,8 @@ func (p *Proxy) Verify(verifyOnly bool) http.Handler {
|
|||
}
|
||||
originalRequest := p.getOriginalRequest(r, uri)
|
||||
|
||||
if err := p.authorize(w, originalRequest); err != nil {
|
||||
authz, err := p.authorize(w, originalRequest)
|
||||
if err != nil {
|
||||
// no session, so redirect
|
||||
if _, err := sessions.FromContext(r.Context()); err != nil {
|
||||
if verifyOnly {
|
||||
|
@ -132,10 +134,11 @@ func (p *Proxy) Verify(verifyOnly bool) http.Handler {
|
|||
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &authN).String(), http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
w.Header().Set(httputil.HeaderPomeriumJWTAssertion, authz.GetSignedJwt())
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, "Access to %s is allowed.", uri.Host)
|
||||
|
|
|
@ -9,8 +9,10 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/grpc/authorize"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
|
@ -92,36 +94,42 @@ func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
|
|||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession")
|
||||
defer span.End()
|
||||
if err := p.authorize(w, r); err != nil {
|
||||
authz, err := p.authorize(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Header.Set(httputil.HeaderPomeriumJWTAssertion, authz.GetSignedJwt())
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Proxy) authorize(w http.ResponseWriter, r *http.Request) error {
|
||||
func (p *Proxy) authorize(w http.ResponseWriter, r *http.Request) (*authorize.IsAuthorizedReply, error) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "proxy.authorize")
|
||||
defer span.End()
|
||||
|
||||
jwt, _ := sessions.FromContext(ctx)
|
||||
|
||||
authz, err := p.AuthorizeClient.Authorize(ctx, jwt, r)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, err)
|
||||
return nil, httputil.NewError(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
if authz.GetSessionExpired() {
|
||||
newJwt, err := p.refresh(ctx, jwt)
|
||||
if err != nil {
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
log.FromRequest(r).Warn().Err(err).Msg("proxy: refresh failed")
|
||||
return p.redirectToSignin(w, r)
|
||||
return nil, p.redirectToSignin(w, r)
|
||||
}
|
||||
if err = p.sessionStore.SaveSession(w, r, newJwt); err != nil {
|
||||
return httputil.NewError(http.StatusUnauthorized, err)
|
||||
return nil, httputil.NewError(http.StatusUnauthorized, err)
|
||||
}
|
||||
|
||||
authz, err = p.AuthorizeClient.Authorize(ctx, newJwt, r)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusUnauthorized, err)
|
||||
return nil, httputil.NewError(http.StatusUnauthorized, err)
|
||||
}
|
||||
}
|
||||
if !authz.GetAllow() {
|
||||
|
@ -130,12 +138,11 @@ func (p *Proxy) authorize(w http.ResponseWriter, r *http.Request) error {
|
|||
Bool("allow", authz.GetAllow()).
|
||||
Bool("expired", authz.GetSessionExpired()).
|
||||
Msg("proxy/authorize: deny")
|
||||
return httputil.NewError(http.StatusForbidden, errors.New("request denied"))
|
||||
return nil, httputil.NewError(http.StatusForbidden, errors.New("request denied"))
|
||||
}
|
||||
|
||||
r.Header.Set(httputil.HeaderPomeriumJWTAssertion, authz.GetSignedJwt())
|
||||
w.Header().Set(httputil.HeaderPomeriumJWTAssertion, authz.GetSignedJwt())
|
||||
return nil
|
||||
return authz, nil
|
||||
|
||||
}
|
||||
|
||||
// SetResponseHeaders sets a map of response headers.
|
||||
|
@ -152,50 +159,78 @@ func SetResponseHeaders(headers map[string]string) func(next http.Handler) http.
|
|||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) jwtClaimMiddleware(next http.Handler) http.Handler {
|
||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
if jwt, err := sessions.FromContext(r.Context()); err == nil {
|
||||
var jwtClaims map[string]interface{}
|
||||
if err := p.encoder.Unmarshal([]byte(jwt), &jwtClaims); err == nil {
|
||||
formattedJWTClaims := make(map[string]string)
|
||||
// jwtClaimMiddleware logs and propagates JWT claim information via request headers
|
||||
//
|
||||
// if returnJWTInfo is set to true, it will also return JWT claim information in the response
|
||||
func (p *Proxy) jwtClaimMiddleware(returnJWTInfo bool) mux.MiddlewareFunc {
|
||||
|
||||
// reformat claims into something resembling map[string]string
|
||||
for claim, value := range jwtClaims {
|
||||
var formattedClaim string
|
||||
if cv, ok := value.([]interface{}); ok {
|
||||
elements := make([]string, len(cv))
|
||||
return func(next http.Handler) http.Handler {
|
||||
|
||||
for i, v := range cv {
|
||||
elements[i] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
formattedClaim = strings.Join(elements, ",")
|
||||
} else {
|
||||
formattedClaim = fmt.Sprintf("%v", value)
|
||||
}
|
||||
formattedJWTClaims[claim] = formattedClaim
|
||||
}
|
||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
defer next.ServeHTTP(w, r)
|
||||
|
||||
// log group, email, user claims
|
||||
l := log.Ctx(r.Context())
|
||||
for _, claimName := range []string{"groups", "email", "user"} {
|
||||
jwt, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("proxy: could not locate session from context")
|
||||
return nil // best effort decoding
|
||||
}
|
||||
|
||||
l.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(claimName, fmt.Sprintf("%v", formattedJWTClaims[claimName]))
|
||||
})
|
||||
formattedJWTClaims, err := p.getFormatedJWTClaims([]byte(jwt))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("proxy: failed to format jwt claims")
|
||||
return nil // best effort formatting
|
||||
}
|
||||
|
||||
}
|
||||
// log group, email, user claims
|
||||
l := log.Ctx(r.Context())
|
||||
for _, claimName := range []string{"groups", "email", "user"} {
|
||||
|
||||
// set headers for any claims specified by config
|
||||
for _, claimName := range p.jwtClaimHeaders {
|
||||
if _, ok := formattedJWTClaims[claimName]; ok {
|
||||
l.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(claimName, fmt.Sprintf("%v", formattedJWTClaims[claimName]))
|
||||
})
|
||||
|
||||
headerName := fmt.Sprintf("x-pomerium-claim-%s", claimName)
|
||||
r.Header.Set(headerName, formattedJWTClaims[claimName])
|
||||
}
|
||||
|
||||
// set headers for any claims specified by config
|
||||
for _, claimName := range p.jwtClaimHeaders {
|
||||
if _, ok := formattedJWTClaims[claimName]; ok {
|
||||
|
||||
headerName := fmt.Sprintf("x-pomerium-claim-%s", claimName)
|
||||
r.Header.Set(headerName, formattedJWTClaims[claimName])
|
||||
if returnJWTInfo {
|
||||
w.Header().Add(headerName, formattedJWTClaims[claimName])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// getFormatJWTClaims reformats jwtClaims into something resembling map[string]string
|
||||
func (p *Proxy) getFormatedJWTClaims(jwt []byte) (map[string]string, error) {
|
||||
formattedJWTClaims := make(map[string]string)
|
||||
|
||||
var jwtClaims map[string]interface{}
|
||||
if err := p.encoder.Unmarshal(jwt, &jwtClaims); err != nil {
|
||||
return formattedJWTClaims, err
|
||||
}
|
||||
|
||||
for claim, value := range jwtClaims {
|
||||
var formattedClaim string
|
||||
if cv, ok := value.([]interface{}); ok {
|
||||
elements := make([]string, len(cv))
|
||||
|
||||
for i, v := range cv {
|
||||
elements[i] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
formattedClaim = strings.Join(elements, ",")
|
||||
} else {
|
||||
formattedClaim = fmt.Sprintf("%v", value)
|
||||
}
|
||||
formattedJWTClaims[claim] = formattedClaim
|
||||
}
|
||||
|
||||
return formattedJWTClaims, nil
|
||||
}
|
||||
|
|
|
@ -73,7 +73,7 @@ func TestProxy_AuthenticateSession(t *testing.T) {
|
|||
r = r.WithContext(ctx)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
got := a.jwtClaimMiddleware(a.AuthenticateSession(fn))
|
||||
got := a.jwtClaimMiddleware(false)(a.AuthenticateSession(fn))
|
||||
got.ServeHTTP(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("AuthenticateSession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||
|
@ -113,7 +113,7 @@ func Test_jwtClaimMiddleware(t *testing.T) {
|
|||
ctx = sessions.NewContext(ctx, string(state), nil)
|
||||
r = r.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
proxyHandler := a.jwtClaimMiddleware(handler)
|
||||
proxyHandler := a.jwtClaimMiddleware(true)(handler)
|
||||
proxyHandler.ServeHTTP(w, r)
|
||||
|
||||
t.Run("email claim", func(t *testing.T) {
|
||||
|
@ -130,6 +130,13 @@ func Test_jwtClaimMiddleware(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("email response claim", func(t *testing.T) {
|
||||
emailHeader := w.Header().Get("x-pomerium-claim-email")
|
||||
if emailHeader != email {
|
||||
t.Errorf("did not find claim email in response, want=%q, got=%q", email, emailHeader)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing claim", func(t *testing.T) {
|
||||
absentHeader := r.Header.Get("x-pomerium-claim-missing")
|
||||
if absentHeader != "" {
|
||||
|
|
|
@ -274,7 +274,7 @@ func (p *Proxy) reverseProxyHandler(r *mux.Router, policy config.Policy) *mux.Ro
|
|||
// 7. Strip the user session cookie from the downstream request
|
||||
rp.Use(middleware.StripCookie(p.cookieOptions.Name))
|
||||
// 8 . Add claim details to the request logger context and headers
|
||||
rp.Use(p.jwtClaimMiddleware)
|
||||
rp.Use(p.jwtClaimMiddleware(false))
|
||||
|
||||
return r
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue