mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-29 08:57:18 +02:00
authenticate: refactor middleware, logging, and tests (#30)
- Abstract remaining middleware from authenticate into internal. - Use middleware chaining in authenticate. - Standardize naming of Request and ResponseWriter to match std lib. - Add healthcheck / ping as a middleware. - Internalized wraped_writer package adapted from goji/middleware. - Fixed indirection issue with reverse proxy map.
This commit is contained in:
parent
b9c298d278
commit
7e1d1a7896
21 changed files with 768 additions and 397 deletions
|
@ -71,11 +71,10 @@ func OptionsFromEnvConfig() (*Options, error) {
|
|||
return o, nil
|
||||
}
|
||||
|
||||
// Validate checks to see if configuration values are valid for authentication service.
|
||||
// The checks do not modify the internal state of the Option structure. Function returns
|
||||
// Validate checks to see if configuration values are valid for the authentication service.
|
||||
// The checks do not modify the internal state of the Option structure. Returns
|
||||
// on first error found.
|
||||
func (o *Options) Validate() error {
|
||||
|
||||
if o.RedirectURL == nil {
|
||||
return errors.New("missing setting: identity provider redirect url")
|
||||
}
|
||||
|
@ -105,11 +104,11 @@ func (o *Options) Validate() error {
|
|||
if len(decodedCookieSecret) != 32 {
|
||||
return fmt.Errorf("cookie secret expects 32 bytes but got %d", len(decodedCookieSecret))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authenticate stores all the information associated with proxying the request.
|
||||
// Authenticate is service for validating user authentication for proxied-requests
|
||||
// against third-party identity provider (IdP) services.
|
||||
type Authenticate struct {
|
||||
RedirectURL *url.URL
|
||||
|
||||
|
@ -133,7 +132,7 @@ type Authenticate struct {
|
|||
provider providers.Provider
|
||||
}
|
||||
|
||||
// New creates a Authenticate struct and applies the optional functions slice to the struct.
|
||||
// New validates and creates a new authentication service from a configuration options.
|
||||
func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate, error) {
|
||||
if opts == nil {
|
||||
return nil, errors.New("options cannot be nil")
|
||||
|
@ -179,13 +178,13 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
|
|||
cipher: cipher,
|
||||
skipProviderButton: opts.SkipProviderButton,
|
||||
}
|
||||
// p.ServeMux = p.Handler()
|
||||
|
||||
p.provider, err = newProvider(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// apply the option functions
|
||||
// validation via dependency injected function
|
||||
for _, optFunc := range optionFuncs {
|
||||
err := optFunc(p)
|
||||
if err != nil {
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
m "github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
@ -28,45 +28,58 @@ var securityHeaders = map[string]string{
|
|||
|
||||
// Handler returns the Http.Handlers for authentication, callback, and refresh
|
||||
func (p *Authenticate) Handler() http.Handler {
|
||||
// set up our standard middlewares
|
||||
stdMiddleware := middleware.NewChain()
|
||||
stdMiddleware = stdMiddleware.Append(middleware.NewHandler(log.Logger))
|
||||
stdMiddleware = stdMiddleware.Append(middleware.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
||||
// executed after handler route handler
|
||||
middleware.FromRequest(r).Info().
|
||||
Str("method", r.Method).
|
||||
Str("url", r.URL.String()).
|
||||
Int("status", status).
|
||||
Int("size", size).
|
||||
Dur("duration", duration).
|
||||
Msg("request")
|
||||
}))
|
||||
stdMiddleware = stdMiddleware.Append(middleware.SetHeaders(securityHeaders))
|
||||
stdMiddleware = stdMiddleware.Append(middleware.ForwardedAddrHandler("fwd_ip"))
|
||||
stdMiddleware = stdMiddleware.Append(middleware.RemoteAddrHandler("ip"))
|
||||
stdMiddleware = stdMiddleware.Append(middleware.UserAgentHandler("user_agent"))
|
||||
stdMiddleware = stdMiddleware.Append(middleware.RefererHandler("referer"))
|
||||
stdMiddleware = stdMiddleware.Append(middleware.RequestIDHandler("req_id", "Request-Id"))
|
||||
stdMiddleware = stdMiddleware.Append(middleware.Healthcheck("/ping", version.UserAgent()))
|
||||
validateSignatureMiddleware := stdMiddleware.Append(
|
||||
middleware.ValidateSignature(p.SharedKey),
|
||||
middleware.ValidateRedirectURI(p.ProxyRootDomains))
|
||||
|
||||
validateClientSecret := stdMiddleware.Append(middleware.ValidateClientSecret(p.SharedKey))
|
||||
|
||||
mux := http.NewServeMux()
|
||||
// we setup global endpoints that should respond to any hostname
|
||||
mux.HandleFunc("/ping", m.WithMethods(p.PingPage, "GET"))
|
||||
|
||||
mux.Handle("/ping", stdMiddleware.ThenFunc(p.PingPage))
|
||||
serviceMux := http.NewServeMux()
|
||||
// standard rest and healthcheck endpoints
|
||||
serviceMux.HandleFunc("/ping", m.WithMethods(p.PingPage, "GET"))
|
||||
serviceMux.HandleFunc("/robots.txt", m.WithMethods(p.RobotsTxt, "GET"))
|
||||
serviceMux.Handle("/ping", stdMiddleware.ThenFunc(p.PingPage))
|
||||
serviceMux.Handle("/robots.txt", stdMiddleware.ThenFunc(p.RobotsTxt))
|
||||
// Identity Provider (IdP) endpoints and callbacks
|
||||
serviceMux.HandleFunc("/start", m.WithMethods(p.OAuthStart, "GET"))
|
||||
serviceMux.HandleFunc("/oauth2/callback", m.WithMethods(p.OAuthCallback, "GET"))
|
||||
serviceMux.Handle("/start", stdMiddleware.ThenFunc(p.OAuthStart))
|
||||
serviceMux.Handle("/oauth2/callback", stdMiddleware.ThenFunc(p.OAuthCallback))
|
||||
// authenticator-server endpoints, todo(bdd): make gRPC
|
||||
serviceMux.HandleFunc("/sign_in", m.WithMethods(p.validateSignature(p.SignIn), "GET"))
|
||||
serviceMux.HandleFunc("/sign_out", m.WithMethods(p.validateSignature(p.SignOut), "GET", "POST"))
|
||||
serviceMux.HandleFunc("/profile", m.WithMethods(p.validateExisting(p.GetProfile), "GET"))
|
||||
serviceMux.HandleFunc("/validate", m.WithMethods(p.validateExisting(p.ValidateToken), "GET"))
|
||||
serviceMux.HandleFunc("/redeem", m.WithMethods(p.validateExisting(p.Redeem), "POST"))
|
||||
serviceMux.HandleFunc("/refresh", m.WithMethods(p.validateExisting(p.Refresh), "POST"))
|
||||
serviceMux.Handle("/sign_in", validateSignatureMiddleware.ThenFunc(p.SignIn))
|
||||
serviceMux.Handle("/sign_out", validateSignatureMiddleware.ThenFunc(p.SignOut)) // "GET", "POST"
|
||||
serviceMux.Handle("/profile", validateClientSecret.ThenFunc(p.GetProfile)) // GET
|
||||
serviceMux.Handle("/validate", validateClientSecret.ThenFunc(p.ValidateToken)) // GET
|
||||
serviceMux.Handle("/redeem", validateClientSecret.ThenFunc(p.Redeem)) // POST
|
||||
serviceMux.Handle("/refresh", validateClientSecret.ThenFunc(p.Refresh)) //POST
|
||||
|
||||
// NOTE: we have to include trailing slash for the router to match the host header
|
||||
host := p.RedirectURL.Host
|
||||
if !strings.HasSuffix(host, "/") {
|
||||
host = fmt.Sprintf("%s/", host)
|
||||
}
|
||||
mux.Handle(host, serviceMux) // setup our service mux to only handle our required host header
|
||||
mux.Handle(host, serviceMux)
|
||||
|
||||
return m.SetHeadersOld(mux, securityHeaders)
|
||||
}
|
||||
|
||||
// validateSignature wraps a common collection of middlewares to validate signatures
|
||||
func (p *Authenticate) validateSignature(f http.HandlerFunc) http.HandlerFunc {
|
||||
return validateRedirectURI(validateSignature(f, p.SharedKey), p.ProxyRootDomains)
|
||||
|
||||
}
|
||||
|
||||
// validateSignature wraps a common collection of middlewares to validate
|
||||
// a (presumably) existing user session
|
||||
func (p *Authenticate) validateExisting(f http.HandlerFunc) http.HandlerFunc {
|
||||
return m.ValidateClientSecret(f, p.SharedKey)
|
||||
return mux
|
||||
}
|
||||
|
||||
// RobotsTxt handles the /robots.txt route.
|
||||
|
@ -83,10 +96,8 @@ func (p *Authenticate) PingPage(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// SignInPage directs the user to the sign in page. Takes a `redirect_uri` param.
|
||||
func (p *Authenticate) SignInPage(w http.ResponseWriter, r *http.Request) {
|
||||
// requestLog := log.WithRequest(req, "authenticate.SignInPage")
|
||||
redirectURL := p.RedirectURL.ResolveReference(r.URL)
|
||||
// validateRedirectURI middleware already ensures that this is a valid URL
|
||||
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))
|
||||
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri")) // checked by middleware
|
||||
t := struct {
|
||||
ProviderName string
|
||||
AllowedDomains []string
|
||||
|
@ -100,28 +111,27 @@ func (p *Authenticate) SignInPage(w http.ResponseWriter, r *http.Request) {
|
|||
Destination: destinationURL.Host,
|
||||
Version: version.FullVersion(),
|
||||
}
|
||||
log.Ctx(r.Context()).Info().
|
||||
log.FromRequest(r).Debug().
|
||||
Str("ProviderName", p.provider.Data().ProviderName).
|
||||
Str("Redirect", redirectURL.String()).
|
||||
Str("Destination", destinationURL.Host).
|
||||
Str("AllowedDomains", strings.Join(p.AllowedDomains, ", ")).
|
||||
Msg("authenticate.SignInPage")
|
||||
Msg("authenticate: SignInPage")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
p.templates.ExecuteTemplate(w, "sign_in.html", t)
|
||||
}
|
||||
|
||||
func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*sessions.SessionState, error) {
|
||||
// requestLog := log.WithRequest(req, "authenticate.authenticate")
|
||||
session, err := p.sessionStore.LoadSession(r)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate.authenticate")
|
||||
log.Error().Err(err).Msg("authenticate: failed to load session")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// ensure sessions lifetime has not expired
|
||||
if session.LifetimePeriodExpired() {
|
||||
log.Ctx(r.Context()).Warn().Msg("lifetime expired")
|
||||
log.FromRequest(r).Warn().Msg("authenticate: lifetime expired")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
return nil, sessions.ErrLifetimeExpired
|
||||
}
|
||||
|
@ -129,12 +139,12 @@ func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
|
|||
if session.RefreshPeriodExpired() {
|
||||
ok, err := p.provider.RefreshSessionIfNeeded(session)
|
||||
if err != nil {
|
||||
log.Ctx(r.Context()).Error().Err(err).Msg("failed to refresh session")
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to refresh session")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
log.Ctx(r.Context()).Error().Msg("user unauthorized after refresh")
|
||||
log.FromRequest(r).Error().Msg("user unauthorized after refresh")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
return nil, httputil.ErrUserNotAuthorized
|
||||
}
|
||||
|
@ -144,7 +154,7 @@ func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
|
|||
// We refreshed the session successfully, but failed to save it.
|
||||
// This could be from failing to encode the session properly.
|
||||
// But, we clear the session cookie and reject the request
|
||||
log.Ctx(r.Context()).Error().Err(err).Msg("could not save refreshed session")
|
||||
log.FromRequest(r).Error().Err(err).Msg("could not save refreshed session")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
return nil, err
|
||||
}
|
||||
|
@ -152,20 +162,20 @@ func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
|
|||
// The session has not exceeded it's lifetime or requires refresh
|
||||
ok := p.provider.ValidateSessionState(session)
|
||||
if !ok {
|
||||
log.Ctx(r.Context()).Error().Msg("invalid session state")
|
||||
log.FromRequest(r).Error().Msg("invalid session state")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
return nil, httputil.ErrUserNotAuthorized
|
||||
}
|
||||
err = p.sessionStore.SaveSession(w, r, session)
|
||||
if err != nil {
|
||||
log.Ctx(r.Context()).Error().Err(err).Msg("failed to save valid session")
|
||||
log.FromRequest(r).Error().Err(err).Msg("failed to save valid session")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if !p.Validator(session.Email) {
|
||||
log.Ctx(r.Context()).Error().Msg("invalid email user")
|
||||
log.FromRequest(r).Error().Msg("invalid email user")
|
||||
return nil, httputil.ErrUserNotAuthorized
|
||||
}
|
||||
return session, nil
|
||||
|
@ -316,7 +326,7 @@ func (p *Authenticate) SignOutPage(w http.ResponseWriter, r *http.Request, messa
|
|||
|
||||
signature := r.Form.Get("sig")
|
||||
timestamp := r.Form.Get("ts")
|
||||
destinationURL, _ := url.Parse(redirectURI)
|
||||
destinationURL, _ := url.Parse(redirectURI) //checked by middleware
|
||||
|
||||
// An error message indicates that an internal server error occurred
|
||||
if message != "" {
|
||||
|
@ -341,7 +351,6 @@ func (p *Authenticate) SignOutPage(w http.ResponseWriter, r *http.Request, messa
|
|||
Version: version.FullVersion(),
|
||||
}
|
||||
p.templates.ExecuteTemplate(w, "sign_out.html", t)
|
||||
return
|
||||
}
|
||||
|
||||
// OAuthStart starts the authentication process by redirecting to the provider. It provides a
|
||||
|
@ -364,20 +373,20 @@ func (p *Authenticate) helperOAuthStart(w http.ResponseWriter, r *http.Request,
|
|||
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
|
||||
p.csrfStore.SetCSRF(w, r, nonce)
|
||||
|
||||
if !validRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
|
||||
if !middleware.ValidRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
|
||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
|
||||
if err != nil || !validRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
|
||||
if err != nil || !middleware.ValidRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
|
||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
proxyRedirectSig := authRedirectURL.Query().Get("sig")
|
||||
ts := authRedirectURL.Query().Get("ts")
|
||||
if !validSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.SharedKey) {
|
||||
if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.SharedKey) {
|
||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
@ -404,7 +413,6 @@ func (p *Authenticate) redeemCode(host, code string) (*sessions.SessionState, er
|
|||
|
||||
// getOAuthCallback completes the oauth cycle from an identity provider's callback
|
||||
func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
// requestLog := log.WithRequest(req, "authenticate.getOAuthCallback")
|
||||
// finish the oauth cycle
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
|
@ -421,17 +429,18 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
|
||||
session, err := p.redeemCode(r.Host, code)
|
||||
if err != nil {
|
||||
log.Ctx(r.Context()).Error().Err(err).Msg("error redeeming authentication code")
|
||||
return "", err
|
||||
log.FromRequest(r).Error().Err(err).Msg("error redeeming authentication code")
|
||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
|
||||
}
|
||||
|
||||
bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state"))
|
||||
if err != nil {
|
||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"}
|
||||
log.FromRequest(r).Error().Err(err).Msg("failed decoding state")
|
||||
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Couldn't decode state"}
|
||||
}
|
||||
s := strings.SplitN(string(bytes), ":", 2)
|
||||
if len(s) != 2 {
|
||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"}
|
||||
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Invalid State"}
|
||||
}
|
||||
nonce := s[0]
|
||||
redirect := s[1]
|
||||
|
@ -441,11 +450,11 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
p.csrfStore.ClearCSRF(w, r)
|
||||
if c.Value != nonce {
|
||||
log.Ctx(r.Context()).Error().Err(err).Msg("csrf token mismatch")
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "csrf failed"}
|
||||
log.FromRequest(r).Error().Err(err).Msg("CSRF token mismatch")
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "CSRF failed"}
|
||||
}
|
||||
|
||||
if !validRedirectURI(redirect, p.ProxyRootDomains) {
|
||||
if !middleware.ValidRedirectURI(redirect, p.ProxyRootDomains) {
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Redirect URI"}
|
||||
}
|
||||
|
||||
|
@ -453,10 +462,10 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
// - for p.Validator see validator.go#newValidatorImpl for more info
|
||||
// - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info
|
||||
if !p.Validator(session.Email) {
|
||||
log.Ctx(r.Context()).Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
|
||||
log.FromRequest(r).Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
|
||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Account"}
|
||||
}
|
||||
log.Ctx(r.Context()).Info().Str("email", session.Email).Msg("authentication complete")
|
||||
log.FromRequest(r).Info().Str("email", session.Email).Msg("authentication complete")
|
||||
err = p.sessionStore.SaveSession(w, r, session)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("internal error")
|
||||
|
@ -487,7 +496,6 @@ func (p *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|||
func (p *Authenticate) Redeem(w http.ResponseWriter, r *http.Request) {
|
||||
// The auth code is redeemed by the sso proxy for an access token, refresh token,
|
||||
// expiration, and email.
|
||||
// requestLog := log.WithRequest(req, "authenticate.Redeem")
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
|
||||
|
@ -496,19 +504,19 @@ func (p *Authenticate) Redeem(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
session, err := sessions.UnmarshalSession(r.Form.Get("code"), p.cipher)
|
||||
if err != nil {
|
||||
log.Ctx(r.Context()).Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid auth code")
|
||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to unmarshal session")
|
||||
http.Error(w, fmt.Sprintf("invalid auth code: %s", err.Error()), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if session == nil {
|
||||
log.Ctx(r.Context()).Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid session")
|
||||
http.Error(w, fmt.Sprintf("invalid session: %s", err.Error()), http.StatusUnauthorized)
|
||||
log.FromRequest(r).Error().Err(err).Msg("empty session")
|
||||
http.Error(w, fmt.Sprintf("empty session: %s", err.Error()), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if session != nil && (session.RefreshPeriodExpired() || session.LifetimePeriodExpired()) {
|
||||
log.Ctx(r.Context()).Error().Msg("expired session")
|
||||
log.FromRequest(r).Error().Msg("expired session")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
http.Error(w, fmt.Sprintf("expired session"), http.StatusUnauthorized)
|
||||
return
|
||||
|
@ -524,7 +532,7 @@ func (p *Authenticate) Redeem(w http.ResponseWriter, r *http.Request) {
|
|||
AccessToken: session.AccessToken,
|
||||
RefreshToken: session.RefreshToken,
|
||||
IDToken: session.IDToken,
|
||||
ExpiresIn: int64(session.RefreshDeadline.Sub(time.Now()).Seconds()),
|
||||
ExpiresIn: int64(time.Until(session.RefreshDeadline).Seconds()),
|
||||
Email: session.Email,
|
||||
}
|
||||
|
||||
|
@ -643,7 +651,5 @@ func (p *Authenticate) ValidateToken(w http.ResponseWriter, r *http.Request) {
|
|||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -77,7 +77,7 @@ func TestAuthenticate_SignInPage(t *testing.T) {
|
|||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
||||
}
|
||||
body := []byte(rr.Body.String())
|
||||
body := rr.Body.Bytes()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
@ -1,109 +0,0 @@
|
|||
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
)
|
||||
|
||||
var defaultSignatureValidityDuration = 5 * time.Minute
|
||||
|
||||
// validateRedirectURI checks the redirect uri in the query parameters and ensures that
|
||||
// the url's domain is one in the list of proxy root domains.
|
||||
func validateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
redirectURI := req.Form.Get("redirect_uri")
|
||||
if !validRedirectURI(redirectURI, proxyRootDomains) {
|
||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
f(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func validRedirectURI(uri string, rootDomains []string) bool {
|
||||
if uri == "" || len(rootDomains) == 0 {
|
||||
return false
|
||||
}
|
||||
redirectURL, err := url.Parse(uri)
|
||||
if err != nil || redirectURL.Host == "" {
|
||||
return false
|
||||
}
|
||||
for _, domain := range rootDomains {
|
||||
if domain == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasSuffix(redirectURL.Hostname(), domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func validateSignature(f http.HandlerFunc, sharedKey string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
redirectURI := req.Form.Get("redirect_uri")
|
||||
sigVal := req.Form.Get("sig")
|
||||
timestamp := req.Form.Get("ts")
|
||||
if !validSignature(redirectURI, sigVal, timestamp, sharedKey) {
|
||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
f(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// validateSignature ensures the validity of the redirect url by comparing the hmac
|
||||
// digest, and ensuring that the included timestamp is fresh
|
||||
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
||||
return false
|
||||
}
|
||||
_, err := url.Parse(redirectURI)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
i, err := strconv.ParseInt(timestamp, 10, 64)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
tm := time.Unix(i, 0)
|
||||
if time.Now().Sub(tm) > defaultSignatureValidityDuration {
|
||||
return false
|
||||
}
|
||||
localSig := redirectURLSignature(redirectURI, tm, secret)
|
||||
return hmac.Equal(requestSig, localSig)
|
||||
}
|
||||
|
||||
// redirectURLSignature generates a hmac digest from a
|
||||
// redirect url, a timestamp, and a secret.
|
||||
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(rawRedirect))
|
||||
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
|
||||
return h.Sum(nil)
|
||||
}
|
|
@ -1,85 +0,0 @@
|
|||
package authenticate
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_validRedirectURI(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
rootDomains []string
|
||||
want bool
|
||||
}{
|
||||
{"good url redirect", "https://example.com/redirect", []string{"example.com"}, true},
|
||||
{"bad domain", "https://example.com/redirect", []string{"notexample.com"}, false},
|
||||
{"malformed url", "^example.com/redirect", []string{"notexample.com"}, false},
|
||||
{"empty domain list", "https://example.com/redirect", []string{}, false},
|
||||
{"empty domain", "https://example.com/redirect", []string{""}, false},
|
||||
{"empty url", "", []string{"example.com"}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := validRedirectURI(tt.uri, tt.rootDomains); got != tt.want {
|
||||
t.Errorf("validRedirectURI() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validSignature(t *testing.T) {
|
||||
goodURL := "https://example.com/redirect"
|
||||
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
||||
now := fmt.Sprint(time.Now().Unix())
|
||||
rawSig := redirectURLSignature(goodURL, time.Now(), secretA)
|
||||
sig := base64.URLEncoding.EncodeToString(rawSig)
|
||||
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
redirectURI string
|
||||
sigVal string
|
||||
timestamp string
|
||||
secret string
|
||||
want bool
|
||||
}{
|
||||
{"good signature", goodURL, string(sig), now, secretA, true},
|
||||
{"empty redirect url", "", string(sig), now, secretA, false},
|
||||
{"bad redirect url", "https://google.com^", string(sig), now, secretA, false},
|
||||
{"malformed signature", goodURL, string(sig + "^"), now, "&*&@**($&#(", false},
|
||||
{"malformed timestamp", goodURL, string(sig), now + "^", secretA, false},
|
||||
{"stale timestamp", goodURL, string(sig), staleTime, secretA, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := validSignature(tt.redirectURI, tt.sigVal, tt.timestamp, tt.secret); got != tt.want {
|
||||
t.Errorf("validSignature() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_redirectURLSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawRedirect string
|
||||
timestamp time.Time
|
||||
secret string
|
||||
want string
|
||||
}{
|
||||
{"good signature", "https://example.com/redirect", time.Unix(1546797901, 0), "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=", "GIDyWKjrG_7MwXwIq1o51f2pDT_rH9aLHdsHxSBEwy8="},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := redirectURLSignature(tt.rawRedirect, tt.timestamp, tt.secret)
|
||||
out := base64.URLEncoding.EncodeToString(got)
|
||||
if out != tt.want {
|
||||
t.Errorf("redirectURLSignature() = %v, want %v", tt.want, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -258,7 +258,7 @@ func (p *ProviderData) RefreshAccessToken(refreshToken string) (string, time.Dur
|
|||
log.Error().Err(err).Msg("authenticate/providers.RefreshAccessToken")
|
||||
return "", 0, err
|
||||
}
|
||||
return newToken.AccessToken, newToken.Expiry.Sub(time.Now()), nil
|
||||
return newToken.AccessToken, time.Until(newToken.Expiry), nil
|
||||
}
|
||||
|
||||
// Revoke enables a user to revoke her token. If the identity provider supports revocation
|
||||
|
|
|
@ -75,8 +75,3 @@ func (tp *TestProvider) Redeem(code string) (*sessions.SessionState, error) {
|
|||
return tp.Session, tp.RedeemError
|
||||
|
||||
}
|
||||
|
||||
// Stop fulfills the Provider interface
|
||||
func (tp *TestProvider) Stop() {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -7,10 +7,10 @@
|
|||
|
||||
# Certificates can be loaded as files or base64 encoded bytes. If neither is set, a
|
||||
# pomerium will attempt to locate a pair in the root directory
|
||||
export CERTIFICATE="xxxxxx" # base64 encoded cert, eg. `base64 -i cert.pem`
|
||||
export CERTIFICATE_KEY="xxxx" # base64 encoded key, eg. `base64 -i privkey.pem`
|
||||
export CERTIFICATE_FILE="./cert.pem" # optional, defaults to `./cert.pem`
|
||||
export CERTIFICATE_KEY_FILE="./privkey.pem" # optional, defaults to `./certprivkey.pem`
|
||||
# export CERTIFICATE="xxxxxx" # base64 encoded cert, eg. `base64 -i cert.pem`
|
||||
# export CERTIFICATE_KEY="xxxx" # base64 encoded key, eg. `base64 -i privkey.pem`
|
||||
|
||||
# The URL that the identity provider will call back after authenticating the user
|
||||
export REDIRECT_URL="https://sso-auth.corp.example.com/oauth2/callback"
|
||||
|
|
1
go.mod
1
go.mod
|
@ -9,7 +9,6 @@ require (
|
|||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||
github.com/rs/zerolog v1.11.0
|
||||
github.com/stretchr/testify v1.2.2 // indirect
|
||||
github.com/zenazn/goji v0.9.0
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9
|
||||
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890
|
||||
|
|
2
go.sum
2
go.sum
|
@ -16,8 +16,6 @@ github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M=
|
|||
github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
||||
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/zenazn/goji v0.9.0 h1:RSQQAbXGArQ0dIDEq+PI6WqN6if+5KHu6x2Cx/GXLTQ=
|
||||
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0=
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
)
|
||||
|
||||
// ErrTokenRevoked signifies a token revokation or expiration error
|
||||
var ErrTokenRevoked = errors.New("Token expired or revoked")
|
||||
var ErrTokenRevoked = errors.New("token expired or revoked")
|
||||
|
||||
var httpClient = &http.Client{
|
||||
Timeout: time.Second * 5,
|
||||
|
|
|
@ -3,6 +3,7 @@ package log // import "github.com/pomerium/pomerium/internal/log"
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
@ -102,3 +103,9 @@ func Printf(format string, v ...interface{}) {
|
|||
func Ctx(ctx context.Context) *zerolog.Logger {
|
||||
return zerolog.Ctx(ctx)
|
||||
}
|
||||
|
||||
// FromRequest gets the logger in the request's context.
|
||||
// This is a shortcut for log.Ctx(r.Context())
|
||||
func FromRequest(r *http.Request) *zerolog.Logger {
|
||||
return Ctx(r.Context())
|
||||
}
|
||||
|
|
2
internal/middleware/doc.go
Normal file
2
internal/middleware/doc.go
Normal file
|
@ -0,0 +1,2 @@
|
|||
// Package middleware provides a standard set of middleware implementations for pomerium.
|
||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
|
@ -1,5 +1,4 @@
|
|||
// Package log provides a set of http.Handler helpers for zerolog.
|
||||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -10,14 +9,14 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/zenazn/goji/web/mutil"
|
||||
)
|
||||
|
||||
// FromRequest gets the logger in the request's context.
|
||||
// This is a shortcut for log.Ctx(r.Context())
|
||||
func FromRequest(r *http.Request) *zerolog.Logger {
|
||||
return Ctx(r.Context())
|
||||
return log.Ctx(r.Context())
|
||||
}
|
||||
|
||||
// NewHandler injects log into requests context.
|
||||
|
@ -172,7 +171,7 @@ func AccessHandler(f func(r *http.Request, status, size int, duration time.Durat
|
|||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
lw := mutil.WrapWriter(w)
|
||||
lw := NewWrapResponseWriter(w, 2)
|
||||
next.ServeHTTP(lw, r)
|
||||
f(r, lw.Status(), lw.BytesWritten(), time.Since(start))
|
||||
})
|
|
@ -1,5 +1,4 @@
|
|||
// Package log provides a set of http.Handler helpers for zerolog.
|
||||
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -258,3 +257,21 @@ func BenchmarkDataRace(b *testing.B) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestForwardedAddrHandler(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Header: http.Header{
|
||||
"X-Forwarded-For": []string{"client", "proxy1", "proxy2"},
|
||||
},
|
||||
}
|
||||
h := ForwardedAddrHandler("fwd_ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h.ServeHTTP(nil, r)
|
||||
if want, got := `{"fwd_ip":"client"}`+"\n", decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,3 @@
|
|||
// Package middleware provides a standard set of middleware implementations for pomerium.
|
||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
|
@ -15,91 +14,76 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
)
|
||||
|
||||
// SetHeadersOld ensures that every response includes some basic security headers
|
||||
func SetHeadersOld(h http.Handler, securityHeaders map[string]string) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
for key, val := range securityHeaders {
|
||||
rw.Header().Set(key, val)
|
||||
}
|
||||
h.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
|
||||
// SetHeaders ensures that every response includes some basic security headers
|
||||
func SetHeaders(securityHeaders map[string]string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for key, val := range securityHeaders {
|
||||
rw.Header().Set(key, val)
|
||||
w.Header().Set(key, val)
|
||||
}
|
||||
next.ServeHTTP(rw, req)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WithMethods writes an error response if the method of the request is not included.
|
||||
func WithMethods(f http.HandlerFunc, methods ...string) http.HandlerFunc {
|
||||
methodMap := make(map[string]struct{}, len(methods))
|
||||
for _, m := range methods {
|
||||
methodMap[m] = struct{}{}
|
||||
}
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
if _, ok := methodMap[req.Method]; !ok {
|
||||
httputil.ErrorResponse(rw, req, fmt.Sprintf("method %s not allowed", req.Method), http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
f(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateClientSecret checks the request header for the client secret and returns
|
||||
// an error if it does not match the proxy client secret
|
||||
func ValidateClientSecret(f http.HandlerFunc, sharedSecret string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
|
||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientSecret := req.Form.Get("shared_secret")
|
||||
clientSecret := r.Form.Get("shared_secret")
|
||||
// check the request header for the client secret
|
||||
if clientSecret == "" {
|
||||
clientSecret = req.Header.Get("X-Client-Secret")
|
||||
clientSecret = r.Header.Get("X-Client-Secret")
|
||||
}
|
||||
|
||||
if clientSecret != sharedSecret {
|
||||
httputil.ErrorResponse(rw, req, "Invalid client secret", http.StatusUnauthorized)
|
||||
httputil.ErrorResponse(w, r, "Invalid client secret", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
f(rw, req)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateRedirectURI checks the redirect uri in the query parameters and ensures that
|
||||
// the url's domain is one in the list of proxy root domains.
|
||||
func ValidateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
// the its domain is in the list of proxy root domains.
|
||||
func ValidateRedirectURI(proxyRootDomains []string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
redirectURI := req.Form.Get("redirect_uri")
|
||||
if !validRedirectURI(redirectURI, proxyRootDomains) {
|
||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
redirectURI := r.Form.Get("redirect_uri")
|
||||
if !ValidRedirectURI(redirectURI, proxyRootDomains) {
|
||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
f(rw, req)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func validRedirectURI(uri string, rootDomains []string) bool {
|
||||
// ValidRedirectURI checks if a URL's domain is one in the list of proxy root domains.
|
||||
func ValidRedirectURI(uri string, rootDomains []string) bool {
|
||||
if uri == "" || len(rootDomains) == 0 {
|
||||
return false
|
||||
}
|
||||
redirectURL, err := url.Parse(uri)
|
||||
if uri == "" || err != nil || redirectURL.Host == "" {
|
||||
if err != nil || redirectURL.Host == "" {
|
||||
return false
|
||||
}
|
||||
for _, domain := range rootDomains {
|
||||
if domain == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasSuffix(redirectURL.Hostname(), domain) {
|
||||
return true
|
||||
}
|
||||
|
@ -109,35 +93,36 @@ func validRedirectURI(uri string, rootDomains []string) bool {
|
|||
|
||||
// ValidateSignature ensures the request is valid and has been signed with
|
||||
// the correspdoning client secret key
|
||||
func ValidateSignature(f http.HandlerFunc, sharedSecret string) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
err := req.ParseForm()
|
||||
func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
redirectURI := req.Form.Get("redirect_uri")
|
||||
sigVal := req.Form.Get("sig")
|
||||
timestamp := req.Form.Get("ts")
|
||||
if !validSignature(redirectURI, sigVal, timestamp, sharedSecret) {
|
||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
||||
redirectURI := r.Form.Get("redirect_uri")
|
||||
sigVal := r.Form.Get("sig")
|
||||
timestamp := r.Form.Get("ts")
|
||||
if !ValidSignature(redirectURI, sigVal, timestamp, sharedSecret) {
|
||||
httputil.ErrorResponse(w, r, "Cross service signature failed to validate", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
f(rw, req)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateHost ensures that each request's host is valid
|
||||
func ValidateHost(mux map[string]*http.Handler) func(next http.Handler) http.Handler {
|
||||
func ValidateHost(mux map[string]http.Handler) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if _, ok := mux[req.Host]; !ok {
|
||||
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, ok := mux[r.Host]; !ok {
|
||||
httputil.ErrorResponse(w, r, "Unknown host to route", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(rw, req)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -145,25 +130,47 @@ func ValidateHost(mux map[string]*http.Handler) func(next http.Handler) http.Han
|
|||
// RequireHTTPS reroutes a HTTP request to HTTPS
|
||||
// todo(bdd) : this is unreliable unless behind another reverser proxy
|
||||
// todo(bdd) : header age seems extreme
|
||||
func RequireHTTPS(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("Strict-Transport-Security", "max-age=31536000")
|
||||
func RequireHTTPS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=31536000")
|
||||
// todo(bdd) : scheme and x-forwarded-proto cannot be trusted if not behind another load balancer
|
||||
if (req.URL.Scheme == "http" && req.Header.Get("X-Forwarded-Proto") == "http") || &req.TLS == nil {
|
||||
if (r.URL.Scheme == "http" && r.Header.Get("X-Forwarded-Proto") == "http") || &r.TLS == nil {
|
||||
dest := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: req.Host,
|
||||
Path: req.URL.Path,
|
||||
RawQuery: req.URL.RawQuery,
|
||||
Host: r.Host,
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
}
|
||||
http.Redirect(rw, req, dest.String(), http.StatusMovedPermanently)
|
||||
http.Redirect(w, r, dest.String(), http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(rw, req)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||
// Healthcheck endpoint middleware useful to setting up a path like
|
||||
// `/ping` that load balancers or uptime testing external services
|
||||
// can make a request before hitting any routes. It's also convenient
|
||||
// to place this above ACL middlewares as well.
|
||||
func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler {
|
||||
f := func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "GET" && strings.EqualFold(r.URL.Path, endpoint) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(msg))
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// ValidSignature checks to see if a signature is valid. Compares hmac of
|
||||
// redirect uri, timestamp, and secret and signature.
|
||||
func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
||||
return false
|
||||
}
|
||||
|
|
324
internal/middleware/middleware_test.go
Normal file
324
internal/middleware/middleware_test.go
Normal file
|
@ -0,0 +1,324 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_ValidRedirectURI(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
rootDomains []string
|
||||
want bool
|
||||
}{
|
||||
{"good url redirect", "https://example.com/redirect", []string{"example.com"}, true},
|
||||
{"bad domain", "https://example.com/redirect", []string{"notexample.com"}, false},
|
||||
{"malformed url", "^example.com/redirect", []string{"notexample.com"}, false},
|
||||
{"empty domain list", "https://example.com/redirect", []string{}, false},
|
||||
{"empty domain", "https://example.com/redirect", []string{""}, false},
|
||||
{"empty url", "", []string{"example.com"}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ValidRedirectURI(tt.uri, tt.rootDomains); got != tt.want {
|
||||
t.Errorf("ValidRedirectURI() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ValidSignature(t *testing.T) {
|
||||
goodURL := "https://example.com/redirect"
|
||||
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
||||
now := fmt.Sprint(time.Now().Unix())
|
||||
rawSig := redirectURLSignature(goodURL, time.Now(), secretA)
|
||||
sig := base64.URLEncoding.EncodeToString(rawSig)
|
||||
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
redirectURI string
|
||||
sigVal string
|
||||
timestamp string
|
||||
secret string
|
||||
want bool
|
||||
}{
|
||||
{"good signature", goodURL, string(sig), now, secretA, true},
|
||||
{"empty redirect url", "", string(sig), now, secretA, false},
|
||||
{"bad redirect url", "https://google.com^", string(sig), now, secretA, false},
|
||||
{"malformed signature", goodURL, string(sig + "^"), now, "&*&@**($&#(", false},
|
||||
{"malformed timestamp", goodURL, string(sig), now + "^", secretA, false},
|
||||
{"stale timestamp", goodURL, string(sig), staleTime, secretA, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ValidSignature(tt.redirectURI, tt.sigVal, tt.timestamp, tt.secret); got != tt.want {
|
||||
t.Errorf("ValidSignature() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_redirectURLSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawRedirect string
|
||||
timestamp time.Time
|
||||
secret string
|
||||
want string
|
||||
}{
|
||||
{"good signature", "https://example.com/redirect", time.Unix(1546797901, 0), "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=", "GIDyWKjrG_7MwXwIq1o51f2pDT_rH9aLHdsHxSBEwy8="},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := redirectURLSignature(tt.rawRedirect, tt.timestamp, tt.secret)
|
||||
out := base64.URLEncoding.EncodeToString(got)
|
||||
if out != tt.want {
|
||||
t.Errorf("redirectURLSignature() = %v, want %v", tt.want, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
securityHeaders map[string]string
|
||||
}{
|
||||
{"one option", map[string]string{"X-Frame-Options": "DENY"}},
|
||||
{"two options", map[string]string{"X-Frame-Options": "DENY", "A": "B"}},
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for k, want := range tt.securityHeaders {
|
||||
if got := w.Header().Get(k); want != got {
|
||||
t.Errorf("want %s got %q", want, got)
|
||||
}
|
||||
|
||||
}
|
||||
})
|
||||
rr := httptest.NewRecorder()
|
||||
handler := SetHeaders(tt.securityHeaders)(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRedirectURI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
proxyRootDomains []string
|
||||
redirectURI string
|
||||
status int
|
||||
}{
|
||||
{"simple", []string{"google.com"}, "https://google.com", http.StatusOK},
|
||||
{"bad match", []string{"aol.com"}, "https://google.com", http.StatusBadRequest},
|
||||
{"with cname", []string{"google.com"}, "https://www.google.com", http.StatusOK},
|
||||
{"with path", []string{"google.com"}, "https://www.google.com/path", http.StatusOK},
|
||||
{"http", []string{"google.com"}, "http://www.google.com/path", http.StatusOK},
|
||||
{"malformed, invalid hex digits", []string{"google.com"}, "%zzzzz", http.StatusBadRequest},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{RawQuery: fmt.Sprintf("redirect_uri=%s", tt.redirectURI)},
|
||||
}
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Hi"))
|
||||
})
|
||||
rr := httptest.NewRecorder()
|
||||
handler := ValidateRedirectURI(tt.proxyRootDomains)(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != tt.status {
|
||||
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
|
||||
t.Errorf("%s", rr.Body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateClientSecret(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sharedSecret string
|
||||
clientGetValue string
|
||||
clientHeaderValue string
|
||||
status int
|
||||
}{
|
||||
{"simple", "secret", "secret", "secret", http.StatusOK},
|
||||
{"missing get param, valid header", "secret", "", "secret", http.StatusOK},
|
||||
{"missing both", "secret", "", "", http.StatusUnauthorized},
|
||||
{"simple bad", "bad-secret", "secret", "", http.StatusUnauthorized},
|
||||
{"malformed, invalid hex digits", "secret", "%zzzzz", "", http.StatusBadRequest},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := &http.Request{
|
||||
Method: http.MethodGet,
|
||||
Header: http.Header{"X-Client-Secret": []string{tt.clientHeaderValue}},
|
||||
URL: &url.URL{RawQuery: fmt.Sprintf("shared_secret=%s", tt.clientGetValue)},
|
||||
}
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Hi"))
|
||||
})
|
||||
rr := httptest.NewRecorder()
|
||||
handler := ValidateClientSecret(tt.sharedSecret)(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != tt.status {
|
||||
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
|
||||
t.Errorf("%s", rr.Body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSignature(t *testing.T) {
|
||||
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
||||
now := fmt.Sprint(time.Now().Unix())
|
||||
goodURL := "https://example.com/redirect"
|
||||
rawSig := redirectURLSignature(goodURL, time.Now(), secretA)
|
||||
sig := base64.URLEncoding.EncodeToString(rawSig)
|
||||
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sharedSecret string
|
||||
redirectURI string
|
||||
sig string
|
||||
ts string
|
||||
status int
|
||||
}{
|
||||
{"valid signature", secretA, goodURL, sig, now, http.StatusOK},
|
||||
{"stale signature", secretA, goodURL, sig, staleTime, http.StatusUnauthorized},
|
||||
{"malformed", secretA, goodURL, "%zzzzz", now, http.StatusBadRequest},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := url.Values{}
|
||||
v.Set("redirect_uri", tt.redirectURI)
|
||||
v.Set("ts", tt.ts)
|
||||
v.Set("sig", tt.sig)
|
||||
|
||||
req := &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{RawQuery: v.Encode()}}
|
||||
if tt.name == "malformed" {
|
||||
req.URL.RawQuery = "sig=%zzzzz"
|
||||
}
|
||||
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Hi"))
|
||||
})
|
||||
rr := httptest.NewRecorder()
|
||||
handler := ValidateSignature(tt.sharedSecret)(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != tt.status {
|
||||
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
|
||||
t.Errorf("%s", rr.Body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
clientPath string
|
||||
expected []byte
|
||||
}{
|
||||
{"good", http.MethodGet, "/ping", []byte("OK")},
|
||||
//tood(bdd): miss?
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, tt.clientPath, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Hi"))
|
||||
})
|
||||
rr := httptest.NewRecorder()
|
||||
handler := Healthcheck(tt.clientPath, string(tt.expected))(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Body.String() != string(tt.expected) {
|
||||
t.Errorf("body differs. got %ss want %ss", rr.Body, tt.expected)
|
||||
t.Errorf("%s", rr.Body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Redirect to a fixed URL
|
||||
type handlerHelper struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (rh *handlerHelper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(rh.msg))
|
||||
}
|
||||
|
||||
func handlerHelp(msg string) http.Handler {
|
||||
return &handlerHelper{msg}
|
||||
}
|
||||
func TestValidateHost(t *testing.T) {
|
||||
m := make(map[string]http.Handler)
|
||||
m["google.com"] = handlerHelp("google")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
validHosts map[string]http.Handler
|
||||
clientPath string
|
||||
expected []byte
|
||||
status int
|
||||
}{
|
||||
{"good", m, "google.com", []byte("google"), 200},
|
||||
{"no route", m, "googles.com", []byte("google"), 404},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, tt.clientPath, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
var testHandler http.Handler
|
||||
if tt.validHosts[tt.clientPath] != nil {
|
||||
tt.validHosts[tt.clientPath].ServeHTTP(rr, req)
|
||||
testHandler = tt.validHosts[tt.clientPath]
|
||||
} else {
|
||||
testHandler = handlerHelp("ok")
|
||||
}
|
||||
handler := ValidateHost(tt.validHosts)(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.status {
|
||||
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
|
||||
t.Errorf("%s", rr.Body)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
183
internal/middleware/wrap_writer.go
Normal file
183
internal/middleware/wrap_writer.go
Normal file
|
@ -0,0 +1,183 @@
|
|||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
// The original work was derived from Goji's middleware, source:
|
||||
// https://github.com/zenazn/goji/tree/master/web/middleware
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// NewWrapResponseWriter wraps an http.ResponseWriter, returning a proxy that allows you to
|
||||
// hook into various parts of the response process.
|
||||
func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) WrapResponseWriter {
|
||||
_, fl := w.(http.Flusher)
|
||||
|
||||
bw := basicWriter{ResponseWriter: w}
|
||||
|
||||
if protoMajor == 2 {
|
||||
_, ps := w.(http.Pusher)
|
||||
if fl && ps {
|
||||
return &http2FancyWriter{bw}
|
||||
}
|
||||
} else {
|
||||
_, hj := w.(http.Hijacker)
|
||||
_, rf := w.(io.ReaderFrom)
|
||||
if fl && hj && rf {
|
||||
return &httpFancyWriter{bw}
|
||||
}
|
||||
}
|
||||
if fl {
|
||||
return &flushWriter{bw}
|
||||
}
|
||||
|
||||
return &bw
|
||||
}
|
||||
|
||||
// WrapResponseWriter is a proxy around an http.ResponseWriter that allows you to hook
|
||||
// into various parts of the response process.
|
||||
type WrapResponseWriter interface {
|
||||
http.ResponseWriter
|
||||
// Status returns the HTTP status of the request, or 0 if one has not
|
||||
// yet been sent.
|
||||
Status() int
|
||||
// BytesWritten returns the total number of bytes sent to the client.
|
||||
BytesWritten() int
|
||||
// Tee causes the response body to be written to the given io.Writer in
|
||||
// addition to proxying the writes through. Only one io.Writer can be
|
||||
// tee'd to at once: setting a second one will overwrite the first.
|
||||
// Writes will be sent to the proxy before being written to this
|
||||
// io.Writer. It is illegal for the tee'd writer to be modified
|
||||
// concurrently with writes.
|
||||
Tee(io.Writer)
|
||||
// Unwrap returns the original proxied target.
|
||||
Unwrap() http.ResponseWriter
|
||||
}
|
||||
|
||||
// basicWriter wraps a http.ResponseWriter that implements the minimal
|
||||
// http.ResponseWriter interface.
|
||||
type basicWriter struct {
|
||||
http.ResponseWriter
|
||||
wroteHeader bool
|
||||
code int
|
||||
bytes int
|
||||
tee io.Writer
|
||||
}
|
||||
|
||||
func (b *basicWriter) WriteHeader(code int) {
|
||||
if !b.wroteHeader {
|
||||
b.code = code
|
||||
b.wroteHeader = true
|
||||
b.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *basicWriter) Write(buf []byte) (int, error) {
|
||||
b.WriteHeader(http.StatusOK)
|
||||
n, err := b.ResponseWriter.Write(buf)
|
||||
if b.tee != nil {
|
||||
_, err2 := b.tee.Write(buf[:n])
|
||||
// Prefer errors generated by the proxied writer.
|
||||
if err == nil {
|
||||
err = err2
|
||||
}
|
||||
}
|
||||
b.bytes += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (b *basicWriter) maybeWriteHeader() {
|
||||
if !b.wroteHeader {
|
||||
b.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *basicWriter) Status() int {
|
||||
return b.code
|
||||
}
|
||||
|
||||
func (b *basicWriter) BytesWritten() int {
|
||||
return b.bytes
|
||||
}
|
||||
|
||||
func (b *basicWriter) Tee(w io.Writer) {
|
||||
b.tee = w
|
||||
}
|
||||
|
||||
func (b *basicWriter) Unwrap() http.ResponseWriter {
|
||||
return b.ResponseWriter
|
||||
}
|
||||
|
||||
type flushWriter struct {
|
||||
basicWriter
|
||||
}
|
||||
|
||||
func (f *flushWriter) Flush() {
|
||||
f.wroteHeader = true
|
||||
|
||||
fl := f.basicWriter.ResponseWriter.(http.Flusher)
|
||||
fl.Flush()
|
||||
}
|
||||
|
||||
var _ http.Flusher = &flushWriter{}
|
||||
|
||||
// httpFancyWriter is a HTTP writer that additionally satisfies
|
||||
// http.Flusher, http.Hijacker, and io.ReaderFrom. It exists for the common case
|
||||
// of wrapping the http.ResponseWriter that package http gives you, in order to
|
||||
// make the proxied object support the full method set of the proxied object.
|
||||
type httpFancyWriter struct {
|
||||
basicWriter
|
||||
}
|
||||
|
||||
func (f *httpFancyWriter) Flush() {
|
||||
f.wroteHeader = true
|
||||
|
||||
fl := f.basicWriter.ResponseWriter.(http.Flusher)
|
||||
fl.Flush()
|
||||
}
|
||||
|
||||
func (f *httpFancyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hj := f.basicWriter.ResponseWriter.(http.Hijacker)
|
||||
return hj.Hijack()
|
||||
}
|
||||
|
||||
func (f *http2FancyWriter) Push(target string, opts *http.PushOptions) error {
|
||||
return f.basicWriter.ResponseWriter.(http.Pusher).Push(target, opts)
|
||||
}
|
||||
|
||||
func (f *httpFancyWriter) ReadFrom(r io.Reader) (int64, error) {
|
||||
if f.basicWriter.tee != nil {
|
||||
n, err := io.Copy(&f.basicWriter, r)
|
||||
f.basicWriter.bytes += int(n)
|
||||
return n, err
|
||||
}
|
||||
rf := f.basicWriter.ResponseWriter.(io.ReaderFrom)
|
||||
f.basicWriter.maybeWriteHeader()
|
||||
n, err := rf.ReadFrom(r)
|
||||
f.basicWriter.bytes += int(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
var _ http.Flusher = &httpFancyWriter{}
|
||||
var _ http.Hijacker = &httpFancyWriter{}
|
||||
var _ http.Pusher = &http2FancyWriter{}
|
||||
var _ io.ReaderFrom = &httpFancyWriter{}
|
||||
|
||||
// http2FancyWriter is a HTTP2 writer that additionally satisfies
|
||||
// http.Flusher, and io.ReaderFrom. It exists for the common case
|
||||
// of wrapping the http.ResponseWriter that package http gives you, in order to
|
||||
// make the proxied object support the full method set of the proxied object.
|
||||
type http2FancyWriter struct {
|
||||
basicWriter
|
||||
}
|
||||
|
||||
func (f *http2FancyWriter) Flush() {
|
||||
f.wroteHeader = true
|
||||
|
||||
fl := f.basicWriter.ResponseWriter.(http.Flusher)
|
||||
fl.Flush()
|
||||
}
|
||||
|
||||
var _ http.Flusher = &http2FancyWriter{}
|
33
internal/middleware/wrap_writer_test.go
Normal file
33
internal/middleware/wrap_writer_test.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFlushWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
|
||||
f := &flushWriter{basicWriter{ResponseWriter: httptest.NewRecorder()}}
|
||||
f.Flush()
|
||||
|
||||
if !f.wroteHeader {
|
||||
t.Fatal("want Flush to have set wroteHeader=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpFancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
|
||||
f := &httpFancyWriter{basicWriter{ResponseWriter: httptest.NewRecorder()}}
|
||||
f.Flush()
|
||||
|
||||
if !f.wroteHeader {
|
||||
t.Fatal("want Flush to have set wroteHeader=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
|
||||
f := &http2FancyWriter{basicWriter{ResponseWriter: httptest.NewRecorder()}}
|
||||
f.Flush()
|
||||
|
||||
if !f.wroteHeader {
|
||||
t.Fatal("want Flush to have set wroteHeader=true")
|
||||
}
|
||||
}
|
|
@ -46,9 +46,9 @@ func (p *Proxy) Handler() http.Handler {
|
|||
|
||||
// Middleware chain
|
||||
c := middleware.NewChain()
|
||||
c = c.Append(log.NewHandler(log.Logger))
|
||||
c = c.Append(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
||||
log.FromRequest(r).Info().
|
||||
c = c.Append(middleware.NewHandler(log.Logger))
|
||||
c = c.Append(middleware.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
||||
middleware.FromRequest(r).Info().
|
||||
Str("method", r.Method).
|
||||
Str("url", r.URL.String()).
|
||||
Int("status", status).
|
||||
|
@ -60,11 +60,11 @@ func (p *Proxy) Handler() http.Handler {
|
|||
}))
|
||||
c = c.Append(middleware.SetHeaders(securityHeaders))
|
||||
c = c.Append(middleware.RequireHTTPS)
|
||||
c = c.Append(log.ForwardedAddrHandler("fwd_ip"))
|
||||
c = c.Append(log.RemoteAddrHandler("ip"))
|
||||
c = c.Append(log.UserAgentHandler("user_agent"))
|
||||
c = c.Append(log.RefererHandler("referer"))
|
||||
c = c.Append(log.RequestIDHandler("req_id", "Request-Id"))
|
||||
c = c.Append(middleware.ForwardedAddrHandler("fwd_ip"))
|
||||
c = c.Append(middleware.RemoteAddrHandler("ip"))
|
||||
c = c.Append(middleware.UserAgentHandler("user_agent"))
|
||||
c = c.Append(middleware.RefererHandler("referer"))
|
||||
c = c.Append(middleware.RequestIDHandler("req_id", "Request-Id"))
|
||||
c = c.Append(middleware.ValidateHost(p.mux))
|
||||
h := c.Then(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip host validation for /ping requests because they hit the LB directly.
|
||||
|
@ -260,8 +260,7 @@ func (p *Proxy) AuthenticateOnly(w http.ResponseWriter, r *http.Request) {
|
|||
// or starting the authentication process if not.
|
||||
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||
// Attempts to validate the user and their cookie.
|
||||
var err error
|
||||
err = p.Authenticate(w, r)
|
||||
err := p.Authenticate(w, r)
|
||||
// If the authentication is not successful we proceed to start the OAuth Flow with
|
||||
// OAuthStart. If successful, we proceed to proxy to the configured upstream.
|
||||
if err != nil {
|
||||
|
@ -387,7 +386,7 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
|
|||
|
||||
// Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig
|
||||
func (p *Proxy) Handle(host string, handler http.Handler) {
|
||||
p.mux[host] = &handler
|
||||
p.mux[host] = handler
|
||||
}
|
||||
|
||||
// router attempts to find a route for a request. If a route is successfully matched,
|
||||
|
@ -396,7 +395,7 @@ func (p *Proxy) Handle(host string, handler http.Handler) {
|
|||
func (p *Proxy) router(r *http.Request) (http.Handler, bool) {
|
||||
route, ok := p.mux[r.Host]
|
||||
if ok {
|
||||
return *route, true
|
||||
return route, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
|
|
@ -135,7 +135,7 @@ type Proxy struct {
|
|||
|
||||
redirectURL *url.URL
|
||||
templates *template.Template
|
||||
mux map[string]*http.Handler
|
||||
mux map[string]http.Handler
|
||||
}
|
||||
|
||||
// StateParameter holds the redirect id along with the session id.
|
||||
|
@ -184,7 +184,7 @@ func New(opts *Options) (*Proxy, error) {
|
|||
|
||||
p := &Proxy{
|
||||
// these fields make up the routing mechanism
|
||||
mux: make(map[string]*http.Handler),
|
||||
mux: make(map[string]http.Handler),
|
||||
// session state
|
||||
cipher: cipher,
|
||||
csrfStore: cookieStore,
|
||||
|
@ -243,15 +243,12 @@ func deleteUpstreamCookies(req *http.Request, cookieName string) {
|
|||
req.Header.Set("Cookie", strings.Join(headers, ";"))
|
||||
}
|
||||
|
||||
// signRequest signs a g
|
||||
func (u *UpstreamProxy) signRequest(req *http.Request) {
|
||||
if u.signer != nil {
|
||||
jwt, err := u.signer.SignJWT(req.Header.Get(HeaderUserID), req.Header.Get(HeaderEmail))
|
||||
if err == nil {
|
||||
req.Header.Set(HeaderJWT, jwt)
|
||||
}
|
||||
} else {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue