pomerium/internal/middleware/middleware.go
Bobby DeSimone 90ab756de1
Added gif to the readme.
Simplified, and de-duplicated many of the configuration settings.
Removed configuration settings that could be deduced from other settings.
Added some basic documentation.
Removed the (duplicate?) user email domain validation check in proxy.
Removed the ClientID middleware check.
Added a shared key option to be used as a PSK instead of using the IDPs ClientID and ClientSecret.
Removed the CookieSecure setting as we only support secure.
Added a letsencrypt script to generate a wildcard certificate.
Removed the argument in proxy's constructor that allowed arbitrary fucntions to be passed in as validators.
Updated proxy's authenticator client to match the server implementation of just using a PSK.
Moved debug-mode logging into the log package.
Removed unused approval prompt setting.
Fixed a bug where identity provider urls were hardcoded.
Removed a bunch of unit tests. There have been so many changes many of these tests don't make sense and will need to be re-thought.
2019-01-04 18:25:03 -08:00

185 lines
5.6 KiB
Go

package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
)
// SetHeaders ensures that every response includes some basic security headers
func SetHeaders(h http.Handler, securityHeaders map[string]string) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for key, val := range securityHeaders {
rw.Header().Set(key, val)
}
h.ServeHTTP(rw, req)
})
}
// WithMethods writes an error response if the method of the request is not included.
func WithMethods(f http.HandlerFunc, methods ...string) http.HandlerFunc {
methodMap := make(map[string]struct{}, len(methods))
for _, m := range methods {
methodMap[m] = struct{}{}
}
return func(rw http.ResponseWriter, req *http.Request) {
if _, ok := methodMap[req.Method]; !ok {
httputil.ErrorResponse(rw, req, fmt.Sprintf("method %s not allowed", req.Method), http.StatusMethodNotAllowed)
return
}
f(rw, req)
}
}
// ValidateClientSecret checks the request header for the client secret and returns
// an error if it does not match the proxy client secret
func ValidateClientSecret(f http.HandlerFunc, sharedSecret string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
return
}
clientSecret := req.Form.Get("shared_secret")
// check the request header for the client secret
if clientSecret == "" {
clientSecret = req.Header.Get("X-Client-Secret")
}
if clientSecret != sharedSecret {
log.Error().
Str("clientSecret", clientSecret).
Str("sharedSecret", sharedSecret).
Msg("middleware.ValidateClientSecret")
httputil.ErrorResponse(rw, req, "Invalid client secret", http.StatusUnauthorized)
return
}
f(rw, req)
}
}
// ValidateRedirectURI checks the redirect uri in the query parameters and ensures that
// the url's domain is one in the list of proxy root domains.
func ValidateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
return
}
redirectURI := req.Form.Get("redirect_uri")
if !validRedirectURI(redirectURI, proxyRootDomains) {
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
return
}
f(rw, req)
}
}
func validRedirectURI(uri string, rootDomains []string) bool {
redirectURL, err := url.Parse(uri)
if uri == "" || err != nil || redirectURL.Host == "" {
return false
}
for _, domain := range rootDomains {
if strings.HasSuffix(redirectURL.Hostname(), domain) {
return true
}
}
return false
}
// ValidateSignature ensures the request is valid and has been signed with
// the correspdoning client secret key
func ValidateSignature(f http.HandlerFunc, sharedSecret string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
return
}
redirectURI := req.Form.Get("redirect_uri")
sigVal := req.Form.Get("sig")
timestamp := req.Form.Get("ts")
if !validSignature(redirectURI, sigVal, timestamp, sharedSecret) {
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
return
}
f(rw, req)
}
}
// ValidateHost ensures that each request's host is valid
func ValidateHost(h http.Handler, mux map[string]*http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if _, ok := mux[req.Host]; !ok {
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
return
}
h.ServeHTTP(rw, req)
})
}
// RequireHTTPS reroutes a HTTP request to HTTPS
// todo(bdd) : this is unreliable unless behind another reverser proxy
// todo(bdd) : header age seems extreme
func RequireHTTPS(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Strict-Transport-Security", "max-age=31536000")
// todo(bdd) : scheme and x-forwarded-proto cannot be trusted if not behind another load balancer
if (req.URL.Scheme == "http" && req.Header.Get("X-Forwarded-Proto") == "http") || &req.TLS == nil {
dest := &url.URL{
Scheme: "https",
Host: req.Host,
Path: req.URL.Path,
RawQuery: req.URL.RawQuery,
}
http.Redirect(rw, req, dest.String(), http.StatusMovedPermanently)
return
}
h.ServeHTTP(rw, req)
})
}
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
return false
}
_, err := url.Parse(redirectURI)
if err != nil {
return false
}
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
if err != nil {
return false
}
i, err := strconv.ParseInt(timestamp, 10, 64)
if err != nil {
return false
}
tm := time.Unix(i, 0)
ttl := 5 * time.Minute
if time.Since(tm) > ttl {
return false
}
localSig := redirectURLSignature(redirectURI, tm, secret)
return hmac.Equal(requestSig, localSig)
}
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(rawRedirect))
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
return h.Sum(nil)
}