pomerium/internal/middleware/middleware.go
Bobby DeSimone 00c29f4e77
authenticate: handle XHR redirect flow (#387)
- authenticate: add cors preflight check support for sign_in endpoint
- internal/httputil: indicate responses that originate from pomerium vs the app
- proxy: detect XHR requests and do not redirect on failure.
- authenticate: removed default session duration; should be maintained out of band with rpc.
2019-11-14 19:37:31 -08:00

133 lines
4.5 KiB
Go

package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"encoding/base64"
"fmt"
"net/http"
"strings"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil"
)
// SetHeaders sets a map of response headers.
func SetHeaders(headers map[string]string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.SetHeaders")
defer span.End()
for key, val := range headers {
w.Header().Set(key, val)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// ValidateSignature ensures the request is valid and has been signed with
// the correspdoning client secret key
func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature")
defer span.End()
if !ValidateRedirectURI(r, sharedSecret) {
httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil))
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// ValidateRedirectURI takes a request and parses `redirect_uri`, `sig`, `ts`
// and validates the supplied signature (`sig`)'s HMAC for validity.
func ValidateRedirectURI(r *http.Request, key string) bool {
return ValidSignature(
r.FormValue("redirect_uri"),
r.FormValue("sig"),
r.FormValue("ts"),
key)
}
// Healthcheck endpoint middleware useful to setting up a path like
// `/ping` that load balancers or uptime testing external services
// can make a request before hitting any routes. It's also convenient
// to place this above ACL middlewares as well.
func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler {
f := func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.Healthcheck")
defer span.End()
if strings.EqualFold(r.URL.Path, endpoint) {
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html
if r.Method != http.MethodGet && r.Method != http.MethodHead {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
if r.Method == http.MethodGet {
w.Write([]byte(msg))
}
return
}
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
return f
}
// ValidSignature checks to see if a signature is valid. Compares hmac of
// redirect uri, timestamp, and secret and signature.
func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
return false
}
_, err := urlutil.ParseAndValidateURL(redirectURI)
if err != nil {
return false
}
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
if err != nil {
return false
}
if err := cryptutil.ValidTimestamp(timestamp); err != nil {
return false
}
return cryptutil.CheckHMAC([]byte(fmt.Sprint(redirectURI, timestamp)), requestSig, secret)
}
// StripCookie strips the cookie from the downstram request.
func StripCookie(cookieName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.StripCookie")
defer span.End()
headers := make([]string, 0, len(r.Cookies()))
for _, cookie := range r.Cookies() {
if !strings.HasPrefix(cookie.Name, cookieName) {
headers = append(headers, cookie.String())
}
}
r.Header.Set("Cookie", strings.Join(headers, ";"))
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// TimeoutHandlerFunc wraps http.TimeoutHandler
func TimeoutHandlerFunc(timeout time.Duration, timeoutError string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.TimeoutHandlerFunc")
defer span.End()
http.TimeoutHandler(next, timeout, timeoutError).ServeHTTP(w, r.WithContext(ctx))
})
}
}