authenticate: refactor middleware, logging, and tests (#30)

- Abstract remaining middleware from authenticate into internal.
- Use middleware chaining in authenticate.
- Standardize naming of Request and ResponseWriter to match std lib.
- Add healthcheck / ping as a middleware.
- Internalized wraped_writer package adapted from goji/middleware.
- Fixed indirection issue with reverse proxy map.
This commit is contained in:
Bobby DeSimone 2019-01-25 20:58:50 -08:00 committed by GitHub
parent b9c298d278
commit 7e1d1a7896
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 768 additions and 397 deletions

View file

@ -71,11 +71,10 @@ func OptionsFromEnvConfig() (*Options, error) {
return o, nil return o, nil
} }
// Validate checks to see if configuration values are valid for authentication service. // Validate checks to see if configuration values are valid for the authentication service.
// The checks do not modify the internal state of the Option structure. Function returns // The checks do not modify the internal state of the Option structure. Returns
// on first error found. // on first error found.
func (o *Options) Validate() error { func (o *Options) Validate() error {
if o.RedirectURL == nil { if o.RedirectURL == nil {
return errors.New("missing setting: identity provider redirect url") return errors.New("missing setting: identity provider redirect url")
} }
@ -105,11 +104,11 @@ func (o *Options) Validate() error {
if len(decodedCookieSecret) != 32 { if len(decodedCookieSecret) != 32 {
return fmt.Errorf("cookie secret expects 32 bytes but got %d", len(decodedCookieSecret)) return fmt.Errorf("cookie secret expects 32 bytes but got %d", len(decodedCookieSecret))
} }
return nil return nil
} }
// Authenticate stores all the information associated with proxying the request. // Authenticate is service for validating user authentication for proxied-requests
// against third-party identity provider (IdP) services.
type Authenticate struct { type Authenticate struct {
RedirectURL *url.URL RedirectURL *url.URL
@ -133,7 +132,7 @@ type Authenticate struct {
provider providers.Provider provider providers.Provider
} }
// New creates a Authenticate struct and applies the optional functions slice to the struct. // New validates and creates a new authentication service from a configuration options.
func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate, error) { func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate, error) {
if opts == nil { if opts == nil {
return nil, errors.New("options cannot be nil") return nil, errors.New("options cannot be nil")
@ -179,13 +178,13 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
cipher: cipher, cipher: cipher,
skipProviderButton: opts.SkipProviderButton, skipProviderButton: opts.SkipProviderButton,
} }
// p.ServeMux = p.Handler()
p.provider, err = newProvider(opts) p.provider, err = newProvider(opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// apply the option functions // validation via dependency injected function
for _, optFunc := range optionFuncs { for _, optFunc := range optionFuncs {
err := optFunc(p) err := optFunc(p)
if err != nil { if err != nil {

View file

@ -12,7 +12,7 @@ import (
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
m "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/internal/version"
) )
@ -28,45 +28,58 @@ var securityHeaders = map[string]string{
// Handler returns the Http.Handlers for authentication, callback, and refresh // Handler returns the Http.Handlers for authentication, callback, and refresh
func (p *Authenticate) Handler() http.Handler { func (p *Authenticate) Handler() http.Handler {
// set up our standard middlewares
stdMiddleware := middleware.NewChain()
stdMiddleware = stdMiddleware.Append(middleware.NewHandler(log.Logger))
stdMiddleware = stdMiddleware.Append(middleware.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
// executed after handler route handler
middleware.FromRequest(r).Info().
Str("method", r.Method).
Str("url", r.URL.String()).
Int("status", status).
Int("size", size).
Dur("duration", duration).
Msg("request")
}))
stdMiddleware = stdMiddleware.Append(middleware.SetHeaders(securityHeaders))
stdMiddleware = stdMiddleware.Append(middleware.ForwardedAddrHandler("fwd_ip"))
stdMiddleware = stdMiddleware.Append(middleware.RemoteAddrHandler("ip"))
stdMiddleware = stdMiddleware.Append(middleware.UserAgentHandler("user_agent"))
stdMiddleware = stdMiddleware.Append(middleware.RefererHandler("referer"))
stdMiddleware = stdMiddleware.Append(middleware.RequestIDHandler("req_id", "Request-Id"))
stdMiddleware = stdMiddleware.Append(middleware.Healthcheck("/ping", version.UserAgent()))
validateSignatureMiddleware := stdMiddleware.Append(
middleware.ValidateSignature(p.SharedKey),
middleware.ValidateRedirectURI(p.ProxyRootDomains))
validateClientSecret := stdMiddleware.Append(middleware.ValidateClientSecret(p.SharedKey))
mux := http.NewServeMux() mux := http.NewServeMux()
// we setup global endpoints that should respond to any hostname // we setup global endpoints that should respond to any hostname
mux.HandleFunc("/ping", m.WithMethods(p.PingPage, "GET")) mux.Handle("/ping", stdMiddleware.ThenFunc(p.PingPage))
serviceMux := http.NewServeMux() serviceMux := http.NewServeMux()
// standard rest and healthcheck endpoints // standard rest and healthcheck endpoints
serviceMux.HandleFunc("/ping", m.WithMethods(p.PingPage, "GET")) serviceMux.Handle("/ping", stdMiddleware.ThenFunc(p.PingPage))
serviceMux.HandleFunc("/robots.txt", m.WithMethods(p.RobotsTxt, "GET")) serviceMux.Handle("/robots.txt", stdMiddleware.ThenFunc(p.RobotsTxt))
// Identity Provider (IdP) endpoints and callbacks // Identity Provider (IdP) endpoints and callbacks
serviceMux.HandleFunc("/start", m.WithMethods(p.OAuthStart, "GET")) serviceMux.Handle("/start", stdMiddleware.ThenFunc(p.OAuthStart))
serviceMux.HandleFunc("/oauth2/callback", m.WithMethods(p.OAuthCallback, "GET")) serviceMux.Handle("/oauth2/callback", stdMiddleware.ThenFunc(p.OAuthCallback))
// authenticator-server endpoints, todo(bdd): make gRPC // authenticator-server endpoints, todo(bdd): make gRPC
serviceMux.HandleFunc("/sign_in", m.WithMethods(p.validateSignature(p.SignIn), "GET")) serviceMux.Handle("/sign_in", validateSignatureMiddleware.ThenFunc(p.SignIn))
serviceMux.HandleFunc("/sign_out", m.WithMethods(p.validateSignature(p.SignOut), "GET", "POST")) serviceMux.Handle("/sign_out", validateSignatureMiddleware.ThenFunc(p.SignOut)) // "GET", "POST"
serviceMux.HandleFunc("/profile", m.WithMethods(p.validateExisting(p.GetProfile), "GET")) serviceMux.Handle("/profile", validateClientSecret.ThenFunc(p.GetProfile)) // GET
serviceMux.HandleFunc("/validate", m.WithMethods(p.validateExisting(p.ValidateToken), "GET")) serviceMux.Handle("/validate", validateClientSecret.ThenFunc(p.ValidateToken)) // GET
serviceMux.HandleFunc("/redeem", m.WithMethods(p.validateExisting(p.Redeem), "POST")) serviceMux.Handle("/redeem", validateClientSecret.ThenFunc(p.Redeem)) // POST
serviceMux.HandleFunc("/refresh", m.WithMethods(p.validateExisting(p.Refresh), "POST")) serviceMux.Handle("/refresh", validateClientSecret.ThenFunc(p.Refresh)) //POST
// NOTE: we have to include trailing slash for the router to match the host header // NOTE: we have to include trailing slash for the router to match the host header
host := p.RedirectURL.Host host := p.RedirectURL.Host
if !strings.HasSuffix(host, "/") { if !strings.HasSuffix(host, "/") {
host = fmt.Sprintf("%s/", host) host = fmt.Sprintf("%s/", host)
} }
mux.Handle(host, serviceMux) // setup our service mux to only handle our required host header mux.Handle(host, serviceMux)
return m.SetHeadersOld(mux, securityHeaders) return mux
}
// validateSignature wraps a common collection of middlewares to validate signatures
func (p *Authenticate) validateSignature(f http.HandlerFunc) http.HandlerFunc {
return validateRedirectURI(validateSignature(f, p.SharedKey), p.ProxyRootDomains)
}
// validateSignature wraps a common collection of middlewares to validate
// a (presumably) existing user session
func (p *Authenticate) validateExisting(f http.HandlerFunc) http.HandlerFunc {
return m.ValidateClientSecret(f, p.SharedKey)
} }
// RobotsTxt handles the /robots.txt route. // RobotsTxt handles the /robots.txt route.
@ -83,10 +96,8 @@ func (p *Authenticate) PingPage(w http.ResponseWriter, r *http.Request) {
// SignInPage directs the user to the sign in page. Takes a `redirect_uri` param. // SignInPage directs the user to the sign in page. Takes a `redirect_uri` param.
func (p *Authenticate) SignInPage(w http.ResponseWriter, r *http.Request) { func (p *Authenticate) SignInPage(w http.ResponseWriter, r *http.Request) {
// requestLog := log.WithRequest(req, "authenticate.SignInPage")
redirectURL := p.RedirectURL.ResolveReference(r.URL) redirectURL := p.RedirectURL.ResolveReference(r.URL)
// validateRedirectURI middleware already ensures that this is a valid URL destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri")) // checked by middleware
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))
t := struct { t := struct {
ProviderName string ProviderName string
AllowedDomains []string AllowedDomains []string
@ -100,28 +111,27 @@ func (p *Authenticate) SignInPage(w http.ResponseWriter, r *http.Request) {
Destination: destinationURL.Host, Destination: destinationURL.Host,
Version: version.FullVersion(), Version: version.FullVersion(),
} }
log.Ctx(r.Context()).Info(). log.FromRequest(r).Debug().
Str("ProviderName", p.provider.Data().ProviderName). Str("ProviderName", p.provider.Data().ProviderName).
Str("Redirect", redirectURL.String()). Str("Redirect", redirectURL.String()).
Str("Destination", destinationURL.Host). Str("Destination", destinationURL.Host).
Str("AllowedDomains", strings.Join(p.AllowedDomains, ", ")). Str("AllowedDomains", strings.Join(p.AllowedDomains, ", ")).
Msg("authenticate.SignInPage") Msg("authenticate: SignInPage")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
p.templates.ExecuteTemplate(w, "sign_in.html", t) p.templates.ExecuteTemplate(w, "sign_in.html", t)
} }
func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*sessions.SessionState, error) { func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*sessions.SessionState, error) {
// requestLog := log.WithRequest(req, "authenticate.authenticate")
session, err := p.sessionStore.LoadSession(r) session, err := p.sessionStore.LoadSession(r)
if err != nil { if err != nil {
log.Error().Err(err).Msg("authenticate.authenticate") log.Error().Err(err).Msg("authenticate: failed to load session")
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
return nil, err return nil, err
} }
// ensure sessions lifetime has not expired // ensure sessions lifetime has not expired
if session.LifetimePeriodExpired() { if session.LifetimePeriodExpired() {
log.Ctx(r.Context()).Warn().Msg("lifetime expired") log.FromRequest(r).Warn().Msg("authenticate: lifetime expired")
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
return nil, sessions.ErrLifetimeExpired return nil, sessions.ErrLifetimeExpired
} }
@ -129,12 +139,12 @@ func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
if session.RefreshPeriodExpired() { if session.RefreshPeriodExpired() {
ok, err := p.provider.RefreshSessionIfNeeded(session) ok, err := p.provider.RefreshSessionIfNeeded(session)
if err != nil { if err != nil {
log.Ctx(r.Context()).Error().Err(err).Msg("failed to refresh session") log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to refresh session")
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
return nil, err return nil, err
} }
if !ok { if !ok {
log.Ctx(r.Context()).Error().Msg("user unauthorized after refresh") log.FromRequest(r).Error().Msg("user unauthorized after refresh")
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
return nil, httputil.ErrUserNotAuthorized return nil, httputil.ErrUserNotAuthorized
} }
@ -144,7 +154,7 @@ func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
// We refreshed the session successfully, but failed to save it. // We refreshed the session successfully, but failed to save it.
// This could be from failing to encode the session properly. // This could be from failing to encode the session properly.
// But, we clear the session cookie and reject the request // But, we clear the session cookie and reject the request
log.Ctx(r.Context()).Error().Err(err).Msg("could not save refreshed session") log.FromRequest(r).Error().Err(err).Msg("could not save refreshed session")
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
return nil, err return nil, err
} }
@ -152,20 +162,20 @@ func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
// The session has not exceeded it's lifetime or requires refresh // The session has not exceeded it's lifetime or requires refresh
ok := p.provider.ValidateSessionState(session) ok := p.provider.ValidateSessionState(session)
if !ok { if !ok {
log.Ctx(r.Context()).Error().Msg("invalid session state") log.FromRequest(r).Error().Msg("invalid session state")
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
return nil, httputil.ErrUserNotAuthorized return nil, httputil.ErrUserNotAuthorized
} }
err = p.sessionStore.SaveSession(w, r, session) err = p.sessionStore.SaveSession(w, r, session)
if err != nil { if err != nil {
log.Ctx(r.Context()).Error().Err(err).Msg("failed to save valid session") log.FromRequest(r).Error().Err(err).Msg("failed to save valid session")
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
return nil, err return nil, err
} }
} }
if !p.Validator(session.Email) { if !p.Validator(session.Email) {
log.Ctx(r.Context()).Error().Msg("invalid email user") log.FromRequest(r).Error().Msg("invalid email user")
return nil, httputil.ErrUserNotAuthorized return nil, httputil.ErrUserNotAuthorized
} }
return session, nil return session, nil
@ -316,7 +326,7 @@ func (p *Authenticate) SignOutPage(w http.ResponseWriter, r *http.Request, messa
signature := r.Form.Get("sig") signature := r.Form.Get("sig")
timestamp := r.Form.Get("ts") timestamp := r.Form.Get("ts")
destinationURL, _ := url.Parse(redirectURI) destinationURL, _ := url.Parse(redirectURI) //checked by middleware
// An error message indicates that an internal server error occurred // An error message indicates that an internal server error occurred
if message != "" { if message != "" {
@ -341,7 +351,6 @@ func (p *Authenticate) SignOutPage(w http.ResponseWriter, r *http.Request, messa
Version: version.FullVersion(), Version: version.FullVersion(),
} }
p.templates.ExecuteTemplate(w, "sign_out.html", t) p.templates.ExecuteTemplate(w, "sign_out.html", t)
return
} }
// OAuthStart starts the authentication process by redirecting to the provider. It provides a // OAuthStart starts the authentication process by redirecting to the provider. It provides a
@ -364,20 +373,20 @@ func (p *Authenticate) helperOAuthStart(w http.ResponseWriter, r *http.Request,
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey()) nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
p.csrfStore.SetCSRF(w, r, nonce) p.csrfStore.SetCSRF(w, r, nonce)
if !validRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) { if !middleware.ValidRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest) httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
return return
} }
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri")) proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
if err != nil || !validRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) { if err != nil || !middleware.ValidRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest) httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
return return
} }
proxyRedirectSig := authRedirectURL.Query().Get("sig") proxyRedirectSig := authRedirectURL.Query().Get("sig")
ts := authRedirectURL.Query().Get("ts") ts := authRedirectURL.Query().Get("ts")
if !validSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.SharedKey) { if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.SharedKey) {
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest) httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
return return
} }
@ -404,7 +413,6 @@ func (p *Authenticate) redeemCode(host, code string) (*sessions.SessionState, er
// getOAuthCallback completes the oauth cycle from an identity provider's callback // getOAuthCallback completes the oauth cycle from an identity provider's callback
func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) { func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) {
// requestLog := log.WithRequest(req, "authenticate.getOAuthCallback")
// finish the oauth cycle // finish the oauth cycle
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
@ -421,17 +429,18 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
session, err := p.redeemCode(r.Host, code) session, err := p.redeemCode(r.Host, code)
if err != nil { if err != nil {
log.Ctx(r.Context()).Error().Err(err).Msg("error redeeming authentication code") log.FromRequest(r).Error().Err(err).Msg("error redeeming authentication code")
return "", err return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
} }
bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state")) bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state"))
if err != nil { if err != nil {
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"} log.FromRequest(r).Error().Err(err).Msg("failed decoding state")
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Couldn't decode state"}
} }
s := strings.SplitN(string(bytes), ":", 2) s := strings.SplitN(string(bytes), ":", 2)
if len(s) != 2 { if len(s) != 2 {
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"} return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Invalid State"}
} }
nonce := s[0] nonce := s[0]
redirect := s[1] redirect := s[1]
@ -441,11 +450,11 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
} }
p.csrfStore.ClearCSRF(w, r) p.csrfStore.ClearCSRF(w, r)
if c.Value != nonce { if c.Value != nonce {
log.Ctx(r.Context()).Error().Err(err).Msg("csrf token mismatch") log.FromRequest(r).Error().Err(err).Msg("CSRF token mismatch")
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "csrf failed"} return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "CSRF failed"}
} }
if !validRedirectURI(redirect, p.ProxyRootDomains) { if !middleware.ValidRedirectURI(redirect, p.ProxyRootDomains) {
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Redirect URI"} return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Redirect URI"}
} }
@ -453,10 +462,10 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
// - for p.Validator see validator.go#newValidatorImpl for more info // - for p.Validator see validator.go#newValidatorImpl for more info
// - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info // - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info
if !p.Validator(session.Email) { if !p.Validator(session.Email) {
log.Ctx(r.Context()).Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied") log.FromRequest(r).Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Account"} return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Account"}
} }
log.Ctx(r.Context()).Info().Str("email", session.Email).Msg("authentication complete") log.FromRequest(r).Info().Str("email", session.Email).Msg("authentication complete")
err = p.sessionStore.SaveSession(w, r, session) err = p.sessionStore.SaveSession(w, r, session)
if err != nil { if err != nil {
log.Error().Err(err).Msg("internal error") log.Error().Err(err).Msg("internal error")
@ -487,7 +496,6 @@ func (p *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
func (p *Authenticate) Redeem(w http.ResponseWriter, r *http.Request) { func (p *Authenticate) Redeem(w http.ResponseWriter, r *http.Request) {
// The auth code is redeemed by the sso proxy for an access token, refresh token, // The auth code is redeemed by the sso proxy for an access token, refresh token,
// expiration, and email. // expiration, and email.
// requestLog := log.WithRequest(req, "authenticate.Redeem")
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest) http.Error(w, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
@ -496,19 +504,19 @@ func (p *Authenticate) Redeem(w http.ResponseWriter, r *http.Request) {
session, err := sessions.UnmarshalSession(r.Form.Get("code"), p.cipher) session, err := sessions.UnmarshalSession(r.Form.Get("code"), p.cipher)
if err != nil { if err != nil {
log.Ctx(r.Context()).Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid auth code") log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to unmarshal session")
http.Error(w, fmt.Sprintf("invalid auth code: %s", err.Error()), http.StatusUnauthorized) http.Error(w, fmt.Sprintf("invalid auth code: %s", err.Error()), http.StatusUnauthorized)
return return
} }
if session == nil { if session == nil {
log.Ctx(r.Context()).Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid session") log.FromRequest(r).Error().Err(err).Msg("empty session")
http.Error(w, fmt.Sprintf("invalid session: %s", err.Error()), http.StatusUnauthorized) http.Error(w, fmt.Sprintf("empty session: %s", err.Error()), http.StatusUnauthorized)
return return
} }
if session != nil && (session.RefreshPeriodExpired() || session.LifetimePeriodExpired()) { if session != nil && (session.RefreshPeriodExpired() || session.LifetimePeriodExpired()) {
log.Ctx(r.Context()).Error().Msg("expired session") log.FromRequest(r).Error().Msg("expired session")
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
http.Error(w, fmt.Sprintf("expired session"), http.StatusUnauthorized) http.Error(w, fmt.Sprintf("expired session"), http.StatusUnauthorized)
return return
@ -524,7 +532,7 @@ func (p *Authenticate) Redeem(w http.ResponseWriter, r *http.Request) {
AccessToken: session.AccessToken, AccessToken: session.AccessToken,
RefreshToken: session.RefreshToken, RefreshToken: session.RefreshToken,
IDToken: session.IDToken, IDToken: session.IDToken,
ExpiresIn: int64(session.RefreshDeadline.Sub(time.Now()).Seconds()), ExpiresIn: int64(time.Until(session.RefreshDeadline).Seconds()),
Email: session.Email, Email: session.Email,
} }
@ -643,7 +651,5 @@ func (p *Authenticate) ValidateToken(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
return
} }

View file

@ -77,7 +77,7 @@ func TestAuthenticate_SignInPage(t *testing.T) {
if status := rr.Code; status != http.StatusOK { if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
} }
body := []byte(rr.Body.String()) body := rr.Body.Bytes()
tests := []struct { tests := []struct {
name string name string

View file

@ -1,109 +0,0 @@
package authenticate // import "github.com/pomerium/pomerium/authenticate"
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/pomerium/pomerium/internal/httputil"
)
var defaultSignatureValidityDuration = 5 * time.Minute
// validateRedirectURI checks the redirect uri in the query parameters and ensures that
// the url's domain is one in the list of proxy root domains.
func validateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
return
}
redirectURI := req.Form.Get("redirect_uri")
if !validRedirectURI(redirectURI, proxyRootDomains) {
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
return
}
f(rw, req)
}
}
func validRedirectURI(uri string, rootDomains []string) bool {
if uri == "" || len(rootDomains) == 0 {
return false
}
redirectURL, err := url.Parse(uri)
if err != nil || redirectURL.Host == "" {
return false
}
for _, domain := range rootDomains {
if domain == "" {
return false
}
if strings.HasSuffix(redirectURL.Hostname(), domain) {
return true
}
}
return false
}
func validateSignature(f http.HandlerFunc, sharedKey string) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
return
}
redirectURI := req.Form.Get("redirect_uri")
sigVal := req.Form.Get("sig")
timestamp := req.Form.Get("ts")
if !validSignature(redirectURI, sigVal, timestamp, sharedKey) {
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
return
}
f(rw, req)
}
}
// validateSignature ensures the validity of the redirect url by comparing the hmac
// digest, and ensuring that the included timestamp is fresh
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
return false
}
_, err := url.Parse(redirectURI)
if err != nil {
return false
}
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
if err != nil {
return false
}
i, err := strconv.ParseInt(timestamp, 10, 64)
if err != nil {
return false
}
tm := time.Unix(i, 0)
if time.Now().Sub(tm) > defaultSignatureValidityDuration {
return false
}
localSig := redirectURLSignature(redirectURI, tm, secret)
return hmac.Equal(requestSig, localSig)
}
// redirectURLSignature generates a hmac digest from a
// redirect url, a timestamp, and a secret.
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(rawRedirect))
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
return h.Sum(nil)
}

View file

@ -1,85 +0,0 @@
package authenticate
import (
"encoding/base64"
"fmt"
"testing"
"time"
)
func Test_validRedirectURI(t *testing.T) {
tests := []struct {
name string
uri string
rootDomains []string
want bool
}{
{"good url redirect", "https://example.com/redirect", []string{"example.com"}, true},
{"bad domain", "https://example.com/redirect", []string{"notexample.com"}, false},
{"malformed url", "^example.com/redirect", []string{"notexample.com"}, false},
{"empty domain list", "https://example.com/redirect", []string{}, false},
{"empty domain", "https://example.com/redirect", []string{""}, false},
{"empty url", "", []string{"example.com"}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := validRedirectURI(tt.uri, tt.rootDomains); got != tt.want {
t.Errorf("validRedirectURI() = %v, want %v", got, tt.want)
}
})
}
}
func Test_validSignature(t *testing.T) {
goodURL := "https://example.com/redirect"
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
now := fmt.Sprint(time.Now().Unix())
rawSig := redirectURLSignature(goodURL, time.Now(), secretA)
sig := base64.URLEncoding.EncodeToString(rawSig)
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
tests := []struct {
name string
redirectURI string
sigVal string
timestamp string
secret string
want bool
}{
{"good signature", goodURL, string(sig), now, secretA, true},
{"empty redirect url", "", string(sig), now, secretA, false},
{"bad redirect url", "https://google.com^", string(sig), now, secretA, false},
{"malformed signature", goodURL, string(sig + "^"), now, "&*&@**($&#(", false},
{"malformed timestamp", goodURL, string(sig), now + "^", secretA, false},
{"stale timestamp", goodURL, string(sig), staleTime, secretA, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := validSignature(tt.redirectURI, tt.sigVal, tt.timestamp, tt.secret); got != tt.want {
t.Errorf("validSignature() = %v, want %v", got, tt.want)
}
})
}
}
func Test_redirectURLSignature(t *testing.T) {
tests := []struct {
name string
rawRedirect string
timestamp time.Time
secret string
want string
}{
{"good signature", "https://example.com/redirect", time.Unix(1546797901, 0), "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=", "GIDyWKjrG_7MwXwIq1o51f2pDT_rH9aLHdsHxSBEwy8="},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := redirectURLSignature(tt.rawRedirect, tt.timestamp, tt.secret)
out := base64.URLEncoding.EncodeToString(got)
if out != tt.want {
t.Errorf("redirectURLSignature() = %v, want %v", tt.want, out)
}
})
}
}

View file

@ -258,7 +258,7 @@ func (p *ProviderData) RefreshAccessToken(refreshToken string) (string, time.Dur
log.Error().Err(err).Msg("authenticate/providers.RefreshAccessToken") log.Error().Err(err).Msg("authenticate/providers.RefreshAccessToken")
return "", 0, err return "", 0, err
} }
return newToken.AccessToken, newToken.Expiry.Sub(time.Now()), nil return newToken.AccessToken, time.Until(newToken.Expiry), nil
} }
// Revoke enables a user to revoke her token. If the identity provider supports revocation // Revoke enables a user to revoke her token. If the identity provider supports revocation

View file

@ -75,8 +75,3 @@ func (tp *TestProvider) Redeem(code string) (*sessions.SessionState, error) {
return tp.Session, tp.RedeemError return tp.Session, tp.RedeemError
} }
// Stop fulfills the Provider interface
func (tp *TestProvider) Stop() {
return
}

View file

@ -7,10 +7,10 @@
# Certificates can be loaded as files or base64 encoded bytes. If neither is set, a # Certificates can be loaded as files or base64 encoded bytes. If neither is set, a
# pomerium will attempt to locate a pair in the root directory # pomerium will attempt to locate a pair in the root directory
export CERTIFICATE="xxxxxx" # base64 encoded cert, eg. `base64 -i cert.pem`
export CERTIFICATE_KEY="xxxx" # base64 encoded key, eg. `base64 -i privkey.pem`
export CERTIFICATE_FILE="./cert.pem" # optional, defaults to `./cert.pem` export CERTIFICATE_FILE="./cert.pem" # optional, defaults to `./cert.pem`
export CERTIFICATE_KEY_FILE="./privkey.pem" # optional, defaults to `./certprivkey.pem` export CERTIFICATE_KEY_FILE="./privkey.pem" # optional, defaults to `./certprivkey.pem`
# export CERTIFICATE="xxxxxx" # base64 encoded cert, eg. `base64 -i cert.pem`
# export CERTIFICATE_KEY="xxxx" # base64 encoded key, eg. `base64 -i privkey.pem`
# The URL that the identity provider will call back after authenticating the user # The URL that the identity provider will call back after authenticating the user
export REDIRECT_URL="https://sso-auth.corp.example.com/oauth2/callback" export REDIRECT_URL="https://sso-auth.corp.example.com/oauth2/callback"

1
go.mod
View file

@ -9,7 +9,6 @@ require (
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/rs/zerolog v1.11.0 github.com/rs/zerolog v1.11.0
github.com/stretchr/testify v1.2.2 // indirect github.com/stretchr/testify v1.2.2 // indirect
github.com/zenazn/goji v0.9.0
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 // indirect golang.org/x/net v0.0.0-20181220203305-927f97764cc3 // indirect
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890 golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890

2
go.sum
View file

@ -16,8 +16,6 @@ github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M=
github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/zenazn/goji v0.9.0 h1:RSQQAbXGArQ0dIDEq+PI6WqN6if+5KHu6x2Cx/GXLTQ=
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=

View file

@ -14,7 +14,7 @@ import (
) )
// ErrTokenRevoked signifies a token revokation or expiration error // ErrTokenRevoked signifies a token revokation or expiration error
var ErrTokenRevoked = errors.New("Token expired or revoked") var ErrTokenRevoked = errors.New("token expired or revoked")
var httpClient = &http.Client{ var httpClient = &http.Client{
Timeout: time.Second * 5, Timeout: time.Second * 5,

View file

@ -3,6 +3,7 @@ package log // import "github.com/pomerium/pomerium/internal/log"
import ( import (
"context" "context"
"net/http"
"os" "os"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -102,3 +103,9 @@ func Printf(format string, v ...interface{}) {
func Ctx(ctx context.Context) *zerolog.Logger { func Ctx(ctx context.Context) *zerolog.Logger {
return zerolog.Ctx(ctx) return zerolog.Ctx(ctx)
} }
// FromRequest gets the logger in the request's context.
// This is a shortcut for log.Ctx(r.Context())
func FromRequest(r *http.Request) *zerolog.Logger {
return Ctx(r.Context())
}

View file

@ -0,0 +1,2 @@
// Package middleware provides a standard set of middleware implementations for pomerium.
package middleware // import "github.com/pomerium/pomerium/internal/middleware"

View file

@ -1,5 +1,4 @@
// Package log provides a set of http.Handler helpers for zerolog. package middleware // import "github.com/pomerium/pomerium/internal/middleware"
package log // import "github.com/pomerium/pomerium/internal/log"
import ( import (
"context" "context"
@ -10,14 +9,14 @@ import (
"strings" "strings"
"time" "time"
"github.com/pomerium/pomerium/internal/log"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/zenazn/goji/web/mutil"
) )
// FromRequest gets the logger in the request's context. // FromRequest gets the logger in the request's context.
// This is a shortcut for log.Ctx(r.Context()) // This is a shortcut for log.Ctx(r.Context())
func FromRequest(r *http.Request) *zerolog.Logger { func FromRequest(r *http.Request) *zerolog.Logger {
return Ctx(r.Context()) return log.Ctx(r.Context())
} }
// NewHandler injects log into requests context. // NewHandler injects log into requests context.
@ -172,7 +171,7 @@ func AccessHandler(f func(r *http.Request, status, size int, duration time.Durat
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now() start := time.Now()
lw := mutil.WrapWriter(w) lw := NewWrapResponseWriter(w, 2)
next.ServeHTTP(lw, r) next.ServeHTTP(lw, r)
f(r, lw.Status(), lw.BytesWritten(), time.Since(start)) f(r, lw.Status(), lw.BytesWritten(), time.Since(start))
}) })

View file

@ -1,5 +1,4 @@
// Package log provides a set of http.Handler helpers for zerolog. package middleware // import "github.com/pomerium/pomerium/internal/middleware"
package log // import "github.com/pomerium/pomerium/internal/log"
import ( import (
"bytes" "bytes"
@ -258,3 +257,21 @@ func BenchmarkDataRace(b *testing.B) {
} }
}) })
} }
func TestForwardedAddrHandler(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
Header: http.Header{
"X-Forwarded-For": []string{"client", "proxy1", "proxy2"},
},
}
h := ForwardedAddrHandler("fwd_ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
}))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"fwd_ip":"client"}`+"\n", decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

View file

@ -1,4 +1,3 @@
// Package middleware provides a standard set of middleware implementations for pomerium.
package middleware // import "github.com/pomerium/pomerium/internal/middleware" package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import ( import (
@ -15,91 +14,76 @@ import (
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
) )
// SetHeadersOld ensures that every response includes some basic security headers
func SetHeadersOld(h http.Handler, securityHeaders map[string]string) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for key, val := range securityHeaders {
rw.Header().Set(key, val)
}
h.ServeHTTP(rw, req)
})
}
// SetHeaders ensures that every response includes some basic security headers // SetHeaders ensures that every response includes some basic security headers
func SetHeaders(securityHeaders map[string]string) func(next http.Handler) http.Handler { func SetHeaders(securityHeaders map[string]string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for key, val := range securityHeaders { for key, val := range securityHeaders {
rw.Header().Set(key, val) w.Header().Set(key, val)
} }
next.ServeHTTP(rw, req) next.ServeHTTP(w, r)
}) })
} }
} }
// WithMethods writes an error response if the method of the request is not included.
func WithMethods(f http.HandlerFunc, methods ...string) http.HandlerFunc {
methodMap := make(map[string]struct{}, len(methods))
for _, m := range methods {
methodMap[m] = struct{}{}
}
return func(rw http.ResponseWriter, req *http.Request) {
if _, ok := methodMap[req.Method]; !ok {
httputil.ErrorResponse(rw, req, fmt.Sprintf("method %s not allowed", req.Method), http.StatusMethodNotAllowed)
return
}
f(rw, req)
}
}
// ValidateClientSecret checks the request header for the client secret and returns // ValidateClientSecret checks the request header for the client secret and returns
// an error if it does not match the proxy client secret // an error if it does not match the proxy client secret
func ValidateClientSecret(f http.HandlerFunc, sharedSecret string) http.HandlerFunc { func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Handler {
return func(rw http.ResponseWriter, req *http.Request) { return func(next http.Handler) http.Handler {
err := req.ParseForm() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err != nil { err := r.ParseForm()
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError) if err != nil {
return httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
} return
clientSecret := req.Form.Get("shared_secret") }
// check the request header for the client secret clientSecret := r.Form.Get("shared_secret")
if clientSecret == "" { // check the request header for the client secret
clientSecret = req.Header.Get("X-Client-Secret") if clientSecret == "" {
} clientSecret = r.Header.Get("X-Client-Secret")
}
if clientSecret != sharedSecret { if clientSecret != sharedSecret {
httputil.ErrorResponse(rw, req, "Invalid client secret", http.StatusUnauthorized) httputil.ErrorResponse(w, r, "Invalid client secret", http.StatusUnauthorized)
return return
} }
f(rw, req) next.ServeHTTP(w, r)
})
} }
} }
// ValidateRedirectURI checks the redirect uri in the query parameters and ensures that // ValidateRedirectURI checks the redirect uri in the query parameters and ensures that
// the url's domain is one in the list of proxy root domains. // the its domain is in the list of proxy root domains.
func ValidateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc { func ValidateRedirectURI(proxyRootDomains []string) func(next http.Handler) http.Handler {
return func(rw http.ResponseWriter, req *http.Request) { return func(next http.Handler) http.Handler {
err := req.ParseForm() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err != nil { err := r.ParseForm()
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest) if err != nil {
return httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
} return
redirectURI := req.Form.Get("redirect_uri") }
if !validRedirectURI(redirectURI, proxyRootDomains) { redirectURI := r.Form.Get("redirect_uri")
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest) if !ValidRedirectURI(redirectURI, proxyRootDomains) {
return httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
} return
}
f(rw, req) next.ServeHTTP(w, r)
})
} }
} }
func validRedirectURI(uri string, rootDomains []string) bool { // ValidRedirectURI checks if a URL's domain is one in the list of proxy root domains.
func ValidRedirectURI(uri string, rootDomains []string) bool {
if uri == "" || len(rootDomains) == 0 {
return false
}
redirectURL, err := url.Parse(uri) redirectURL, err := url.Parse(uri)
if uri == "" || err != nil || redirectURL.Host == "" { if err != nil || redirectURL.Host == "" {
return false return false
} }
for _, domain := range rootDomains { for _, domain := range rootDomains {
if domain == "" {
return false
}
if strings.HasSuffix(redirectURL.Hostname(), domain) { if strings.HasSuffix(redirectURL.Hostname(), domain) {
return true return true
} }
@ -109,35 +93,36 @@ func validRedirectURI(uri string, rootDomains []string) bool {
// ValidateSignature ensures the request is valid and has been signed with // ValidateSignature ensures the request is valid and has been signed with
// the correspdoning client secret key // the correspdoning client secret key
func ValidateSignature(f http.HandlerFunc, sharedSecret string) http.HandlerFunc { func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler {
return func(rw http.ResponseWriter, req *http.Request) { return func(next http.Handler) http.Handler {
err := req.ParseForm() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err != nil { err := r.ParseForm()
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest) if err != nil {
return httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
} return
redirectURI := req.Form.Get("redirect_uri") }
sigVal := req.Form.Get("sig") redirectURI := r.Form.Get("redirect_uri")
timestamp := req.Form.Get("ts") sigVal := r.Form.Get("sig")
if !validSignature(redirectURI, sigVal, timestamp, sharedSecret) { timestamp := r.Form.Get("ts")
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest) if !ValidSignature(redirectURI, sigVal, timestamp, sharedSecret) {
return httputil.ErrorResponse(w, r, "Cross service signature failed to validate", http.StatusUnauthorized)
} return
}
f(rw, req) next.ServeHTTP(w, r)
})
} }
} }
// ValidateHost ensures that each request's host is valid // ValidateHost ensures that each request's host is valid
func ValidateHost(mux map[string]*http.Handler) func(next http.Handler) http.Handler { func ValidateHost(mux map[string]http.Handler) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { if _, ok := mux[r.Host]; !ok {
if _, ok := mux[req.Host]; !ok { httputil.ErrorResponse(w, r, "Unknown host to route", http.StatusNotFound)
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
return return
} }
next.ServeHTTP(rw, req) next.ServeHTTP(w, r)
}) })
} }
} }
@ -145,25 +130,47 @@ func ValidateHost(mux map[string]*http.Handler) func(next http.Handler) http.Han
// RequireHTTPS reroutes a HTTP request to HTTPS // RequireHTTPS reroutes a HTTP request to HTTPS
// todo(bdd) : this is unreliable unless behind another reverser proxy // todo(bdd) : this is unreliable unless behind another reverser proxy
// todo(bdd) : header age seems extreme // todo(bdd) : header age seems extreme
func RequireHTTPS(h http.Handler) http.Handler { func RequireHTTPS(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw.Header().Set("Strict-Transport-Security", "max-age=31536000") w.Header().Set("Strict-Transport-Security", "max-age=31536000")
// todo(bdd) : scheme and x-forwarded-proto cannot be trusted if not behind another load balancer // todo(bdd) : scheme and x-forwarded-proto cannot be trusted if not behind another load balancer
if (req.URL.Scheme == "http" && req.Header.Get("X-Forwarded-Proto") == "http") || &req.TLS == nil { if (r.URL.Scheme == "http" && r.Header.Get("X-Forwarded-Proto") == "http") || &r.TLS == nil {
dest := &url.URL{ dest := &url.URL{
Scheme: "https", Scheme: "https",
Host: req.Host, Host: r.Host,
Path: req.URL.Path, Path: r.URL.Path,
RawQuery: req.URL.RawQuery, RawQuery: r.URL.RawQuery,
} }
http.Redirect(rw, req, dest.String(), http.StatusMovedPermanently) http.Redirect(w, r, dest.String(), http.StatusMovedPermanently)
return return
} }
h.ServeHTTP(rw, req) next.ServeHTTP(w, r)
}) })
} }
func validSignature(redirectURI, sigVal, timestamp, secret string) bool { // Healthcheck endpoint middleware useful to setting up a path like
// `/ping` that load balancers or uptime testing external services
// can make a request before hitting any routes. It's also convenient
// to place this above ACL middlewares as well.
func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler {
f := func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" && strings.EqualFold(r.URL.Path, endpoint) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(msg))
return
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
return f
}
// ValidSignature checks to see if a signature is valid. Compares hmac of
// redirect uri, timestamp, and secret and signature.
func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" { if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
return false return false
} }

View file

@ -0,0 +1,324 @@
package middleware
import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
)
func Test_ValidRedirectURI(t *testing.T) {
tests := []struct {
name string
uri string
rootDomains []string
want bool
}{
{"good url redirect", "https://example.com/redirect", []string{"example.com"}, true},
{"bad domain", "https://example.com/redirect", []string{"notexample.com"}, false},
{"malformed url", "^example.com/redirect", []string{"notexample.com"}, false},
{"empty domain list", "https://example.com/redirect", []string{}, false},
{"empty domain", "https://example.com/redirect", []string{""}, false},
{"empty url", "", []string{"example.com"}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ValidRedirectURI(tt.uri, tt.rootDomains); got != tt.want {
t.Errorf("ValidRedirectURI() = %v, want %v", got, tt.want)
}
})
}
}
func Test_ValidSignature(t *testing.T) {
goodURL := "https://example.com/redirect"
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
now := fmt.Sprint(time.Now().Unix())
rawSig := redirectURLSignature(goodURL, time.Now(), secretA)
sig := base64.URLEncoding.EncodeToString(rawSig)
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
tests := []struct {
name string
redirectURI string
sigVal string
timestamp string
secret string
want bool
}{
{"good signature", goodURL, string(sig), now, secretA, true},
{"empty redirect url", "", string(sig), now, secretA, false},
{"bad redirect url", "https://google.com^", string(sig), now, secretA, false},
{"malformed signature", goodURL, string(sig + "^"), now, "&*&@**($&#(", false},
{"malformed timestamp", goodURL, string(sig), now + "^", secretA, false},
{"stale timestamp", goodURL, string(sig), staleTime, secretA, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ValidSignature(tt.redirectURI, tt.sigVal, tt.timestamp, tt.secret); got != tt.want {
t.Errorf("ValidSignature() = %v, want %v", got, tt.want)
}
})
}
}
func Test_redirectURLSignature(t *testing.T) {
tests := []struct {
name string
rawRedirect string
timestamp time.Time
secret string
want string
}{
{"good signature", "https://example.com/redirect", time.Unix(1546797901, 0), "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=", "GIDyWKjrG_7MwXwIq1o51f2pDT_rH9aLHdsHxSBEwy8="},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := redirectURLSignature(tt.rawRedirect, tt.timestamp, tt.secret)
out := base64.URLEncoding.EncodeToString(got)
if out != tt.want {
t.Errorf("redirectURLSignature() = %v, want %v", tt.want, out)
}
})
}
}
func TestSetHeaders(t *testing.T) {
tests := []struct {
name string
securityHeaders map[string]string
}{
{"one option", map[string]string{"X-Frame-Options": "DENY"}},
{"two options", map[string]string{"X-Frame-Options": "DENY", "A": "B"}},
}
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for k, want := range tt.securityHeaders {
if got := w.Header().Get(k); want != got {
t.Errorf("want %s got %q", want, got)
}
}
})
rr := httptest.NewRecorder()
handler := SetHeaders(tt.securityHeaders)(testHandler)
handler.ServeHTTP(rr, req)
})
}
}
func TestValidateRedirectURI(t *testing.T) {
tests := []struct {
name string
proxyRootDomains []string
redirectURI string
status int
}{
{"simple", []string{"google.com"}, "https://google.com", http.StatusOK},
{"bad match", []string{"aol.com"}, "https://google.com", http.StatusBadRequest},
{"with cname", []string{"google.com"}, "https://www.google.com", http.StatusOK},
{"with path", []string{"google.com"}, "https://www.google.com/path", http.StatusOK},
{"http", []string{"google.com"}, "http://www.google.com/path", http.StatusOK},
{"malformed, invalid hex digits", []string{"google.com"}, "%zzzzz", http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &http.Request{
Method: http.MethodGet,
URL: &url.URL{RawQuery: fmt.Sprintf("redirect_uri=%s", tt.redirectURI)},
}
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hi"))
})
rr := httptest.NewRecorder()
handler := ValidateRedirectURI(tt.proxyRootDomains)(testHandler)
handler.ServeHTTP(rr, req)
if rr.Code != tt.status {
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
t.Errorf("%s", rr.Body)
}
})
}
}
func TestValidateClientSecret(t *testing.T) {
tests := []struct {
name string
sharedSecret string
clientGetValue string
clientHeaderValue string
status int
}{
{"simple", "secret", "secret", "secret", http.StatusOK},
{"missing get param, valid header", "secret", "", "secret", http.StatusOK},
{"missing both", "secret", "", "", http.StatusUnauthorized},
{"simple bad", "bad-secret", "secret", "", http.StatusUnauthorized},
{"malformed, invalid hex digits", "secret", "%zzzzz", "", http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &http.Request{
Method: http.MethodGet,
Header: http.Header{"X-Client-Secret": []string{tt.clientHeaderValue}},
URL: &url.URL{RawQuery: fmt.Sprintf("shared_secret=%s", tt.clientGetValue)},
}
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hi"))
})
rr := httptest.NewRecorder()
handler := ValidateClientSecret(tt.sharedSecret)(testHandler)
handler.ServeHTTP(rr, req)
if rr.Code != tt.status {
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
t.Errorf("%s", rr.Body)
}
})
}
}
func TestValidateSignature(t *testing.T) {
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
now := fmt.Sprint(time.Now().Unix())
goodURL := "https://example.com/redirect"
rawSig := redirectURLSignature(goodURL, time.Now(), secretA)
sig := base64.URLEncoding.EncodeToString(rawSig)
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
tests := []struct {
name string
sharedSecret string
redirectURI string
sig string
ts string
status int
}{
{"valid signature", secretA, goodURL, sig, now, http.StatusOK},
{"stale signature", secretA, goodURL, sig, staleTime, http.StatusUnauthorized},
{"malformed", secretA, goodURL, "%zzzzz", now, http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := url.Values{}
v.Set("redirect_uri", tt.redirectURI)
v.Set("ts", tt.ts)
v.Set("sig", tt.sig)
req := &http.Request{
Method: http.MethodGet,
URL: &url.URL{RawQuery: v.Encode()}}
if tt.name == "malformed" {
req.URL.RawQuery = "sig=%zzzzz"
}
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hi"))
})
rr := httptest.NewRecorder()
handler := ValidateSignature(tt.sharedSecret)(testHandler)
handler.ServeHTTP(rr, req)
if rr.Code != tt.status {
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
t.Errorf("%s", rr.Body)
}
})
}
}
func TestHealthCheck(t *testing.T) {
tests := []struct {
name string
method string
clientPath string
expected []byte
}{
{"good", http.MethodGet, "/ping", []byte("OK")},
//tood(bdd): miss?
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, tt.clientPath, nil)
if err != nil {
t.Fatal(err)
}
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hi"))
})
rr := httptest.NewRecorder()
handler := Healthcheck(tt.clientPath, string(tt.expected))(testHandler)
handler.ServeHTTP(rr, req)
if rr.Body.String() != string(tt.expected) {
t.Errorf("body differs. got %ss want %ss", rr.Body, tt.expected)
t.Errorf("%s", rr.Body)
}
})
}
}
// Redirect to a fixed URL
type handlerHelper struct {
msg string
}
func (rh *handlerHelper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(rh.msg))
}
func handlerHelp(msg string) http.Handler {
return &handlerHelper{msg}
}
func TestValidateHost(t *testing.T) {
m := make(map[string]http.Handler)
m["google.com"] = handlerHelp("google")
tests := []struct {
name string
validHosts map[string]http.Handler
clientPath string
expected []byte
status int
}{
{"good", m, "google.com", []byte("google"), 200},
{"no route", m, "googles.com", []byte("google"), 404},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, tt.clientPath, nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
var testHandler http.Handler
if tt.validHosts[tt.clientPath] != nil {
tt.validHosts[tt.clientPath].ServeHTTP(rr, req)
testHandler = tt.validHosts[tt.clientPath]
} else {
testHandler = handlerHelp("ok")
}
handler := ValidateHost(tt.validHosts)(testHandler)
handler.ServeHTTP(rr, req)
if rr.Code != tt.status {
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
t.Errorf("%s", rr.Body)
}
})
}
}

View file

@ -0,0 +1,183 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
// The original work was derived from Goji's middleware, source:
// https://github.com/zenazn/goji/tree/master/web/middleware
import (
"bufio"
"io"
"net"
"net/http"
)
// NewWrapResponseWriter wraps an http.ResponseWriter, returning a proxy that allows you to
// hook into various parts of the response process.
func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) WrapResponseWriter {
_, fl := w.(http.Flusher)
bw := basicWriter{ResponseWriter: w}
if protoMajor == 2 {
_, ps := w.(http.Pusher)
if fl && ps {
return &http2FancyWriter{bw}
}
} else {
_, hj := w.(http.Hijacker)
_, rf := w.(io.ReaderFrom)
if fl && hj && rf {
return &httpFancyWriter{bw}
}
}
if fl {
return &flushWriter{bw}
}
return &bw
}
// WrapResponseWriter is a proxy around an http.ResponseWriter that allows you to hook
// into various parts of the response process.
type WrapResponseWriter interface {
http.ResponseWriter
// Status returns the HTTP status of the request, or 0 if one has not
// yet been sent.
Status() int
// BytesWritten returns the total number of bytes sent to the client.
BytesWritten() int
// Tee causes the response body to be written to the given io.Writer in
// addition to proxying the writes through. Only one io.Writer can be
// tee'd to at once: setting a second one will overwrite the first.
// Writes will be sent to the proxy before being written to this
// io.Writer. It is illegal for the tee'd writer to be modified
// concurrently with writes.
Tee(io.Writer)
// Unwrap returns the original proxied target.
Unwrap() http.ResponseWriter
}
// basicWriter wraps a http.ResponseWriter that implements the minimal
// http.ResponseWriter interface.
type basicWriter struct {
http.ResponseWriter
wroteHeader bool
code int
bytes int
tee io.Writer
}
func (b *basicWriter) WriteHeader(code int) {
if !b.wroteHeader {
b.code = code
b.wroteHeader = true
b.ResponseWriter.WriteHeader(code)
}
}
func (b *basicWriter) Write(buf []byte) (int, error) {
b.WriteHeader(http.StatusOK)
n, err := b.ResponseWriter.Write(buf)
if b.tee != nil {
_, err2 := b.tee.Write(buf[:n])
// Prefer errors generated by the proxied writer.
if err == nil {
err = err2
}
}
b.bytes += n
return n, err
}
func (b *basicWriter) maybeWriteHeader() {
if !b.wroteHeader {
b.WriteHeader(http.StatusOK)
}
}
func (b *basicWriter) Status() int {
return b.code
}
func (b *basicWriter) BytesWritten() int {
return b.bytes
}
func (b *basicWriter) Tee(w io.Writer) {
b.tee = w
}
func (b *basicWriter) Unwrap() http.ResponseWriter {
return b.ResponseWriter
}
type flushWriter struct {
basicWriter
}
func (f *flushWriter) Flush() {
f.wroteHeader = true
fl := f.basicWriter.ResponseWriter.(http.Flusher)
fl.Flush()
}
var _ http.Flusher = &flushWriter{}
// httpFancyWriter is a HTTP writer that additionally satisfies
// http.Flusher, http.Hijacker, and io.ReaderFrom. It exists for the common case
// of wrapping the http.ResponseWriter that package http gives you, in order to
// make the proxied object support the full method set of the proxied object.
type httpFancyWriter struct {
basicWriter
}
func (f *httpFancyWriter) Flush() {
f.wroteHeader = true
fl := f.basicWriter.ResponseWriter.(http.Flusher)
fl.Flush()
}
func (f *httpFancyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj := f.basicWriter.ResponseWriter.(http.Hijacker)
return hj.Hijack()
}
func (f *http2FancyWriter) Push(target string, opts *http.PushOptions) error {
return f.basicWriter.ResponseWriter.(http.Pusher).Push(target, opts)
}
func (f *httpFancyWriter) ReadFrom(r io.Reader) (int64, error) {
if f.basicWriter.tee != nil {
n, err := io.Copy(&f.basicWriter, r)
f.basicWriter.bytes += int(n)
return n, err
}
rf := f.basicWriter.ResponseWriter.(io.ReaderFrom)
f.basicWriter.maybeWriteHeader()
n, err := rf.ReadFrom(r)
f.basicWriter.bytes += int(n)
return n, err
}
var _ http.Flusher = &httpFancyWriter{}
var _ http.Hijacker = &httpFancyWriter{}
var _ http.Pusher = &http2FancyWriter{}
var _ io.ReaderFrom = &httpFancyWriter{}
// http2FancyWriter is a HTTP2 writer that additionally satisfies
// http.Flusher, and io.ReaderFrom. It exists for the common case
// of wrapping the http.ResponseWriter that package http gives you, in order to
// make the proxied object support the full method set of the proxied object.
type http2FancyWriter struct {
basicWriter
}
func (f *http2FancyWriter) Flush() {
f.wroteHeader = true
fl := f.basicWriter.ResponseWriter.(http.Flusher)
fl.Flush()
}
var _ http.Flusher = &http2FancyWriter{}

View file

@ -0,0 +1,33 @@
package middleware
import (
"net/http/httptest"
"testing"
)
func TestFlushWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
f := &flushWriter{basicWriter{ResponseWriter: httptest.NewRecorder()}}
f.Flush()
if !f.wroteHeader {
t.Fatal("want Flush to have set wroteHeader=true")
}
}
func TestHttpFancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
f := &httpFancyWriter{basicWriter{ResponseWriter: httptest.NewRecorder()}}
f.Flush()
if !f.wroteHeader {
t.Fatal("want Flush to have set wroteHeader=true")
}
}
func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
f := &http2FancyWriter{basicWriter{ResponseWriter: httptest.NewRecorder()}}
f.Flush()
if !f.wroteHeader {
t.Fatal("want Flush to have set wroteHeader=true")
}
}

View file

@ -46,9 +46,9 @@ func (p *Proxy) Handler() http.Handler {
// Middleware chain // Middleware chain
c := middleware.NewChain() c := middleware.NewChain()
c = c.Append(log.NewHandler(log.Logger)) c = c.Append(middleware.NewHandler(log.Logger))
c = c.Append(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) { c = c.Append(middleware.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
log.FromRequest(r).Info(). middleware.FromRequest(r).Info().
Str("method", r.Method). Str("method", r.Method).
Str("url", r.URL.String()). Str("url", r.URL.String()).
Int("status", status). Int("status", status).
@ -60,11 +60,11 @@ func (p *Proxy) Handler() http.Handler {
})) }))
c = c.Append(middleware.SetHeaders(securityHeaders)) c = c.Append(middleware.SetHeaders(securityHeaders))
c = c.Append(middleware.RequireHTTPS) c = c.Append(middleware.RequireHTTPS)
c = c.Append(log.ForwardedAddrHandler("fwd_ip")) c = c.Append(middleware.ForwardedAddrHandler("fwd_ip"))
c = c.Append(log.RemoteAddrHandler("ip")) c = c.Append(middleware.RemoteAddrHandler("ip"))
c = c.Append(log.UserAgentHandler("user_agent")) c = c.Append(middleware.UserAgentHandler("user_agent"))
c = c.Append(log.RefererHandler("referer")) c = c.Append(middleware.RefererHandler("referer"))
c = c.Append(log.RequestIDHandler("req_id", "Request-Id")) c = c.Append(middleware.RequestIDHandler("req_id", "Request-Id"))
c = c.Append(middleware.ValidateHost(p.mux)) c = c.Append(middleware.ValidateHost(p.mux))
h := c.Then(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := c.Then(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip host validation for /ping requests because they hit the LB directly. // Skip host validation for /ping requests because they hit the LB directly.
@ -260,8 +260,7 @@ func (p *Proxy) AuthenticateOnly(w http.ResponseWriter, r *http.Request) {
// or starting the authentication process if not. // or starting the authentication process if not.
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
// Attempts to validate the user and their cookie. // Attempts to validate the user and their cookie.
var err error err := p.Authenticate(w, r)
err = p.Authenticate(w, r)
// If the authentication is not successful we proceed to start the OAuth Flow with // If the authentication is not successful we proceed to start the OAuth Flow with
// OAuthStart. If successful, we proceed to proxy to the configured upstream. // OAuthStart. If successful, we proceed to proxy to the configured upstream.
if err != nil { if err != nil {
@ -387,7 +386,7 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
// Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig // Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig
func (p *Proxy) Handle(host string, handler http.Handler) { func (p *Proxy) Handle(host string, handler http.Handler) {
p.mux[host] = &handler p.mux[host] = handler
} }
// router attempts to find a route for a request. If a route is successfully matched, // router attempts to find a route for a request. If a route is successfully matched,
@ -396,7 +395,7 @@ func (p *Proxy) Handle(host string, handler http.Handler) {
func (p *Proxy) router(r *http.Request) (http.Handler, bool) { func (p *Proxy) router(r *http.Request) (http.Handler, bool) {
route, ok := p.mux[r.Host] route, ok := p.mux[r.Host]
if ok { if ok {
return *route, true return route, true
} }
return nil, false return nil, false
} }

View file

@ -135,7 +135,7 @@ type Proxy struct {
redirectURL *url.URL redirectURL *url.URL
templates *template.Template templates *template.Template
mux map[string]*http.Handler mux map[string]http.Handler
} }
// StateParameter holds the redirect id along with the session id. // StateParameter holds the redirect id along with the session id.
@ -184,7 +184,7 @@ func New(opts *Options) (*Proxy, error) {
p := &Proxy{ p := &Proxy{
// these fields make up the routing mechanism // these fields make up the routing mechanism
mux: make(map[string]*http.Handler), mux: make(map[string]http.Handler),
// session state // session state
cipher: cipher, cipher: cipher,
csrfStore: cookieStore, csrfStore: cookieStore,
@ -243,15 +243,12 @@ func deleteUpstreamCookies(req *http.Request, cookieName string) {
req.Header.Set("Cookie", strings.Join(headers, ";")) req.Header.Set("Cookie", strings.Join(headers, ";"))
} }
// signRequest signs a g
func (u *UpstreamProxy) signRequest(req *http.Request) { func (u *UpstreamProxy) signRequest(req *http.Request) {
if u.signer != nil { if u.signer != nil {
jwt, err := u.signer.SignJWT(req.Header.Get(HeaderUserID), req.Header.Get(HeaderEmail)) jwt, err := u.signer.SignJWT(req.Header.Get(HeaderUserID), req.Header.Get(HeaderEmail))
if err == nil { if err == nil {
req.Header.Set(HeaderJWT, jwt) req.Header.Set(HeaderJWT, jwt)
} }
} else {
} }
} }