mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-29 08:57:18 +02:00
authenticate: refactor middleware, logging, and tests (#30)
- Abstract remaining middleware from authenticate into internal. - Use middleware chaining in authenticate. - Standardize naming of Request and ResponseWriter to match std lib. - Add healthcheck / ping as a middleware. - Internalized wraped_writer package adapted from goji/middleware. - Fixed indirection issue with reverse proxy map.
This commit is contained in:
parent
b9c298d278
commit
7e1d1a7896
21 changed files with 768 additions and 397 deletions
|
@ -71,11 +71,10 @@ func OptionsFromEnvConfig() (*Options, error) {
|
||||||
return o, nil
|
return o, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate checks to see if configuration values are valid for authentication service.
|
// Validate checks to see if configuration values are valid for the authentication service.
|
||||||
// The checks do not modify the internal state of the Option structure. Function returns
|
// The checks do not modify the internal state of the Option structure. Returns
|
||||||
// on first error found.
|
// on first error found.
|
||||||
func (o *Options) Validate() error {
|
func (o *Options) Validate() error {
|
||||||
|
|
||||||
if o.RedirectURL == nil {
|
if o.RedirectURL == nil {
|
||||||
return errors.New("missing setting: identity provider redirect url")
|
return errors.New("missing setting: identity provider redirect url")
|
||||||
}
|
}
|
||||||
|
@ -105,11 +104,11 @@ func (o *Options) Validate() error {
|
||||||
if len(decodedCookieSecret) != 32 {
|
if len(decodedCookieSecret) != 32 {
|
||||||
return fmt.Errorf("cookie secret expects 32 bytes but got %d", len(decodedCookieSecret))
|
return fmt.Errorf("cookie secret expects 32 bytes but got %d", len(decodedCookieSecret))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate stores all the information associated with proxying the request.
|
// Authenticate is service for validating user authentication for proxied-requests
|
||||||
|
// against third-party identity provider (IdP) services.
|
||||||
type Authenticate struct {
|
type Authenticate struct {
|
||||||
RedirectURL *url.URL
|
RedirectURL *url.URL
|
||||||
|
|
||||||
|
@ -133,7 +132,7 @@ type Authenticate struct {
|
||||||
provider providers.Provider
|
provider providers.Provider
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a Authenticate struct and applies the optional functions slice to the struct.
|
// New validates and creates a new authentication service from a configuration options.
|
||||||
func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate, error) {
|
func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate, error) {
|
||||||
if opts == nil {
|
if opts == nil {
|
||||||
return nil, errors.New("options cannot be nil")
|
return nil, errors.New("options cannot be nil")
|
||||||
|
@ -179,13 +178,13 @@ func New(opts *Options, optionFuncs ...func(*Authenticate) error) (*Authenticate
|
||||||
cipher: cipher,
|
cipher: cipher,
|
||||||
skipProviderButton: opts.SkipProviderButton,
|
skipProviderButton: opts.SkipProviderButton,
|
||||||
}
|
}
|
||||||
// p.ServeMux = p.Handler()
|
|
||||||
p.provider, err = newProvider(opts)
|
p.provider, err = newProvider(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply the option functions
|
// validation via dependency injected function
|
||||||
for _, optFunc := range optionFuncs {
|
for _, optFunc := range optionFuncs {
|
||||||
err := optFunc(p)
|
err := optFunc(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
m "github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
)
|
)
|
||||||
|
@ -28,45 +28,58 @@ var securityHeaders = map[string]string{
|
||||||
|
|
||||||
// Handler returns the Http.Handlers for authentication, callback, and refresh
|
// Handler returns the Http.Handlers for authentication, callback, and refresh
|
||||||
func (p *Authenticate) Handler() http.Handler {
|
func (p *Authenticate) Handler() http.Handler {
|
||||||
|
// set up our standard middlewares
|
||||||
|
stdMiddleware := middleware.NewChain()
|
||||||
|
stdMiddleware = stdMiddleware.Append(middleware.NewHandler(log.Logger))
|
||||||
|
stdMiddleware = stdMiddleware.Append(middleware.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
||||||
|
// executed after handler route handler
|
||||||
|
middleware.FromRequest(r).Info().
|
||||||
|
Str("method", r.Method).
|
||||||
|
Str("url", r.URL.String()).
|
||||||
|
Int("status", status).
|
||||||
|
Int("size", size).
|
||||||
|
Dur("duration", duration).
|
||||||
|
Msg("request")
|
||||||
|
}))
|
||||||
|
stdMiddleware = stdMiddleware.Append(middleware.SetHeaders(securityHeaders))
|
||||||
|
stdMiddleware = stdMiddleware.Append(middleware.ForwardedAddrHandler("fwd_ip"))
|
||||||
|
stdMiddleware = stdMiddleware.Append(middleware.RemoteAddrHandler("ip"))
|
||||||
|
stdMiddleware = stdMiddleware.Append(middleware.UserAgentHandler("user_agent"))
|
||||||
|
stdMiddleware = stdMiddleware.Append(middleware.RefererHandler("referer"))
|
||||||
|
stdMiddleware = stdMiddleware.Append(middleware.RequestIDHandler("req_id", "Request-Id"))
|
||||||
|
stdMiddleware = stdMiddleware.Append(middleware.Healthcheck("/ping", version.UserAgent()))
|
||||||
|
validateSignatureMiddleware := stdMiddleware.Append(
|
||||||
|
middleware.ValidateSignature(p.SharedKey),
|
||||||
|
middleware.ValidateRedirectURI(p.ProxyRootDomains))
|
||||||
|
|
||||||
|
validateClientSecret := stdMiddleware.Append(middleware.ValidateClientSecret(p.SharedKey))
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
// we setup global endpoints that should respond to any hostname
|
// we setup global endpoints that should respond to any hostname
|
||||||
mux.HandleFunc("/ping", m.WithMethods(p.PingPage, "GET"))
|
mux.Handle("/ping", stdMiddleware.ThenFunc(p.PingPage))
|
||||||
|
|
||||||
serviceMux := http.NewServeMux()
|
serviceMux := http.NewServeMux()
|
||||||
// standard rest and healthcheck endpoints
|
// standard rest and healthcheck endpoints
|
||||||
serviceMux.HandleFunc("/ping", m.WithMethods(p.PingPage, "GET"))
|
serviceMux.Handle("/ping", stdMiddleware.ThenFunc(p.PingPage))
|
||||||
serviceMux.HandleFunc("/robots.txt", m.WithMethods(p.RobotsTxt, "GET"))
|
serviceMux.Handle("/robots.txt", stdMiddleware.ThenFunc(p.RobotsTxt))
|
||||||
// Identity Provider (IdP) endpoints and callbacks
|
// Identity Provider (IdP) endpoints and callbacks
|
||||||
serviceMux.HandleFunc("/start", m.WithMethods(p.OAuthStart, "GET"))
|
serviceMux.Handle("/start", stdMiddleware.ThenFunc(p.OAuthStart))
|
||||||
serviceMux.HandleFunc("/oauth2/callback", m.WithMethods(p.OAuthCallback, "GET"))
|
serviceMux.Handle("/oauth2/callback", stdMiddleware.ThenFunc(p.OAuthCallback))
|
||||||
// authenticator-server endpoints, todo(bdd): make gRPC
|
// authenticator-server endpoints, todo(bdd): make gRPC
|
||||||
serviceMux.HandleFunc("/sign_in", m.WithMethods(p.validateSignature(p.SignIn), "GET"))
|
serviceMux.Handle("/sign_in", validateSignatureMiddleware.ThenFunc(p.SignIn))
|
||||||
serviceMux.HandleFunc("/sign_out", m.WithMethods(p.validateSignature(p.SignOut), "GET", "POST"))
|
serviceMux.Handle("/sign_out", validateSignatureMiddleware.ThenFunc(p.SignOut)) // "GET", "POST"
|
||||||
serviceMux.HandleFunc("/profile", m.WithMethods(p.validateExisting(p.GetProfile), "GET"))
|
serviceMux.Handle("/profile", validateClientSecret.ThenFunc(p.GetProfile)) // GET
|
||||||
serviceMux.HandleFunc("/validate", m.WithMethods(p.validateExisting(p.ValidateToken), "GET"))
|
serviceMux.Handle("/validate", validateClientSecret.ThenFunc(p.ValidateToken)) // GET
|
||||||
serviceMux.HandleFunc("/redeem", m.WithMethods(p.validateExisting(p.Redeem), "POST"))
|
serviceMux.Handle("/redeem", validateClientSecret.ThenFunc(p.Redeem)) // POST
|
||||||
serviceMux.HandleFunc("/refresh", m.WithMethods(p.validateExisting(p.Refresh), "POST"))
|
serviceMux.Handle("/refresh", validateClientSecret.ThenFunc(p.Refresh)) //POST
|
||||||
|
|
||||||
// NOTE: we have to include trailing slash for the router to match the host header
|
// NOTE: we have to include trailing slash for the router to match the host header
|
||||||
host := p.RedirectURL.Host
|
host := p.RedirectURL.Host
|
||||||
if !strings.HasSuffix(host, "/") {
|
if !strings.HasSuffix(host, "/") {
|
||||||
host = fmt.Sprintf("%s/", host)
|
host = fmt.Sprintf("%s/", host)
|
||||||
}
|
}
|
||||||
mux.Handle(host, serviceMux) // setup our service mux to only handle our required host header
|
mux.Handle(host, serviceMux)
|
||||||
|
|
||||||
return m.SetHeadersOld(mux, securityHeaders)
|
return mux
|
||||||
}
|
|
||||||
|
|
||||||
// validateSignature wraps a common collection of middlewares to validate signatures
|
|
||||||
func (p *Authenticate) validateSignature(f http.HandlerFunc) http.HandlerFunc {
|
|
||||||
return validateRedirectURI(validateSignature(f, p.SharedKey), p.ProxyRootDomains)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateSignature wraps a common collection of middlewares to validate
|
|
||||||
// a (presumably) existing user session
|
|
||||||
func (p *Authenticate) validateExisting(f http.HandlerFunc) http.HandlerFunc {
|
|
||||||
return m.ValidateClientSecret(f, p.SharedKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RobotsTxt handles the /robots.txt route.
|
// RobotsTxt handles the /robots.txt route.
|
||||||
|
@ -83,10 +96,8 @@ func (p *Authenticate) PingPage(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// SignInPage directs the user to the sign in page. Takes a `redirect_uri` param.
|
// SignInPage directs the user to the sign in page. Takes a `redirect_uri` param.
|
||||||
func (p *Authenticate) SignInPage(w http.ResponseWriter, r *http.Request) {
|
func (p *Authenticate) SignInPage(w http.ResponseWriter, r *http.Request) {
|
||||||
// requestLog := log.WithRequest(req, "authenticate.SignInPage")
|
|
||||||
redirectURL := p.RedirectURL.ResolveReference(r.URL)
|
redirectURL := p.RedirectURL.ResolveReference(r.URL)
|
||||||
// validateRedirectURI middleware already ensures that this is a valid URL
|
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri")) // checked by middleware
|
||||||
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))
|
|
||||||
t := struct {
|
t := struct {
|
||||||
ProviderName string
|
ProviderName string
|
||||||
AllowedDomains []string
|
AllowedDomains []string
|
||||||
|
@ -100,28 +111,27 @@ func (p *Authenticate) SignInPage(w http.ResponseWriter, r *http.Request) {
|
||||||
Destination: destinationURL.Host,
|
Destination: destinationURL.Host,
|
||||||
Version: version.FullVersion(),
|
Version: version.FullVersion(),
|
||||||
}
|
}
|
||||||
log.Ctx(r.Context()).Info().
|
log.FromRequest(r).Debug().
|
||||||
Str("ProviderName", p.provider.Data().ProviderName).
|
Str("ProviderName", p.provider.Data().ProviderName).
|
||||||
Str("Redirect", redirectURL.String()).
|
Str("Redirect", redirectURL.String()).
|
||||||
Str("Destination", destinationURL.Host).
|
Str("Destination", destinationURL.Host).
|
||||||
Str("AllowedDomains", strings.Join(p.AllowedDomains, ", ")).
|
Str("AllowedDomains", strings.Join(p.AllowedDomains, ", ")).
|
||||||
Msg("authenticate.SignInPage")
|
Msg("authenticate: SignInPage")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
p.templates.ExecuteTemplate(w, "sign_in.html", t)
|
p.templates.ExecuteTemplate(w, "sign_in.html", t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*sessions.SessionState, error) {
|
func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*sessions.SessionState, error) {
|
||||||
// requestLog := log.WithRequest(req, "authenticate.authenticate")
|
|
||||||
session, err := p.sessionStore.LoadSession(r)
|
session, err := p.sessionStore.LoadSession(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("authenticate.authenticate")
|
log.Error().Err(err).Msg("authenticate: failed to load session")
|
||||||
p.sessionStore.ClearSession(w, r)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure sessions lifetime has not expired
|
// ensure sessions lifetime has not expired
|
||||||
if session.LifetimePeriodExpired() {
|
if session.LifetimePeriodExpired() {
|
||||||
log.Ctx(r.Context()).Warn().Msg("lifetime expired")
|
log.FromRequest(r).Warn().Msg("authenticate: lifetime expired")
|
||||||
p.sessionStore.ClearSession(w, r)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, sessions.ErrLifetimeExpired
|
return nil, sessions.ErrLifetimeExpired
|
||||||
}
|
}
|
||||||
|
@ -129,12 +139,12 @@ func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
|
||||||
if session.RefreshPeriodExpired() {
|
if session.RefreshPeriodExpired() {
|
||||||
ok, err := p.provider.RefreshSessionIfNeeded(session)
|
ok, err := p.provider.RefreshSessionIfNeeded(session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Ctx(r.Context()).Error().Err(err).Msg("failed to refresh session")
|
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to refresh session")
|
||||||
p.sessionStore.ClearSession(w, r)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Ctx(r.Context()).Error().Msg("user unauthorized after refresh")
|
log.FromRequest(r).Error().Msg("user unauthorized after refresh")
|
||||||
p.sessionStore.ClearSession(w, r)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, httputil.ErrUserNotAuthorized
|
return nil, httputil.ErrUserNotAuthorized
|
||||||
}
|
}
|
||||||
|
@ -144,7 +154,7 @@ func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
|
||||||
// We refreshed the session successfully, but failed to save it.
|
// We refreshed the session successfully, but failed to save it.
|
||||||
// This could be from failing to encode the session properly.
|
// This could be from failing to encode the session properly.
|
||||||
// But, we clear the session cookie and reject the request
|
// But, we clear the session cookie and reject the request
|
||||||
log.Ctx(r.Context()).Error().Err(err).Msg("could not save refreshed session")
|
log.FromRequest(r).Error().Err(err).Msg("could not save refreshed session")
|
||||||
p.sessionStore.ClearSession(w, r)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -152,20 +162,20 @@ func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*se
|
||||||
// The session has not exceeded it's lifetime or requires refresh
|
// The session has not exceeded it's lifetime or requires refresh
|
||||||
ok := p.provider.ValidateSessionState(session)
|
ok := p.provider.ValidateSessionState(session)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Ctx(r.Context()).Error().Msg("invalid session state")
|
log.FromRequest(r).Error().Msg("invalid session state")
|
||||||
p.sessionStore.ClearSession(w, r)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, httputil.ErrUserNotAuthorized
|
return nil, httputil.ErrUserNotAuthorized
|
||||||
}
|
}
|
||||||
err = p.sessionStore.SaveSession(w, r, session)
|
err = p.sessionStore.SaveSession(w, r, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Ctx(r.Context()).Error().Err(err).Msg("failed to save valid session")
|
log.FromRequest(r).Error().Err(err).Msg("failed to save valid session")
|
||||||
p.sessionStore.ClearSession(w, r)
|
p.sessionStore.ClearSession(w, r)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !p.Validator(session.Email) {
|
if !p.Validator(session.Email) {
|
||||||
log.Ctx(r.Context()).Error().Msg("invalid email user")
|
log.FromRequest(r).Error().Msg("invalid email user")
|
||||||
return nil, httputil.ErrUserNotAuthorized
|
return nil, httputil.ErrUserNotAuthorized
|
||||||
}
|
}
|
||||||
return session, nil
|
return session, nil
|
||||||
|
@ -316,7 +326,7 @@ func (p *Authenticate) SignOutPage(w http.ResponseWriter, r *http.Request, messa
|
||||||
|
|
||||||
signature := r.Form.Get("sig")
|
signature := r.Form.Get("sig")
|
||||||
timestamp := r.Form.Get("ts")
|
timestamp := r.Form.Get("ts")
|
||||||
destinationURL, _ := url.Parse(redirectURI)
|
destinationURL, _ := url.Parse(redirectURI) //checked by middleware
|
||||||
|
|
||||||
// An error message indicates that an internal server error occurred
|
// An error message indicates that an internal server error occurred
|
||||||
if message != "" {
|
if message != "" {
|
||||||
|
@ -341,7 +351,6 @@ func (p *Authenticate) SignOutPage(w http.ResponseWriter, r *http.Request, messa
|
||||||
Version: version.FullVersion(),
|
Version: version.FullVersion(),
|
||||||
}
|
}
|
||||||
p.templates.ExecuteTemplate(w, "sign_out.html", t)
|
p.templates.ExecuteTemplate(w, "sign_out.html", t)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthStart starts the authentication process by redirecting to the provider. It provides a
|
// OAuthStart starts the authentication process by redirecting to the provider. It provides a
|
||||||
|
@ -364,20 +373,20 @@ func (p *Authenticate) helperOAuthStart(w http.ResponseWriter, r *http.Request,
|
||||||
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
|
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
|
||||||
p.csrfStore.SetCSRF(w, r, nonce)
|
p.csrfStore.SetCSRF(w, r, nonce)
|
||||||
|
|
||||||
if !validRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
|
if !middleware.ValidRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
|
||||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
|
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
|
||||||
if err != nil || !validRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
|
if err != nil || !middleware.ValidRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
|
||||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyRedirectSig := authRedirectURL.Query().Get("sig")
|
proxyRedirectSig := authRedirectURL.Query().Get("sig")
|
||||||
ts := authRedirectURL.Query().Get("ts")
|
ts := authRedirectURL.Query().Get("ts")
|
||||||
if !validSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.SharedKey) {
|
if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.SharedKey) {
|
||||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -404,7 +413,6 @@ func (p *Authenticate) redeemCode(host, code string) (*sessions.SessionState, er
|
||||||
|
|
||||||
// getOAuthCallback completes the oauth cycle from an identity provider's callback
|
// getOAuthCallback completes the oauth cycle from an identity provider's callback
|
||||||
func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) {
|
func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||||
// requestLog := log.WithRequest(req, "authenticate.getOAuthCallback")
|
|
||||||
// finish the oauth cycle
|
// finish the oauth cycle
|
||||||
err := r.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -421,17 +429,18 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
|
|
||||||
session, err := p.redeemCode(r.Host, code)
|
session, err := p.redeemCode(r.Host, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Ctx(r.Context()).Error().Err(err).Msg("error redeeming authentication code")
|
log.FromRequest(r).Error().Err(err).Msg("error redeeming authentication code")
|
||||||
return "", err
|
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
|
||||||
}
|
}
|
||||||
|
|
||||||
bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state"))
|
bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"}
|
log.FromRequest(r).Error().Err(err).Msg("failed decoding state")
|
||||||
|
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Couldn't decode state"}
|
||||||
}
|
}
|
||||||
s := strings.SplitN(string(bytes), ":", 2)
|
s := strings.SplitN(string(bytes), ":", 2)
|
||||||
if len(s) != 2 {
|
if len(s) != 2 {
|
||||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Invalid State"}
|
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Invalid State"}
|
||||||
}
|
}
|
||||||
nonce := s[0]
|
nonce := s[0]
|
||||||
redirect := s[1]
|
redirect := s[1]
|
||||||
|
@ -441,11 +450,11 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
p.csrfStore.ClearCSRF(w, r)
|
p.csrfStore.ClearCSRF(w, r)
|
||||||
if c.Value != nonce {
|
if c.Value != nonce {
|
||||||
log.Ctx(r.Context()).Error().Err(err).Msg("csrf token mismatch")
|
log.FromRequest(r).Error().Err(err).Msg("CSRF token mismatch")
|
||||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "csrf failed"}
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "CSRF failed"}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !validRedirectURI(redirect, p.ProxyRootDomains) {
|
if !middleware.ValidRedirectURI(redirect, p.ProxyRootDomains) {
|
||||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Redirect URI"}
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Redirect URI"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -453,10 +462,10 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
// - for p.Validator see validator.go#newValidatorImpl for more info
|
// - for p.Validator see validator.go#newValidatorImpl for more info
|
||||||
// - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info
|
// - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info
|
||||||
if !p.Validator(session.Email) {
|
if !p.Validator(session.Email) {
|
||||||
log.Ctx(r.Context()).Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
|
log.FromRequest(r).Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
|
||||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Account"}
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Account"}
|
||||||
}
|
}
|
||||||
log.Ctx(r.Context()).Info().Str("email", session.Email).Msg("authentication complete")
|
log.FromRequest(r).Info().Str("email", session.Email).Msg("authentication complete")
|
||||||
err = p.sessionStore.SaveSession(w, r, session)
|
err = p.sessionStore.SaveSession(w, r, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("internal error")
|
log.Error().Err(err).Msg("internal error")
|
||||||
|
@ -487,7 +496,6 @@ func (p *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
func (p *Authenticate) Redeem(w http.ResponseWriter, r *http.Request) {
|
func (p *Authenticate) Redeem(w http.ResponseWriter, r *http.Request) {
|
||||||
// The auth code is redeemed by the sso proxy for an access token, refresh token,
|
// The auth code is redeemed by the sso proxy for an access token, refresh token,
|
||||||
// expiration, and email.
|
// expiration, and email.
|
||||||
// requestLog := log.WithRequest(req, "authenticate.Redeem")
|
|
||||||
err := r.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
|
http.Error(w, fmt.Sprintf("Bad Request: %s", err.Error()), http.StatusBadRequest)
|
||||||
|
@ -496,19 +504,19 @@ func (p *Authenticate) Redeem(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
session, err := sessions.UnmarshalSession(r.Form.Get("code"), p.cipher)
|
session, err := sessions.UnmarshalSession(r.Form.Get("code"), p.cipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Ctx(r.Context()).Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid auth code")
|
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to unmarshal session")
|
||||||
http.Error(w, fmt.Sprintf("invalid auth code: %s", err.Error()), http.StatusUnauthorized)
|
http.Error(w, fmt.Sprintf("invalid auth code: %s", err.Error()), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if session == nil {
|
if session == nil {
|
||||||
log.Ctx(r.Context()).Error().Err(err).Int("http-status", http.StatusUnauthorized).Msg("invalid session")
|
log.FromRequest(r).Error().Err(err).Msg("empty session")
|
||||||
http.Error(w, fmt.Sprintf("invalid session: %s", err.Error()), http.StatusUnauthorized)
|
http.Error(w, fmt.Sprintf("empty session: %s", err.Error()), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if session != nil && (session.RefreshPeriodExpired() || session.LifetimePeriodExpired()) {
|
if session != nil && (session.RefreshPeriodExpired() || session.LifetimePeriodExpired()) {
|
||||||
log.Ctx(r.Context()).Error().Msg("expired session")
|
log.FromRequest(r).Error().Msg("expired session")
|
||||||
p.sessionStore.ClearSession(w, r)
|
p.sessionStore.ClearSession(w, r)
|
||||||
http.Error(w, fmt.Sprintf("expired session"), http.StatusUnauthorized)
|
http.Error(w, fmt.Sprintf("expired session"), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
|
@ -524,7 +532,7 @@ func (p *Authenticate) Redeem(w http.ResponseWriter, r *http.Request) {
|
||||||
AccessToken: session.AccessToken,
|
AccessToken: session.AccessToken,
|
||||||
RefreshToken: session.RefreshToken,
|
RefreshToken: session.RefreshToken,
|
||||||
IDToken: session.IDToken,
|
IDToken: session.IDToken,
|
||||||
ExpiresIn: int64(session.RefreshDeadline.Sub(time.Now()).Seconds()),
|
ExpiresIn: int64(time.Until(session.RefreshDeadline).Seconds()),
|
||||||
Email: session.Email,
|
Email: session.Email,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -643,7 +651,5 @@ func (p *Authenticate) ValidateToken(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,7 +77,7 @@ func TestAuthenticate_SignInPage(t *testing.T) {
|
||||||
if status := rr.Code; status != http.StatusOK {
|
if status := rr.Code; status != http.StatusOK {
|
||||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
||||||
}
|
}
|
||||||
body := []byte(rr.Body.String())
|
body := rr.Body.Bytes()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
|
@ -1,109 +0,0 @@
|
||||||
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
|
||||||
)
|
|
||||||
|
|
||||||
var defaultSignatureValidityDuration = 5 * time.Minute
|
|
||||||
|
|
||||||
// validateRedirectURI checks the redirect uri in the query parameters and ensures that
|
|
||||||
// the url's domain is one in the list of proxy root domains.
|
|
||||||
func validateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
|
|
||||||
return func(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
err := req.ParseForm()
|
|
||||||
if err != nil {
|
|
||||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
redirectURI := req.Form.Get("redirect_uri")
|
|
||||||
if !validRedirectURI(redirectURI, proxyRootDomains) {
|
|
||||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f(rw, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func validRedirectURI(uri string, rootDomains []string) bool {
|
|
||||||
if uri == "" || len(rootDomains) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
redirectURL, err := url.Parse(uri)
|
|
||||||
if err != nil || redirectURL.Host == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, domain := range rootDomains {
|
|
||||||
if domain == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(redirectURL.Hostname(), domain) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateSignature(f http.HandlerFunc, sharedKey string) http.HandlerFunc {
|
|
||||||
return func(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
err := req.ParseForm()
|
|
||||||
if err != nil {
|
|
||||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
redirectURI := req.Form.Get("redirect_uri")
|
|
||||||
sigVal := req.Form.Get("sig")
|
|
||||||
timestamp := req.Form.Get("ts")
|
|
||||||
if !validSignature(redirectURI, sigVal, timestamp, sharedKey) {
|
|
||||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f(rw, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateSignature ensures the validity of the redirect url by comparing the hmac
|
|
||||||
// digest, and ensuring that the included timestamp is fresh
|
|
||||||
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
|
||||||
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
_, err := url.Parse(redirectURI)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
i, err := strconv.ParseInt(timestamp, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
tm := time.Unix(i, 0)
|
|
||||||
if time.Now().Sub(tm) > defaultSignatureValidityDuration {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
localSig := redirectURLSignature(redirectURI, tm, secret)
|
|
||||||
return hmac.Equal(requestSig, localSig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// redirectURLSignature generates a hmac digest from a
|
|
||||||
// redirect url, a timestamp, and a secret.
|
|
||||||
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
|
|
||||||
h := hmac.New(sha256.New, []byte(secret))
|
|
||||||
h.Write([]byte(rawRedirect))
|
|
||||||
h.Write([]byte(fmt.Sprint(timestamp.Unix())))
|
|
||||||
return h.Sum(nil)
|
|
||||||
}
|
|
|
@ -1,85 +0,0 @@
|
||||||
package authenticate
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_validRedirectURI(t *testing.T) {
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
uri string
|
|
||||||
rootDomains []string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"good url redirect", "https://example.com/redirect", []string{"example.com"}, true},
|
|
||||||
{"bad domain", "https://example.com/redirect", []string{"notexample.com"}, false},
|
|
||||||
{"malformed url", "^example.com/redirect", []string{"notexample.com"}, false},
|
|
||||||
{"empty domain list", "https://example.com/redirect", []string{}, false},
|
|
||||||
{"empty domain", "https://example.com/redirect", []string{""}, false},
|
|
||||||
{"empty url", "", []string{"example.com"}, false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := validRedirectURI(tt.uri, tt.rootDomains); got != tt.want {
|
|
||||||
t.Errorf("validRedirectURI() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_validSignature(t *testing.T) {
|
|
||||||
goodURL := "https://example.com/redirect"
|
|
||||||
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
|
||||||
now := fmt.Sprint(time.Now().Unix())
|
|
||||||
rawSig := redirectURLSignature(goodURL, time.Now(), secretA)
|
|
||||||
sig := base64.URLEncoding.EncodeToString(rawSig)
|
|
||||||
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
redirectURI string
|
|
||||||
sigVal string
|
|
||||||
timestamp string
|
|
||||||
secret string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"good signature", goodURL, string(sig), now, secretA, true},
|
|
||||||
{"empty redirect url", "", string(sig), now, secretA, false},
|
|
||||||
{"bad redirect url", "https://google.com^", string(sig), now, secretA, false},
|
|
||||||
{"malformed signature", goodURL, string(sig + "^"), now, "&*&@**($&#(", false},
|
|
||||||
{"malformed timestamp", goodURL, string(sig), now + "^", secretA, false},
|
|
||||||
{"stale timestamp", goodURL, string(sig), staleTime, secretA, false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := validSignature(tt.redirectURI, tt.sigVal, tt.timestamp, tt.secret); got != tt.want {
|
|
||||||
t.Errorf("validSignature() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_redirectURLSignature(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
rawRedirect string
|
|
||||||
timestamp time.Time
|
|
||||||
secret string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"good signature", "https://example.com/redirect", time.Unix(1546797901, 0), "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=", "GIDyWKjrG_7MwXwIq1o51f2pDT_rH9aLHdsHxSBEwy8="},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := redirectURLSignature(tt.rawRedirect, tt.timestamp, tt.secret)
|
|
||||||
out := base64.URLEncoding.EncodeToString(got)
|
|
||||||
if out != tt.want {
|
|
||||||
t.Errorf("redirectURLSignature() = %v, want %v", tt.want, out)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -258,7 +258,7 @@ func (p *ProviderData) RefreshAccessToken(refreshToken string) (string, time.Dur
|
||||||
log.Error().Err(err).Msg("authenticate/providers.RefreshAccessToken")
|
log.Error().Err(err).Msg("authenticate/providers.RefreshAccessToken")
|
||||||
return "", 0, err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
return newToken.AccessToken, newToken.Expiry.Sub(time.Now()), nil
|
return newToken.AccessToken, time.Until(newToken.Expiry), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Revoke enables a user to revoke her token. If the identity provider supports revocation
|
// Revoke enables a user to revoke her token. If the identity provider supports revocation
|
||||||
|
|
|
@ -75,8 +75,3 @@ func (tp *TestProvider) Redeem(code string) (*sessions.SessionState, error) {
|
||||||
return tp.Session, tp.RedeemError
|
return tp.Session, tp.RedeemError
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop fulfills the Provider interface
|
|
||||||
func (tp *TestProvider) Stop() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
|
@ -7,10 +7,10 @@
|
||||||
|
|
||||||
# Certificates can be loaded as files or base64 encoded bytes. If neither is set, a
|
# Certificates can be loaded as files or base64 encoded bytes. If neither is set, a
|
||||||
# pomerium will attempt to locate a pair in the root directory
|
# pomerium will attempt to locate a pair in the root directory
|
||||||
export CERTIFICATE="xxxxxx" # base64 encoded cert, eg. `base64 -i cert.pem`
|
|
||||||
export CERTIFICATE_KEY="xxxx" # base64 encoded key, eg. `base64 -i privkey.pem`
|
|
||||||
export CERTIFICATE_FILE="./cert.pem" # optional, defaults to `./cert.pem`
|
export CERTIFICATE_FILE="./cert.pem" # optional, defaults to `./cert.pem`
|
||||||
export CERTIFICATE_KEY_FILE="./privkey.pem" # optional, defaults to `./certprivkey.pem`
|
export CERTIFICATE_KEY_FILE="./privkey.pem" # optional, defaults to `./certprivkey.pem`
|
||||||
|
# export CERTIFICATE="xxxxxx" # base64 encoded cert, eg. `base64 -i cert.pem`
|
||||||
|
# export CERTIFICATE_KEY="xxxx" # base64 encoded key, eg. `base64 -i privkey.pem`
|
||||||
|
|
||||||
# The URL that the identity provider will call back after authenticating the user
|
# The URL that the identity provider will call back after authenticating the user
|
||||||
export REDIRECT_URL="https://sso-auth.corp.example.com/oauth2/callback"
|
export REDIRECT_URL="https://sso-auth.corp.example.com/oauth2/callback"
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -9,7 +9,6 @@ require (
|
||||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||||
github.com/rs/zerolog v1.11.0
|
github.com/rs/zerolog v1.11.0
|
||||||
github.com/stretchr/testify v1.2.2 // indirect
|
github.com/stretchr/testify v1.2.2 // indirect
|
||||||
github.com/zenazn/goji v0.9.0
|
|
||||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9
|
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9
|
||||||
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 // indirect
|
golang.org/x/net v0.0.0-20181220203305-927f97764cc3 // indirect
|
||||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890
|
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -16,8 +16,6 @@ github.com/rs/zerolog v1.11.0 h1:DRuq/S+4k52uJzBQciUcofXx45GrMC6yrEbb/CoK6+M=
|
||||||
github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
github.com/rs/zerolog v1.11.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
||||||
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
|
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
|
||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
github.com/zenazn/goji v0.9.0 h1:RSQQAbXGArQ0dIDEq+PI6WqN6if+5KHu6x2Cx/GXLTQ=
|
|
||||||
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
|
||||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0=
|
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0=
|
||||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrTokenRevoked signifies a token revokation or expiration error
|
// ErrTokenRevoked signifies a token revokation or expiration error
|
||||||
var ErrTokenRevoked = errors.New("Token expired or revoked")
|
var ErrTokenRevoked = errors.New("token expired or revoked")
|
||||||
|
|
||||||
var httpClient = &http.Client{
|
var httpClient = &http.Client{
|
||||||
Timeout: time.Second * 5,
|
Timeout: time.Second * 5,
|
||||||
|
|
|
@ -3,6 +3,7 @@ package log // import "github.com/pomerium/pomerium/internal/log"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
@ -102,3 +103,9 @@ func Printf(format string, v ...interface{}) {
|
||||||
func Ctx(ctx context.Context) *zerolog.Logger {
|
func Ctx(ctx context.Context) *zerolog.Logger {
|
||||||
return zerolog.Ctx(ctx)
|
return zerolog.Ctx(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FromRequest gets the logger in the request's context.
|
||||||
|
// This is a shortcut for log.Ctx(r.Context())
|
||||||
|
func FromRequest(r *http.Request) *zerolog.Logger {
|
||||||
|
return Ctx(r.Context())
|
||||||
|
}
|
||||||
|
|
2
internal/middleware/doc.go
Normal file
2
internal/middleware/doc.go
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
// Package middleware provides a standard set of middleware implementations for pomerium.
|
||||||
|
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
|
@ -1,5 +1,4 @@
|
||||||
// Package log provides a set of http.Handler helpers for zerolog.
|
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||||
package log // import "github.com/pomerium/pomerium/internal/log"
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -10,14 +9,14 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/zenazn/goji/web/mutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// FromRequest gets the logger in the request's context.
|
// FromRequest gets the logger in the request's context.
|
||||||
// This is a shortcut for log.Ctx(r.Context())
|
// This is a shortcut for log.Ctx(r.Context())
|
||||||
func FromRequest(r *http.Request) *zerolog.Logger {
|
func FromRequest(r *http.Request) *zerolog.Logger {
|
||||||
return Ctx(r.Context())
|
return log.Ctx(r.Context())
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler injects log into requests context.
|
// NewHandler injects log into requests context.
|
||||||
|
@ -172,7 +171,7 @@ func AccessHandler(f func(r *http.Request, status, size int, duration time.Durat
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
lw := mutil.WrapWriter(w)
|
lw := NewWrapResponseWriter(w, 2)
|
||||||
next.ServeHTTP(lw, r)
|
next.ServeHTTP(lw, r)
|
||||||
f(r, lw.Status(), lw.BytesWritten(), time.Since(start))
|
f(r, lw.Status(), lw.BytesWritten(), time.Since(start))
|
||||||
})
|
})
|
|
@ -1,5 +1,4 @@
|
||||||
// Package log provides a set of http.Handler helpers for zerolog.
|
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||||
package log // import "github.com/pomerium/pomerium/internal/log"
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -258,3 +257,21 @@ func BenchmarkDataRace(b *testing.B) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestForwardedAddrHandler(t *testing.T) {
|
||||||
|
out := &bytes.Buffer{}
|
||||||
|
r := &http.Request{
|
||||||
|
Header: http.Header{
|
||||||
|
"X-Forwarded-For": []string{"client", "proxy1", "proxy2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := ForwardedAddrHandler("fwd_ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
l := FromRequest(r)
|
||||||
|
l.Log().Msg("")
|
||||||
|
}))
|
||||||
|
h = NewHandler(zerolog.New(out))(h)
|
||||||
|
h.ServeHTTP(nil, r)
|
||||||
|
if want, got := `{"fwd_ip":"client"}`+"\n", decodeIfBinary(out); want != got {
|
||||||
|
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,3 @@
|
||||||
// Package middleware provides a standard set of middleware implementations for pomerium.
|
|
||||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -15,91 +14,76 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetHeadersOld ensures that every response includes some basic security headers
|
|
||||||
func SetHeadersOld(h http.Handler, securityHeaders map[string]string) http.Handler {
|
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
for key, val := range securityHeaders {
|
|
||||||
rw.Header().Set(key, val)
|
|
||||||
}
|
|
||||||
h.ServeHTTP(rw, req)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetHeaders ensures that every response includes some basic security headers
|
// SetHeaders ensures that every response includes some basic security headers
|
||||||
func SetHeaders(securityHeaders map[string]string) func(next http.Handler) http.Handler {
|
func SetHeaders(securityHeaders map[string]string) func(next http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
for key, val := range securityHeaders {
|
for key, val := range securityHeaders {
|
||||||
rw.Header().Set(key, val)
|
w.Header().Set(key, val)
|
||||||
}
|
}
|
||||||
next.ServeHTTP(rw, req)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithMethods writes an error response if the method of the request is not included.
|
|
||||||
func WithMethods(f http.HandlerFunc, methods ...string) http.HandlerFunc {
|
|
||||||
methodMap := make(map[string]struct{}, len(methods))
|
|
||||||
for _, m := range methods {
|
|
||||||
methodMap[m] = struct{}{}
|
|
||||||
}
|
|
||||||
return func(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
if _, ok := methodMap[req.Method]; !ok {
|
|
||||||
httputil.ErrorResponse(rw, req, fmt.Sprintf("method %s not allowed", req.Method), http.StatusMethodNotAllowed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
f(rw, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateClientSecret checks the request header for the client secret and returns
|
// ValidateClientSecret checks the request header for the client secret and returns
|
||||||
// an error if it does not match the proxy client secret
|
// an error if it does not match the proxy client secret
|
||||||
func ValidateClientSecret(f http.HandlerFunc, sharedSecret string) http.HandlerFunc {
|
func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Handler {
|
||||||
return func(rw http.ResponseWriter, req *http.Request) {
|
return func(next http.Handler) http.Handler {
|
||||||
err := req.ParseForm()
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err != nil {
|
err := r.ParseForm()
|
||||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
|
if err != nil {
|
||||||
return
|
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
|
||||||
}
|
return
|
||||||
clientSecret := req.Form.Get("shared_secret")
|
}
|
||||||
// check the request header for the client secret
|
clientSecret := r.Form.Get("shared_secret")
|
||||||
if clientSecret == "" {
|
// check the request header for the client secret
|
||||||
clientSecret = req.Header.Get("X-Client-Secret")
|
if clientSecret == "" {
|
||||||
}
|
clientSecret = r.Header.Get("X-Client-Secret")
|
||||||
|
}
|
||||||
|
|
||||||
if clientSecret != sharedSecret {
|
if clientSecret != sharedSecret {
|
||||||
httputil.ErrorResponse(rw, req, "Invalid client secret", http.StatusUnauthorized)
|
httputil.ErrorResponse(w, r, "Invalid client secret", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f(rw, req)
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateRedirectURI checks the redirect uri in the query parameters and ensures that
|
// ValidateRedirectURI checks the redirect uri in the query parameters and ensures that
|
||||||
// the url's domain is one in the list of proxy root domains.
|
// the its domain is in the list of proxy root domains.
|
||||||
func ValidateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
|
func ValidateRedirectURI(proxyRootDomains []string) func(next http.Handler) http.Handler {
|
||||||
return func(rw http.ResponseWriter, req *http.Request) {
|
return func(next http.Handler) http.Handler {
|
||||||
err := req.ParseForm()
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err != nil {
|
err := r.ParseForm()
|
||||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
return
|
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
|
||||||
}
|
return
|
||||||
redirectURI := req.Form.Get("redirect_uri")
|
}
|
||||||
if !validRedirectURI(redirectURI, proxyRootDomains) {
|
redirectURI := r.Form.Get("redirect_uri")
|
||||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
if !ValidRedirectURI(redirectURI, proxyRootDomains) {
|
||||||
return
|
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
}
|
return
|
||||||
|
}
|
||||||
f(rw, req)
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func validRedirectURI(uri string, rootDomains []string) bool {
|
// ValidRedirectURI checks if a URL's domain is one in the list of proxy root domains.
|
||||||
|
func ValidRedirectURI(uri string, rootDomains []string) bool {
|
||||||
|
if uri == "" || len(rootDomains) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
redirectURL, err := url.Parse(uri)
|
redirectURL, err := url.Parse(uri)
|
||||||
if uri == "" || err != nil || redirectURL.Host == "" {
|
if err != nil || redirectURL.Host == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
for _, domain := range rootDomains {
|
for _, domain := range rootDomains {
|
||||||
|
if domain == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
if strings.HasSuffix(redirectURL.Hostname(), domain) {
|
if strings.HasSuffix(redirectURL.Hostname(), domain) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -109,35 +93,36 @@ func validRedirectURI(uri string, rootDomains []string) bool {
|
||||||
|
|
||||||
// ValidateSignature ensures the request is valid and has been signed with
|
// ValidateSignature ensures the request is valid and has been signed with
|
||||||
// the correspdoning client secret key
|
// the correspdoning client secret key
|
||||||
func ValidateSignature(f http.HandlerFunc, sharedSecret string) http.HandlerFunc {
|
func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler {
|
||||||
return func(rw http.ResponseWriter, req *http.Request) {
|
return func(next http.Handler) http.Handler {
|
||||||
err := req.ParseForm()
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err != nil {
|
err := r.ParseForm()
|
||||||
httputil.ErrorResponse(rw, req, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
return
|
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
|
||||||
}
|
return
|
||||||
redirectURI := req.Form.Get("redirect_uri")
|
}
|
||||||
sigVal := req.Form.Get("sig")
|
redirectURI := r.Form.Get("redirect_uri")
|
||||||
timestamp := req.Form.Get("ts")
|
sigVal := r.Form.Get("sig")
|
||||||
if !validSignature(redirectURI, sigVal, timestamp, sharedSecret) {
|
timestamp := r.Form.Get("ts")
|
||||||
httputil.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
|
if !ValidSignature(redirectURI, sigVal, timestamp, sharedSecret) {
|
||||||
return
|
httputil.ErrorResponse(w, r, "Cross service signature failed to validate", http.StatusUnauthorized)
|
||||||
}
|
return
|
||||||
|
}
|
||||||
|
|
||||||
f(rw, req)
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateHost ensures that each request's host is valid
|
// ValidateHost ensures that each request's host is valid
|
||||||
func ValidateHost(mux map[string]*http.Handler) func(next http.Handler) http.Handler {
|
func ValidateHost(mux map[string]http.Handler) func(next http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
if _, ok := mux[r.Host]; !ok {
|
||||||
if _, ok := mux[req.Host]; !ok {
|
httputil.ErrorResponse(w, r, "Unknown host to route", http.StatusNotFound)
|
||||||
httputil.ErrorResponse(rw, req, "Unknown host to route", http.StatusNotFound)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next.ServeHTTP(rw, req)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -145,25 +130,47 @@ func ValidateHost(mux map[string]*http.Handler) func(next http.Handler) http.Han
|
||||||
// RequireHTTPS reroutes a HTTP request to HTTPS
|
// RequireHTTPS reroutes a HTTP request to HTTPS
|
||||||
// todo(bdd) : this is unreliable unless behind another reverser proxy
|
// todo(bdd) : this is unreliable unless behind another reverser proxy
|
||||||
// todo(bdd) : header age seems extreme
|
// todo(bdd) : header age seems extreme
|
||||||
func RequireHTTPS(h http.Handler) http.Handler {
|
func RequireHTTPS(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
rw.Header().Set("Strict-Transport-Security", "max-age=31536000")
|
w.Header().Set("Strict-Transport-Security", "max-age=31536000")
|
||||||
// todo(bdd) : scheme and x-forwarded-proto cannot be trusted if not behind another load balancer
|
// todo(bdd) : scheme and x-forwarded-proto cannot be trusted if not behind another load balancer
|
||||||
if (req.URL.Scheme == "http" && req.Header.Get("X-Forwarded-Proto") == "http") || &req.TLS == nil {
|
if (r.URL.Scheme == "http" && r.Header.Get("X-Forwarded-Proto") == "http") || &r.TLS == nil {
|
||||||
dest := &url.URL{
|
dest := &url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: req.Host,
|
Host: r.Host,
|
||||||
Path: req.URL.Path,
|
Path: r.URL.Path,
|
||||||
RawQuery: req.URL.RawQuery,
|
RawQuery: r.URL.RawQuery,
|
||||||
}
|
}
|
||||||
http.Redirect(rw, req, dest.String(), http.StatusMovedPermanently)
|
http.Redirect(w, r, dest.String(), http.StatusMovedPermanently)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.ServeHTTP(rw, req)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
// Healthcheck endpoint middleware useful to setting up a path like
|
||||||
|
// `/ping` that load balancers or uptime testing external services
|
||||||
|
// can make a request before hitting any routes. It's also convenient
|
||||||
|
// to place this above ACL middlewares as well.
|
||||||
|
func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler {
|
||||||
|
f := func(next http.Handler) http.Handler {
|
||||||
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == "GET" && strings.EqualFold(r.URL.Path, endpoint) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(msg))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
return http.HandlerFunc(fn)
|
||||||
|
}
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidSignature checks to see if a signature is valid. Compares hmac of
|
||||||
|
// redirect uri, timestamp, and secret and signature.
|
||||||
|
func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||||
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
324
internal/middleware/middleware_test.go
Normal file
324
internal/middleware/middleware_test.go
Normal file
|
@ -0,0 +1,324 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_ValidRedirectURI(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
uri string
|
||||||
|
rootDomains []string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"good url redirect", "https://example.com/redirect", []string{"example.com"}, true},
|
||||||
|
{"bad domain", "https://example.com/redirect", []string{"notexample.com"}, false},
|
||||||
|
{"malformed url", "^example.com/redirect", []string{"notexample.com"}, false},
|
||||||
|
{"empty domain list", "https://example.com/redirect", []string{}, false},
|
||||||
|
{"empty domain", "https://example.com/redirect", []string{""}, false},
|
||||||
|
{"empty url", "", []string{"example.com"}, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := ValidRedirectURI(tt.uri, tt.rootDomains); got != tt.want {
|
||||||
|
t.Errorf("ValidRedirectURI() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_ValidSignature(t *testing.T) {
|
||||||
|
goodURL := "https://example.com/redirect"
|
||||||
|
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
||||||
|
now := fmt.Sprint(time.Now().Unix())
|
||||||
|
rawSig := redirectURLSignature(goodURL, time.Now(), secretA)
|
||||||
|
sig := base64.URLEncoding.EncodeToString(rawSig)
|
||||||
|
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
redirectURI string
|
||||||
|
sigVal string
|
||||||
|
timestamp string
|
||||||
|
secret string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"good signature", goodURL, string(sig), now, secretA, true},
|
||||||
|
{"empty redirect url", "", string(sig), now, secretA, false},
|
||||||
|
{"bad redirect url", "https://google.com^", string(sig), now, secretA, false},
|
||||||
|
{"malformed signature", goodURL, string(sig + "^"), now, "&*&@**($&#(", false},
|
||||||
|
{"malformed timestamp", goodURL, string(sig), now + "^", secretA, false},
|
||||||
|
{"stale timestamp", goodURL, string(sig), staleTime, secretA, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := ValidSignature(tt.redirectURI, tt.sigVal, tt.timestamp, tt.secret); got != tt.want {
|
||||||
|
t.Errorf("ValidSignature() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_redirectURLSignature(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rawRedirect string
|
||||||
|
timestamp time.Time
|
||||||
|
secret string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"good signature", "https://example.com/redirect", time.Unix(1546797901, 0), "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=", "GIDyWKjrG_7MwXwIq1o51f2pDT_rH9aLHdsHxSBEwy8="},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := redirectURLSignature(tt.rawRedirect, tt.timestamp, tt.secret)
|
||||||
|
out := base64.URLEncoding.EncodeToString(got)
|
||||||
|
if out != tt.want {
|
||||||
|
t.Errorf("redirectURLSignature() = %v, want %v", tt.want, out)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetHeaders(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
securityHeaders map[string]string
|
||||||
|
}{
|
||||||
|
{"one option", map[string]string{"X-Frame-Options": "DENY"}},
|
||||||
|
{"two options", map[string]string{"X-Frame-Options": "DENY", "A": "B"}},
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
for k, want := range tt.securityHeaders {
|
||||||
|
if got := w.Header().Get(k); want != got {
|
||||||
|
t.Errorf("want %s got %q", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
})
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler := SetHeaders(tt.securityHeaders)(testHandler)
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRedirectURI(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
proxyRootDomains []string
|
||||||
|
redirectURI string
|
||||||
|
status int
|
||||||
|
}{
|
||||||
|
{"simple", []string{"google.com"}, "https://google.com", http.StatusOK},
|
||||||
|
{"bad match", []string{"aol.com"}, "https://google.com", http.StatusBadRequest},
|
||||||
|
{"with cname", []string{"google.com"}, "https://www.google.com", http.StatusOK},
|
||||||
|
{"with path", []string{"google.com"}, "https://www.google.com/path", http.StatusOK},
|
||||||
|
{"http", []string{"google.com"}, "http://www.google.com/path", http.StatusOK},
|
||||||
|
{"malformed, invalid hex digits", []string{"google.com"}, "%zzzzz", http.StatusBadRequest},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := &http.Request{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
URL: &url.URL{RawQuery: fmt.Sprintf("redirect_uri=%s", tt.redirectURI)},
|
||||||
|
}
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("Hi"))
|
||||||
|
})
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler := ValidateRedirectURI(tt.proxyRootDomains)(testHandler)
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != tt.status {
|
||||||
|
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
|
||||||
|
t.Errorf("%s", rr.Body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateClientSecret(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sharedSecret string
|
||||||
|
clientGetValue string
|
||||||
|
clientHeaderValue string
|
||||||
|
status int
|
||||||
|
}{
|
||||||
|
{"simple", "secret", "secret", "secret", http.StatusOK},
|
||||||
|
{"missing get param, valid header", "secret", "", "secret", http.StatusOK},
|
||||||
|
{"missing both", "secret", "", "", http.StatusUnauthorized},
|
||||||
|
{"simple bad", "bad-secret", "secret", "", http.StatusUnauthorized},
|
||||||
|
{"malformed, invalid hex digits", "secret", "%zzzzz", "", http.StatusBadRequest},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := &http.Request{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Header: http.Header{"X-Client-Secret": []string{tt.clientHeaderValue}},
|
||||||
|
URL: &url.URL{RawQuery: fmt.Sprintf("shared_secret=%s", tt.clientGetValue)},
|
||||||
|
}
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("Hi"))
|
||||||
|
})
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler := ValidateClientSecret(tt.sharedSecret)(testHandler)
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != tt.status {
|
||||||
|
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
|
||||||
|
t.Errorf("%s", rr.Body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSignature(t *testing.T) {
|
||||||
|
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
||||||
|
now := fmt.Sprint(time.Now().Unix())
|
||||||
|
goodURL := "https://example.com/redirect"
|
||||||
|
rawSig := redirectURLSignature(goodURL, time.Now(), secretA)
|
||||||
|
sig := base64.URLEncoding.EncodeToString(rawSig)
|
||||||
|
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sharedSecret string
|
||||||
|
redirectURI string
|
||||||
|
sig string
|
||||||
|
ts string
|
||||||
|
status int
|
||||||
|
}{
|
||||||
|
{"valid signature", secretA, goodURL, sig, now, http.StatusOK},
|
||||||
|
{"stale signature", secretA, goodURL, sig, staleTime, http.StatusUnauthorized},
|
||||||
|
{"malformed", secretA, goodURL, "%zzzzz", now, http.StatusBadRequest},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
v := url.Values{}
|
||||||
|
v.Set("redirect_uri", tt.redirectURI)
|
||||||
|
v.Set("ts", tt.ts)
|
||||||
|
v.Set("sig", tt.sig)
|
||||||
|
|
||||||
|
req := &http.Request{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
URL: &url.URL{RawQuery: v.Encode()}}
|
||||||
|
if tt.name == "malformed" {
|
||||||
|
req.URL.RawQuery = "sig=%zzzzz"
|
||||||
|
}
|
||||||
|
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("Hi"))
|
||||||
|
})
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler := ValidateSignature(tt.sharedSecret)(testHandler)
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != tt.status {
|
||||||
|
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
|
||||||
|
t.Errorf("%s", rr.Body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthCheck(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
clientPath string
|
||||||
|
expected []byte
|
||||||
|
}{
|
||||||
|
{"good", http.MethodGet, "/ping", []byte("OK")},
|
||||||
|
//tood(bdd): miss?
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, tt.clientPath, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("Hi"))
|
||||||
|
})
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler := Healthcheck(tt.clientPath, string(tt.expected))(testHandler)
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if rr.Body.String() != string(tt.expected) {
|
||||||
|
t.Errorf("body differs. got %ss want %ss", rr.Body, tt.expected)
|
||||||
|
t.Errorf("%s", rr.Body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redirect to a fixed URL
|
||||||
|
type handlerHelper struct {
|
||||||
|
msg string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rh *handlerHelper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(rh.msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlerHelp(msg string) http.Handler {
|
||||||
|
return &handlerHelper{msg}
|
||||||
|
}
|
||||||
|
func TestValidateHost(t *testing.T) {
|
||||||
|
m := make(map[string]http.Handler)
|
||||||
|
m["google.com"] = handlerHelp("google")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
validHosts map[string]http.Handler
|
||||||
|
clientPath string
|
||||||
|
expected []byte
|
||||||
|
status int
|
||||||
|
}{
|
||||||
|
{"good", m, "google.com", []byte("google"), 200},
|
||||||
|
{"no route", m, "googles.com", []byte("google"), 404},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, tt.clientPath, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
var testHandler http.Handler
|
||||||
|
if tt.validHosts[tt.clientPath] != nil {
|
||||||
|
tt.validHosts[tt.clientPath].ServeHTTP(rr, req)
|
||||||
|
testHandler = tt.validHosts[tt.clientPath]
|
||||||
|
} else {
|
||||||
|
testHandler = handlerHelp("ok")
|
||||||
|
}
|
||||||
|
handler := ValidateHost(tt.validHosts)(testHandler)
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != tt.status {
|
||||||
|
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
|
||||||
|
t.Errorf("%s", rr.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
183
internal/middleware/wrap_writer.go
Normal file
183
internal/middleware/wrap_writer.go
Normal file
|
@ -0,0 +1,183 @@
|
||||||
|
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
|
||||||
|
// The original work was derived from Goji's middleware, source:
|
||||||
|
// https://github.com/zenazn/goji/tree/master/web/middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWrapResponseWriter wraps an http.ResponseWriter, returning a proxy that allows you to
|
||||||
|
// hook into various parts of the response process.
|
||||||
|
func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) WrapResponseWriter {
|
||||||
|
_, fl := w.(http.Flusher)
|
||||||
|
|
||||||
|
bw := basicWriter{ResponseWriter: w}
|
||||||
|
|
||||||
|
if protoMajor == 2 {
|
||||||
|
_, ps := w.(http.Pusher)
|
||||||
|
if fl && ps {
|
||||||
|
return &http2FancyWriter{bw}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_, hj := w.(http.Hijacker)
|
||||||
|
_, rf := w.(io.ReaderFrom)
|
||||||
|
if fl && hj && rf {
|
||||||
|
return &httpFancyWriter{bw}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if fl {
|
||||||
|
return &flushWriter{bw}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &bw
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapResponseWriter is a proxy around an http.ResponseWriter that allows you to hook
|
||||||
|
// into various parts of the response process.
|
||||||
|
type WrapResponseWriter interface {
|
||||||
|
http.ResponseWriter
|
||||||
|
// Status returns the HTTP status of the request, or 0 if one has not
|
||||||
|
// yet been sent.
|
||||||
|
Status() int
|
||||||
|
// BytesWritten returns the total number of bytes sent to the client.
|
||||||
|
BytesWritten() int
|
||||||
|
// Tee causes the response body to be written to the given io.Writer in
|
||||||
|
// addition to proxying the writes through. Only one io.Writer can be
|
||||||
|
// tee'd to at once: setting a second one will overwrite the first.
|
||||||
|
// Writes will be sent to the proxy before being written to this
|
||||||
|
// io.Writer. It is illegal for the tee'd writer to be modified
|
||||||
|
// concurrently with writes.
|
||||||
|
Tee(io.Writer)
|
||||||
|
// Unwrap returns the original proxied target.
|
||||||
|
Unwrap() http.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
// basicWriter wraps a http.ResponseWriter that implements the minimal
|
||||||
|
// http.ResponseWriter interface.
|
||||||
|
type basicWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
wroteHeader bool
|
||||||
|
code int
|
||||||
|
bytes int
|
||||||
|
tee io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *basicWriter) WriteHeader(code int) {
|
||||||
|
if !b.wroteHeader {
|
||||||
|
b.code = code
|
||||||
|
b.wroteHeader = true
|
||||||
|
b.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *basicWriter) Write(buf []byte) (int, error) {
|
||||||
|
b.WriteHeader(http.StatusOK)
|
||||||
|
n, err := b.ResponseWriter.Write(buf)
|
||||||
|
if b.tee != nil {
|
||||||
|
_, err2 := b.tee.Write(buf[:n])
|
||||||
|
// Prefer errors generated by the proxied writer.
|
||||||
|
if err == nil {
|
||||||
|
err = err2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.bytes += n
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *basicWriter) maybeWriteHeader() {
|
||||||
|
if !b.wroteHeader {
|
||||||
|
b.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *basicWriter) Status() int {
|
||||||
|
return b.code
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *basicWriter) BytesWritten() int {
|
||||||
|
return b.bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *basicWriter) Tee(w io.Writer) {
|
||||||
|
b.tee = w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *basicWriter) Unwrap() http.ResponseWriter {
|
||||||
|
return b.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type flushWriter struct {
|
||||||
|
basicWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *flushWriter) Flush() {
|
||||||
|
f.wroteHeader = true
|
||||||
|
|
||||||
|
fl := f.basicWriter.ResponseWriter.(http.Flusher)
|
||||||
|
fl.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ http.Flusher = &flushWriter{}
|
||||||
|
|
||||||
|
// httpFancyWriter is a HTTP writer that additionally satisfies
|
||||||
|
// http.Flusher, http.Hijacker, and io.ReaderFrom. It exists for the common case
|
||||||
|
// of wrapping the http.ResponseWriter that package http gives you, in order to
|
||||||
|
// make the proxied object support the full method set of the proxied object.
|
||||||
|
type httpFancyWriter struct {
|
||||||
|
basicWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *httpFancyWriter) Flush() {
|
||||||
|
f.wroteHeader = true
|
||||||
|
|
||||||
|
fl := f.basicWriter.ResponseWriter.(http.Flusher)
|
||||||
|
fl.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *httpFancyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
hj := f.basicWriter.ResponseWriter.(http.Hijacker)
|
||||||
|
return hj.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *http2FancyWriter) Push(target string, opts *http.PushOptions) error {
|
||||||
|
return f.basicWriter.ResponseWriter.(http.Pusher).Push(target, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *httpFancyWriter) ReadFrom(r io.Reader) (int64, error) {
|
||||||
|
if f.basicWriter.tee != nil {
|
||||||
|
n, err := io.Copy(&f.basicWriter, r)
|
||||||
|
f.basicWriter.bytes += int(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
rf := f.basicWriter.ResponseWriter.(io.ReaderFrom)
|
||||||
|
f.basicWriter.maybeWriteHeader()
|
||||||
|
n, err := rf.ReadFrom(r)
|
||||||
|
f.basicWriter.bytes += int(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ http.Flusher = &httpFancyWriter{}
|
||||||
|
var _ http.Hijacker = &httpFancyWriter{}
|
||||||
|
var _ http.Pusher = &http2FancyWriter{}
|
||||||
|
var _ io.ReaderFrom = &httpFancyWriter{}
|
||||||
|
|
||||||
|
// http2FancyWriter is a HTTP2 writer that additionally satisfies
|
||||||
|
// http.Flusher, and io.ReaderFrom. It exists for the common case
|
||||||
|
// of wrapping the http.ResponseWriter that package http gives you, in order to
|
||||||
|
// make the proxied object support the full method set of the proxied object.
|
||||||
|
type http2FancyWriter struct {
|
||||||
|
basicWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *http2FancyWriter) Flush() {
|
||||||
|
f.wroteHeader = true
|
||||||
|
|
||||||
|
fl := f.basicWriter.ResponseWriter.(http.Flusher)
|
||||||
|
fl.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ http.Flusher = &http2FancyWriter{}
|
33
internal/middleware/wrap_writer_test.go
Normal file
33
internal/middleware/wrap_writer_test.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFlushWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
|
||||||
|
f := &flushWriter{basicWriter{ResponseWriter: httptest.NewRecorder()}}
|
||||||
|
f.Flush()
|
||||||
|
|
||||||
|
if !f.wroteHeader {
|
||||||
|
t.Fatal("want Flush to have set wroteHeader=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHttpFancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
|
||||||
|
f := &httpFancyWriter{basicWriter{ResponseWriter: httptest.NewRecorder()}}
|
||||||
|
f.Flush()
|
||||||
|
|
||||||
|
if !f.wroteHeader {
|
||||||
|
t.Fatal("want Flush to have set wroteHeader=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
|
||||||
|
f := &http2FancyWriter{basicWriter{ResponseWriter: httptest.NewRecorder()}}
|
||||||
|
f.Flush()
|
||||||
|
|
||||||
|
if !f.wroteHeader {
|
||||||
|
t.Fatal("want Flush to have set wroteHeader=true")
|
||||||
|
}
|
||||||
|
}
|
|
@ -46,9 +46,9 @@ func (p *Proxy) Handler() http.Handler {
|
||||||
|
|
||||||
// Middleware chain
|
// Middleware chain
|
||||||
c := middleware.NewChain()
|
c := middleware.NewChain()
|
||||||
c = c.Append(log.NewHandler(log.Logger))
|
c = c.Append(middleware.NewHandler(log.Logger))
|
||||||
c = c.Append(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
c = c.Append(middleware.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
||||||
log.FromRequest(r).Info().
|
middleware.FromRequest(r).Info().
|
||||||
Str("method", r.Method).
|
Str("method", r.Method).
|
||||||
Str("url", r.URL.String()).
|
Str("url", r.URL.String()).
|
||||||
Int("status", status).
|
Int("status", status).
|
||||||
|
@ -60,11 +60,11 @@ func (p *Proxy) Handler() http.Handler {
|
||||||
}))
|
}))
|
||||||
c = c.Append(middleware.SetHeaders(securityHeaders))
|
c = c.Append(middleware.SetHeaders(securityHeaders))
|
||||||
c = c.Append(middleware.RequireHTTPS)
|
c = c.Append(middleware.RequireHTTPS)
|
||||||
c = c.Append(log.ForwardedAddrHandler("fwd_ip"))
|
c = c.Append(middleware.ForwardedAddrHandler("fwd_ip"))
|
||||||
c = c.Append(log.RemoteAddrHandler("ip"))
|
c = c.Append(middleware.RemoteAddrHandler("ip"))
|
||||||
c = c.Append(log.UserAgentHandler("user_agent"))
|
c = c.Append(middleware.UserAgentHandler("user_agent"))
|
||||||
c = c.Append(log.RefererHandler("referer"))
|
c = c.Append(middleware.RefererHandler("referer"))
|
||||||
c = c.Append(log.RequestIDHandler("req_id", "Request-Id"))
|
c = c.Append(middleware.RequestIDHandler("req_id", "Request-Id"))
|
||||||
c = c.Append(middleware.ValidateHost(p.mux))
|
c = c.Append(middleware.ValidateHost(p.mux))
|
||||||
h := c.Then(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
h := c.Then(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Skip host validation for /ping requests because they hit the LB directly.
|
// Skip host validation for /ping requests because they hit the LB directly.
|
||||||
|
@ -260,8 +260,7 @@ func (p *Proxy) AuthenticateOnly(w http.ResponseWriter, r *http.Request) {
|
||||||
// or starting the authentication process if not.
|
// or starting the authentication process if not.
|
||||||
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||||
// Attempts to validate the user and their cookie.
|
// Attempts to validate the user and their cookie.
|
||||||
var err error
|
err := p.Authenticate(w, r)
|
||||||
err = p.Authenticate(w, r)
|
|
||||||
// If the authentication is not successful we proceed to start the OAuth Flow with
|
// If the authentication is not successful we proceed to start the OAuth Flow with
|
||||||
// OAuthStart. If successful, we proceed to proxy to the configured upstream.
|
// OAuthStart. If successful, we proceed to proxy to the configured upstream.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -387,7 +386,7 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
|
||||||
|
|
||||||
// Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig
|
// Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig
|
||||||
func (p *Proxy) Handle(host string, handler http.Handler) {
|
func (p *Proxy) Handle(host string, handler http.Handler) {
|
||||||
p.mux[host] = &handler
|
p.mux[host] = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// router attempts to find a route for a request. If a route is successfully matched,
|
// router attempts to find a route for a request. If a route is successfully matched,
|
||||||
|
@ -396,7 +395,7 @@ func (p *Proxy) Handle(host string, handler http.Handler) {
|
||||||
func (p *Proxy) router(r *http.Request) (http.Handler, bool) {
|
func (p *Proxy) router(r *http.Request) (http.Handler, bool) {
|
||||||
route, ok := p.mux[r.Host]
|
route, ok := p.mux[r.Host]
|
||||||
if ok {
|
if ok {
|
||||||
return *route, true
|
return route, true
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
|
@ -135,7 +135,7 @@ type Proxy struct {
|
||||||
|
|
||||||
redirectURL *url.URL
|
redirectURL *url.URL
|
||||||
templates *template.Template
|
templates *template.Template
|
||||||
mux map[string]*http.Handler
|
mux map[string]http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateParameter holds the redirect id along with the session id.
|
// StateParameter holds the redirect id along with the session id.
|
||||||
|
@ -184,7 +184,7 @@ func New(opts *Options) (*Proxy, error) {
|
||||||
|
|
||||||
p := &Proxy{
|
p := &Proxy{
|
||||||
// these fields make up the routing mechanism
|
// these fields make up the routing mechanism
|
||||||
mux: make(map[string]*http.Handler),
|
mux: make(map[string]http.Handler),
|
||||||
// session state
|
// session state
|
||||||
cipher: cipher,
|
cipher: cipher,
|
||||||
csrfStore: cookieStore,
|
csrfStore: cookieStore,
|
||||||
|
@ -243,15 +243,12 @@ func deleteUpstreamCookies(req *http.Request, cookieName string) {
|
||||||
req.Header.Set("Cookie", strings.Join(headers, ";"))
|
req.Header.Set("Cookie", strings.Join(headers, ";"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// signRequest signs a g
|
|
||||||
func (u *UpstreamProxy) signRequest(req *http.Request) {
|
func (u *UpstreamProxy) signRequest(req *http.Request) {
|
||||||
if u.signer != nil {
|
if u.signer != nil {
|
||||||
jwt, err := u.signer.SignJWT(req.Header.Get(HeaderUserID), req.Header.Get(HeaderEmail))
|
jwt, err := u.signer.SignJWT(req.Header.Get(HeaderUserID), req.Header.Get(HeaderEmail))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
req.Header.Set(HeaderJWT, jwt)
|
req.Header.Set(HeaderJWT, jwt)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue