Add configurable JWT claim headers (#596)

This commit is contained in:
Travis Groth 2020-04-09 23:41:55 -04:00 committed by GitHub
parent b08ecc624a
commit 789068e27a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 118 additions and 17 deletions

View file

@ -155,25 +155,50 @@ func SetResponseHeaders(headers map[string]string) func(next http.Handler) http.
}
}
func (p *Proxy) userDetailsLoggerMiddleware(next http.Handler) http.Handler {
func (p *Proxy) jwtClaimMiddleware(next http.Handler) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if jwt, err := sessions.FromContext(r.Context()); err == nil {
var s sessions.State
if err := p.encoder.Unmarshal([]byte(jwt), &s); err == nil {
var jwtClaims map[string]interface{}
if err := p.encoder.Unmarshal([]byte(jwt), &jwtClaims); err == nil {
formattedJWTClaims := make(map[string]string)
// reformat claims into something resembling map[string]string
for claim, value := range jwtClaims {
var formattedClaim string
if cv, ok := value.([]interface{}); ok {
elements := make([]string, len(cv))
for i, v := range cv {
elements[i] = fmt.Sprintf("%v", v)
}
formattedClaim = strings.Join(elements, ",")
} else {
formattedClaim = fmt.Sprintf("%v", value)
}
formattedJWTClaims[claim] = formattedClaim
}
// log group, email, user claims
l := log.Ctx(r.Context())
l.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Strs("groups", s.Groups)
})
l.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("email", s.Email)
})
l.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("user-id", s.User)
})
for _, claimName := range []string{"groups", "email", "user"} {
l.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(claimName, fmt.Sprintf("%v", formattedJWTClaims[claimName]))
})
}
// set headers for any claims specified by config
for _, claimName := range p.jwtClaimHeaders {
if _, ok := formattedJWTClaims[claimName]; ok {
headerName := fmt.Sprintf("x-pomerium-claim-%s", claimName)
r.Header.Set(headerName, formattedJWTClaims[claimName])
}
}
}
}
next.ServeHTTP(w, r)
return nil
})
}