pomerium/internal/middleware/middleware.go
Caleb Doxsey 6d1d2bec54
crypto: use actual bytes of shared secret, not the base64 encoded representation (#2075)
* crypto: use actual bytes of shared secret, not the base64 encoded representation

* return errors

* return errors
2021-04-08 20:04:01 -06:00

107 lines
3.7 KiB
Go

package middleware
import (
"crypto/sha256"
"crypto/subtle"
"net/http"
"strings"
"time"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil"
)
// SetHeaders sets a map of response headers.
func SetHeaders(headers map[string]string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.SetHeaders")
defer span.End()
for key, val := range headers {
w.Header().Set(key, val)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// ValidateSignature ensures the request is valid and has been signed with
// the correspdoning client secret key
func ValidateSignature(sharedKey []byte) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature")
defer span.End()
if err := ValidateRequestURL(r, sharedKey); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
next.ServeHTTP(w, r.WithContext(ctx))
return nil
})
}
}
// ValidateRequestURL validates the current absolute request URL was signed
// by a given shared key.
func ValidateRequestURL(r *http.Request, key []byte) error {
return urlutil.NewSignedURL(key, urlutil.GetAbsoluteURL(r)).Validate()
}
// StripCookie strips the cookie from the downstram request.
func StripCookie(cookieName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.StripCookie")
defer span.End()
headers := make([]string, 0, len(r.Cookies()))
for _, cookie := range r.Cookies() {
if !strings.HasPrefix(cookie.Name, cookieName) {
headers = append(headers, cookie.String())
}
}
r.Header.Set("Cookie", strings.Join(headers, ";"))
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// TimeoutHandlerFunc wraps http.TimeoutHandler
func TimeoutHandlerFunc(timeout time.Duration, timeoutError string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.TimeoutHandlerFunc")
defer span.End()
http.TimeoutHandler(next, timeout, timeoutError).ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RequireBasicAuth creates a new handler that requires basic auth from the client before
// calling the underlying handler.
func RequireBasicAuth(username, password string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
u, p, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
givenUser := sha256.Sum256([]byte(u))
givenPass := sha256.Sum256([]byte(p))
requiredUser := sha256.Sum256([]byte(username))
requiredPass := sha256.Sum256([]byte(password))
if subtle.ConstantTimeCompare(givenUser[:], requiredUser[:]) != 1 ||
subtle.ConstantTimeCompare(givenPass[:], requiredPass[:]) != 1 {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
}