pomerium/internal/middleware/middleware.go
Bobby DeSimone 5edfa7b03f
telemetry: add tracing
- telemetry/tace: add traces throughout code
- telemetry/metrics: nest metrics and trace under telemetry
- telemetry/tace: add service name span to HTTPMetricsHandler.
- telemetry/metrics: removed chain dependency middleware_tests.
- telemetry/metrics: wrap and encapsulate variatic view registration.
- telemetry/tace: add jaeger support for tracing.
- cmd/pomerium: move `parseOptions` to internal/config.
- cmd/pomerium: offload server handling to httputil and sub pkgs.
- httputil: standardize creation/shutdown of http listeners.
- httputil: prefer curve X25519 to P256 when negotiating TLS.
- fileutil: use standardized Getw

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
2019-07-24 09:20:16 -07:00

211 lines
6.5 KiB
Go

package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"crypto/hmac"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"golang.org/x/net/publicsuffix"
)
// SetHeaders ensures that every response includes some basic security headers
func SetHeaders(securityHeaders map[string]string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.SetHeaders")
defer span.End()
for key, val := range securityHeaders {
w.Header().Set(key, val)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// ValidateClientSecret checks the request header for the client secret and returns
// an error if it does not match the proxy client secret
func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateClientSecret")
defer span.End()
if err := r.ParseForm(); err != nil {
httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return
}
clientSecret := r.Form.Get("shared_secret")
// check the request header for the client secret
if clientSecret == "" {
clientSecret = r.Header.Get("X-Client-Secret")
}
if clientSecret != sharedSecret {
httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusInternalServerError})
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// ValidateRedirectURI checks the redirect uri in the query parameters and ensures that
// the its domain is in the list of proxy root domains.
func ValidateRedirectURI(rootDomain *url.URL) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateRedirectURI")
defer span.End()
err := r.ParseForm()
if err != nil {
httpErr := &httputil.Error{
Message: err.Error(),
Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return
}
redirectURI, err := url.Parse(r.Form.Get("redirect_uri"))
if err != nil {
httpErr := &httputil.Error{
Message: err.Error(),
Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return
}
if !SameDomain(redirectURI, rootDomain) {
httpErr := &httputil.Error{
Message: "Invalid redirect parameter",
Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// SameDomain checks to see if two URLs share the top level domain (TLD Plus One).
func SameDomain(u, j *url.URL) bool {
a, err := publicsuffix.EffectiveTLDPlusOne(u.Hostname())
if err != nil {
return false
}
b, err := publicsuffix.EffectiveTLDPlusOne(j.Hostname())
if err != nil {
return false
}
return a == b
}
// ValidateSignature ensures the request is valid and has been signed with
// the correspdoning client secret key
func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature")
defer span.End()
err := r.ParseForm()
if err != nil {
httpErr := &httputil.Error{Message: err.Error(), Code: http.StatusBadRequest}
httputil.ErrorResponse(w, r, httpErr)
return
}
redirectURI := r.Form.Get("redirect_uri")
sigVal := r.Form.Get("sig")
timestamp := r.Form.Get("ts")
if !ValidSignature(redirectURI, sigVal, timestamp, sharedSecret) {
httpErr := &httputil.Error{
Message: "Cross service signature failed to validate",
Code: http.StatusUnauthorized}
httputil.ErrorResponse(w, r, httpErr)
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// ValidateHost ensures that each request's host is valid
func ValidateHost(validHost func(host string) bool) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateHost")
defer span.End()
if !validHost(r.Host) {
httputil.ErrorResponse(w, r, &httputil.Error{Code: http.StatusNotFound})
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Healthcheck endpoint middleware useful to setting up a path like
// `/ping` that load balancers or uptime testing external services
// can make a request before hitting any routes. It's also convenient
// to place this above ACL middlewares as well.
func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler {
f := func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.Healthcheck")
defer span.End()
if r.Method == "GET" && strings.EqualFold(r.URL.Path, endpoint) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(msg))
return
}
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
return f
}
// ValidSignature checks to see if a signature is valid. Compares hmac of
// redirect uri, timestamp, and secret and signature.
func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
return false
}
_, err := url.Parse(redirectURI)
if err != nil {
return false
}
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
if err != nil {
return false
}
i, err := strconv.ParseInt(timestamp, 10, 64)
if err != nil {
return false
}
tm := time.Unix(i, 0)
ttl := 5 * time.Minute
if time.Since(tm) > ttl {
return false
}
localSig := redirectURLSignature(redirectURI, tm, secret)
return hmac.Equal(requestSig, localSig)
}
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
h := cryptutil.Hash(secret, data)
return h
}