mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-24 03:59:49 +02:00
authenticate: refactor middleware, logging, and tests (#30)
- Abstract remaining middleware from authenticate into internal. - Use middleware chaining in authenticate. - Standardize naming of Request and ResponseWriter to match std lib. - Add healthcheck / ping as a middleware. - Internalized wraped_writer package adapted from goji/middleware. - Fixed indirection issue with reverse proxy map.
This commit is contained in:
parent
b9c298d278
commit
7e1d1a7896
21 changed files with 768 additions and 397 deletions
|
@ -1,4 +1,3 @@
|
|||
// Package middleware provides a standard set of middleware implementations for pomerium.
|
||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
|
@ -15,91 +14,76 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
)
|
||||
|
||||
// SetHeadersOld ensures that every response includes some basic security headers
|
||||
func SetHeadersOld(h http.Handler, securityHeaders map[string]string) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
for key, val := range securityHeaders {
|
||||
rw.Header().Set(key, val)
|
||||
}
|
||||
h.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
|
||||
// SetHeaders ensures that every response includes some basic security headers
|
||||
func SetHeaders(securityHeaders map[string]string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for key, val := range securityHeaders {
|
||||
rw.Header().Set(key, val)
|
||||
w.Header().Set(key, val)
|
||||
}
|
||||
next.ServeHTTP(rw, req)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WithMethods writes an error response if the method of the request is not included.
|
||||
func WithMethods(f http.HandlerFunc, methods ...string) http.HandlerFunc {
|
||||
methodMap := make(map[string]struct{}, len(methods))
|
||||
for _, m := range methods {
|
||||
methodMap[m] = struct{}{}
|
||||
}
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
if _, ok := methodMap[req.Method]; !ok {
|
||||
httputil.ErrorResponse(rw, req, fmt.Sprintf("method %s not allowed", req.Method), http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
f(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateClientSecret checks the request header for the client secret and returns
|
||||
// an error if it does not match the proxy client secret
|
||||
func ValidateClientSecret(f http.HandlerFunc, sharedSecret string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
clientSecret := req.Form.Get("shared_secret")
|
||||
// check the request header for the client secret
|
||||
if clientSecret == "" {
|
||||
clientSecret = req.Header.Get("X-Client-Secret")
|
||||
}
|
||||
func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientSecret := r.Form.Get("shared_secret")
|
||||
// check the request header for the client secret
|
||||
if clientSecret == "" {
|
||||
clientSecret = r.Header.Get("X-Client-Secret")
|
||||
}
|
||||
|
||||
if clientSecret != sharedSecret {
|
||||
httputil.ErrorResponse(rw, req, "Invalid client secret", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
f(rw, req)
|
||||
if clientSecret != sharedSecret {
|
||||
httputil.ErrorResponse(w, r, "Invalid client secret", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateRedirectURI checks the redirect uri in the query parameters and ensures that
|
||||
// the url's domain is one in the list of proxy root domains.
|
||||
func ValidateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
redirectURI := req.Form.Get("redirect_uri")
|
||||
if !validRedirectURI(redirectURI, proxyRootDomains) {
|
||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
f(rw, req)
|
||||
// the its domain is in the list of proxy root domains.
|
||||
func ValidateRedirectURI(proxyRootDomains []string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
redirectURI := r.Form.Get("redirect_uri")
|
||||
if !ValidRedirectURI(redirectURI, proxyRootDomains) {
|
||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func validRedirectURI(uri string, rootDomains []string) bool {
|
||||
// ValidRedirectURI checks if a URL's domain is one in the list of proxy root domains.
|
||||
func ValidRedirectURI(uri string, rootDomains []string) bool {
|
||||
if uri == "" || len(rootDomains) == 0 {
|
||||
return false
|
||||
}
|
||||
redirectURL, err := url.Parse(uri)
|
||||
if uri == "" || err != nil || redirectURL.Host == "" {
|
||||
if err != nil || redirectURL.Host == "" {
|
||||
return false
|
||||
}
|
||||
for _, domain := range rootDomains {
|
||||
if domain == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasSuffix(redirectURL.Hostname(), domain) {
|
||||
return true
|
||||
}
|
||||
|
@ -109,35 +93,36 @@ func validRedirectURI(uri string, rootDomains []string) bool {
|
|||
|
||||
// ValidateSignature ensures the request is valid and has been signed with
|
||||
// the correspdoning client secret key
|
||||
func ValidateSignature(f http.HandlerFunc, sharedSecret string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
redirectURI := req.Form.Get("redirect_uri")
|
||||
sigVal := req.Form.Get("sig")
|
||||
timestamp := req.Form.Get("ts")
|
||||
if !validSignature(redirectURI, sigVal, timestamp, sharedSecret) {
|
||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
redirectURI := r.Form.Get("redirect_uri")
|
||||
sigVal := r.Form.Get("sig")
|
||||
timestamp := r.Form.Get("ts")
|
||||
if !ValidSignature(redirectURI, sigVal, timestamp, sharedSecret) {
|
||||
httputil.ErrorResponse(w, r, "Cross service signature failed to validate", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
f(rw, req)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateHost ensures that each request's host is valid
|
||||
func ValidateHost(mux map[string]*http.Handler) func(next http.Handler) http.Handler {
|
||||
func ValidateHost(mux map[string]http.Handler) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if _, ok := mux[req.Host]; !ok {
|
||||
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, ok := mux[r.Host]; !ok {
|
||||
httputil.ErrorResponse(w, r, "Unknown host to route", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(rw, req)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -145,25 +130,47 @@ func ValidateHost(mux map[string]*http.Handler) func(next http.Handler) http.Han
|
|||
// RequireHTTPS reroutes a HTTP request to HTTPS
|
||||
// todo(bdd) : this is unreliable unless behind another reverser proxy
|
||||
// todo(bdd) : header age seems extreme
|
||||
func RequireHTTPS(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("Strict-Transport-Security", "max-age=31536000")
|
||||
func RequireHTTPS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=31536000")
|
||||
// todo(bdd) : scheme and x-forwarded-proto cannot be trusted if not behind another load balancer
|
||||
if (req.URL.Scheme == "http" && req.Header.Get("X-Forwarded-Proto") == "http") || &req.TLS == nil {
|
||||
if (r.URL.Scheme == "http" && r.Header.Get("X-Forwarded-Proto") == "http") || &r.TLS == nil {
|
||||
dest := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: req.Host,
|
||||
Path: req.URL.Path,
|
||||
RawQuery: req.URL.RawQuery,
|
||||
Host: r.Host,
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
}
|
||||
http.Redirect(rw, req, dest.String(), http.StatusMovedPermanently)
|
||||
http.Redirect(w, r, dest.String(), http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(rw, req)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||
// Healthcheck endpoint middleware useful to setting up a path like
|
||||
// `/ping` that load balancers or uptime testing external services
|
||||
// can make a request before hitting any routes. It's also convenient
|
||||
// to place this above ACL middlewares as well.
|
||||
func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler {
|
||||
f := func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "GET" && strings.EqualFold(r.URL.Path, endpoint) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(msg))
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// ValidSignature checks to see if a signature is valid. Compares hmac of
|
||||
// redirect uri, timestamp, and secret and signature.
|
||||
func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
||||
return false
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue