all: refactor handler logic

- all: prefer `FormValues` to `ParseForm` with subsequent `Form.Get`s
- all: refactor authentication stack to be checked by middleware, and accessible via request context.
- all: replace http.ServeMux with gorilla/mux’s router
- all: replace custom CSRF checks with gorilla/csrf middleware
- authenticate: extract callback path as constant.
- internal/config: implement stringer interface for policy
- internal/cryptutil: add helper func `NewBase64Key`
- internal/cryptutil: rename `GenerateKey` to `NewKey`
- internal/cryptutil: rename `GenerateRandomString` to `NewRandomStringN`
- internal/middleware: removed alice in favor of gorilla/mux
- internal/sessions: remove unused `ValidateRedirectURI` and `ValidateClientSecret`
- internal/sessions: replace custom CSRF with gorilla/csrf fork that supports custom handler protection
- internal/urlutil: add `SignedRedirectURL` to create hmac'd URLs
- internal/urlutil: add `ValidateURL` helper to parse URL options
- internal/urlutil: add `GetAbsoluteURL` which takes a request and returns its absolute URL.
- proxy: remove holdover state verification checks; we no longer are setting sessions in any proxy routes so we don’t need them.
- proxy: replace un-named http.ServeMux with named domain routes.

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2019-09-12 13:54:30 -07:00
parent a793249386
commit dc12947241
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
37 changed files with 1132 additions and 1384 deletions

3
.gitignore vendored
View file

@ -76,4 +76,5 @@ yarn.lock
node_modules node_modules
i18n/* i18n/*
docs/.vuepress/dist/ docs/.vuepress/dist/
.firebase/ .firebase/
.changes.md

View file

@ -87,31 +87,6 @@ https://github.com/bitly/oauth2_proxy/blob/master/LICENSE
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
alice
SPDX-License-Identifier: MIT
https://github.com/justinas/alice/blob/master/LICENSE
The MIT License (MIT)
Copyright (c) 2014 Justinas Stankevicius
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
goji goji
SPDX-License-Identifier: MIT SPDX-License-Identifier: MIT
https://github.com/zenazn/goji/blob/master/LICENSE https://github.com/zenazn/goji/blob/master/LICENSE

View file

@ -15,6 +15,8 @@ import (
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
) )
const callbackPath = "/oauth2/callback"
// ValidateOptions checks that configuration are complete and valid. // ValidateOptions checks that configuration are complete and valid.
// Returns on first error found. // Returns on first error found.
func ValidateOptions(o config.Options) error { func ValidateOptions(o config.Options) error {
@ -24,11 +26,8 @@ func ValidateOptions(o config.Options) error {
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil { if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %v", err) return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %v", err)
} }
if o.AuthenticateURL == nil { if err := urlutil.ValidateURL(o.AuthenticateURL); err != nil {
return errors.New("authenticate: 'AUTHENTICATE_SERVICE_URL' is required") return fmt.Errorf("authenticate: invalid 'AUTHENTICATE_SERVICE_URL': %v", err)
}
if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil {
return fmt.Errorf("authenticate: couldn't parse 'AUTHENTICATE_SERVICE_URL': %v", err)
} }
if o.ClientID == "" { if o.ClientID == "" {
return errors.New("authenticate: 'IDP_CLIENT_ID' is required") return errors.New("authenticate: 'IDP_CLIENT_ID' is required")
@ -44,8 +43,10 @@ type Authenticate struct {
SharedKey string SharedKey string
RedirectURL *url.URL RedirectURL *url.URL
cookieName string
cookieDomain string
cookieSecret []byte
templates *template.Template templates *template.Template
csrfStore sessions.CSRFStore
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
cipher cryptutil.Cipher cipher cryptutil.Cipher
provider identity.Authenticator provider identity.Authenticator
@ -61,6 +62,9 @@ func New(opts config.Options) (*Authenticate, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if opts.CookieDomain == "" {
opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String())
}
cookieStore, err := sessions.NewCookieStore( cookieStore, err := sessions.NewCookieStore(
&sessions.CookieStoreOptions{ &sessions.CookieStoreOptions{
Name: opts.CookieName, Name: opts.CookieName,
@ -74,7 +78,7 @@ func New(opts config.Options) (*Authenticate, error) {
return nil, err return nil, err
} }
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL) redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
redirectURL.Path = "/oauth2/callback" redirectURL.Path = callbackPath
provider, err := identity.New( provider, err := identity.New(
opts.Provider, opts.Provider,
&identity.Provider{ &identity.Provider{
@ -94,9 +98,11 @@ func New(opts config.Options) (*Authenticate, error) {
SharedKey: opts.SharedKey, SharedKey: opts.SharedKey,
RedirectURL: redirectURL, RedirectURL: redirectURL,
templates: templates.New(), templates: templates.New(),
csrfStore: cookieStore,
sessionStore: cookieStore, sessionStore: cookieStore,
cipher: cipher, cipher: cipher,
provider: provider, provider: provider,
cookieSecret: decodedCookieSecret,
cookieName: opts.CookieName,
cookieDomain: opts.CookieDomain,
}, nil }, nil
} }

View file

@ -10,7 +10,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/csrf"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
@ -31,24 +32,68 @@ var CSPHeaders = map[string]string{
// Handler returns the authenticate service's HTTP multiplexer, and routes. // Handler returns the authenticate service's HTTP multiplexer, and routes.
func (a *Authenticate) Handler() http.Handler { func (a *Authenticate) Handler() http.Handler {
// validation middleware chain r := httputil.NewRouter()
c := middleware.NewChain() r.Use(middleware.SetHeaders(CSPHeaders))
c = c.Append(middleware.SetHeaders(CSPHeaders)) r.Use(csrf.Protect(
mux := http.NewServeMux() a.cookieSecret,
mux.Handle("/robots.txt", c.ThenFunc(a.RobotsTxt)) csrf.Path("/"),
csrf.Domain(a.cookieDomain),
csrf.UnsafePaths([]string{"/oauth2/callback"}), // enforce CSRF on "safe" handler
csrf.FormValueName("state"), // rfc6749 section-10.12
csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieName)),
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
))
r.HandleFunc("/robots.txt", a.RobotsTxt).Methods(http.MethodGet)
// Identity Provider (IdP) endpoints // Identity Provider (IdP) endpoints
mux.Handle("/oauth2", c.ThenFunc(a.OAuthStart)) r.HandleFunc("/oauth2/callback", a.OAuthCallback).Methods(http.MethodGet)
mux.Handle("/oauth2/callback", c.ThenFunc(a.OAuthCallback)) r.HandleFunc("/api/v1/token", a.ExchangeToken)
// Proxy service endpoints // Proxy service endpoints
validationMiddlewares := c.Append( v := r.PathPrefix("/.pomerium").Subrouter()
middleware.ValidateSignature(a.SharedKey), v.Use(middleware.ValidateSignature(a.SharedKey))
middleware.ValidateRedirectURI(a.RedirectURL), v.Use(middleware.ValidateRedirectURI(a.RedirectURL))
) v.Use(sessions.RetrieveSession(a.sessionStore))
mux.Handle("/sign_in", validationMiddlewares.ThenFunc(a.SignIn)) v.Use(a.VerifySession)
mux.Handle("/sign_out", validationMiddlewares.ThenFunc(a.SignOut)) // POST
// Direct user access endpoints v.HandleFunc("/sign_in", a.SignIn)
mux.Handle("/api/v1/token", c.ThenFunc(a.ExchangeToken)) v.HandleFunc("/sign_out", a.SignOut)
return mux return r
}
// VerifySession is the middleware used to enforce a valid authentication
// session state is attached to the users's request context.
func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
state, err := sessions.FromContext(r.Context())
if errors.Is(err, sessions.ErrExpired) {
if err := a.refresh(w, r, state); err != nil {
log.FromRequest(r).Debug().Str("cause", err.Error()).Msg("authenticate: couldn't refresh session")
a.sessionStore.ClearSession(w, r)
a.redirectToIdentityProvider(w, r)
return
}
} else if err != nil {
log.FromRequest(r).Err(err).Msg("authenticate: unexpected session state")
a.sessionStore.ClearSession(w, r)
a.redirectToIdentityProvider(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) error {
newSession, err := a.provider.Refresh(r.Context(), s)
if err != nil {
return fmt.Errorf("authenticate: refresh failed: %w", err)
}
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
return fmt.Errorf("authenticate: refresh save failed: %w", err)
}
return nil
} }
// RobotsTxt handles the /robots.txt route. // RobotsTxt handles the /robots.txt route.
@ -59,87 +104,22 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "User-agent: *\nDisallow: /") fmt.Fprintf(w, "User-agent: *\nDisallow: /")
} }
func (a *Authenticate) loadExisting(w http.ResponseWriter, r *http.Request) (*sessions.State, error) {
session, err := a.sessionStore.LoadSession(r)
if err != nil {
return nil, err
}
err = session.Valid()
if err == nil {
return session, nil
} else if !errors.Is(err, sessions.ErrExpired) {
return nil, fmt.Errorf("authenticate: non-refreshable error: %w", err)
} else {
return a.refresh(w, r, session)
}
}
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (*sessions.State, error) {
newSession, err := a.provider.Refresh(r.Context(), s)
if err != nil {
return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
}
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
return nil, fmt.Errorf("authenticate: refresh save failed: %w", err)
}
return newSession, nil
}
// SignIn handles to authenticating a user. // SignIn handles to authenticating a user.
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
session, err := a.loadExisting(w, r) redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
if err != nil {
log.FromRequest(r).Debug().Err(err).Msg("authenticate: need new session")
a.sessionStore.ClearSession(w, r)
a.OAuthStart(w, r)
return
}
if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
state := r.Form.Get("state")
if state == "" {
httputil.ErrorResponse(w, r, httputil.Error("sign in state empty", http.StatusBadRequest, nil))
return
}
redirectURL, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
return return
} }
// encrypt session state as json blob http.Redirect(w, r, redirectURL.String(), http.StatusFound)
encrypted, err := sessions.MarshalSession(session, a.cipher)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("couldn't marshal session", http.StatusInternalServerError, err))
return
}
http.Redirect(w, r, getAuthCodeRedirectURL(redirectURL, state, encrypted), http.StatusFound)
}
func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string {
// ParseQuery err handled by go's mux stack
params, _ := url.ParseQuery(redirectURL.RawQuery)
params.Set("code", authCode)
params.Set("state", state)
redirectURL.RawQuery = params.Encode()
return redirectURL.String()
} }
// SignOut signs the user out and attempts to revoke the user's identity session // SignOut signs the user out and attempts to revoke the user's identity session
// Handles both GET and POST. // Handles both GET and POST.
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil { session, err := sessions.FromContext(r.Context())
httputil.ErrorResponse(w, r, err)
return
}
redirectURI := r.Form.Get("redirect_uri")
session, err := a.sessionStore.LoadSession(r)
if err != nil { if err != nil {
log.Error().Err(err).Msg("authenticate: no session to signout, redirect and clear") httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
http.Redirect(w, r, redirectURI, http.StatusFound)
return return
} }
a.sessionStore.ClearSession(w, r) a.sessionStore.ClearSession(w, r)
@ -148,46 +128,30 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err))
return return
} }
http.Redirect(w, r, redirectURI, http.StatusFound) redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
return
}
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
} }
// OAuthStart starts the authenticate process by redirecting to the identity provider. // redirectToIdentityProvider starts the authenticate process by redirecting the
// user to their respective identity provider.
//
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest // https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
// https://tools.ietf.org/html/rfc6749#section-4.2.1 // https://tools.ietf.org/html/rfc6749#section-4.2.1
func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http.Request) {
authRedirectURL := a.RedirectURL.ResolveReference(r.URL) redirectURL := a.RedirectURL.ResolveReference(r.URL)
nonce := csrf.Token(r)
// Nonce is the opaque, cryptographically binding value used to maintain state := fmt.Sprintf("%v:%v", nonce, redirectURL.String())
// state between the request and the callback. encodedState := base64.URLEncoding.EncodeToString([]byte(state))
// OIDC : 3.1.2.1. Authentication Request http.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
a.csrfStore.SetCSRF(w, r, nonce)
// Redirection URI to which the response will be sent. This URI MUST exactly
// match one of the Redirection URI values for the Client pre-registered at
// at your identity provider
proxyRedirectURL, err := urlutil.ParseAndValidateURL(authRedirectURL.Query().Get("redirect_uri"))
if err != nil || !middleware.SameDomain(proxyRedirectURL, a.RedirectURL) {
httputil.ErrorResponse(w, r, httputil.Error("proxy url not from the root domain", http.StatusBadRequest, err))
return
}
// get the signature and timestamp values then compare hmac
proxyRedirectSig := authRedirectURL.Query().Get("sig")
ts := authRedirectURL.Query().Get("ts")
if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, a.SharedKey) {
httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil))
return
}
// State is the opaque value used to maintain state between the request and
// the callback; contains both the nonce and redirect URI
state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String())))
// build the provider sign in url
signInURL := a.provider.GetSignInURL(state)
http.Redirect(w, r, signInURL, http.StatusFound)
} }
// OAuthCallback handles the callback from the identity provider. // OAuthCallback handles the callback from the identity provider.
//
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowSteps
// https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse // https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
redirect, err := a.getOAuthCallback(w, r) redirect, err := a.getOAuthCallback(w, r)
@ -195,57 +159,49 @@ func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err)) httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err))
return return
} }
// redirect back to the proxy-service via sign_in
http.Redirect(w, r, redirect.String(), http.StatusFound) http.Redirect(w, r, redirect.String(), http.StatusFound)
} }
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) { func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
if err := r.ParseForm(); err != nil { // Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
return nil, httputil.Error("invalid signature", http.StatusBadRequest, err) //
// first, check if the identity provider returned an error
if idpError := r.FormValue("error"); idpError != "" {
return nil, httputil.Error(idpError, http.StatusBadRequest, fmt.Errorf("identity provider: %v", idpError))
} }
// OIDC : 3.1.2.6. Authentication Error Response // fail if no session redemption code is returned
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthError code := r.FormValue("code")
if idpError := r.Form.Get("error"); idpError != "" {
return nil, httputil.Error("provider returned an error", http.StatusBadRequest, fmt.Errorf("provider error: %v", idpError))
}
code := r.Form.Get("code")
if code == "" { if code == "" {
return nil, httputil.Error("provider didn't reply with code", http.StatusBadRequest, nil) return nil, httputil.Error("identity provider returned empty code", http.StatusBadRequest, nil)
} }
// validate the returned code with the identity provider // Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5
//
// Exchange the supplied Authorization Code for a valid user session.
session, err := a.provider.Authenticate(r.Context(), code) session, err := a.provider.Authenticate(r.Context(), code)
if err != nil { if err != nil {
return nil, fmt.Errorf("error redeeming authenticate code: %w", err) return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
} }
// state includes a csrf nonce (validated by middleware) and redirect uri
// OIDC : 3.1.2.5. Successful Authentication Response bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state"))
// Opaque value used to maintain state between the request and the callback.
bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state"))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed decoding state: %w", err) return nil, httputil.Error("malformed state", http.StatusBadRequest, err)
} }
s := strings.SplitN(string(bytes), ":", 2) // split state into its it's components (nonce:redirect_uri)
if len(s) != 2 { statePayload := strings.SplitN(string(bytes), ":", 2)
return nil, fmt.Errorf("invalid state size: %d", len(s)) if len(statePayload) != 2 {
return nil, fmt.Errorf("state malformed, size: %d", len(statePayload))
} }
// state contains the csrf nonce and redirect uri // parse redirect_uri; ignore csrf nonce (validity asserted by middleware)
nonce := s[0] redirectURL, err := urlutil.ParseAndValidateURL(statePayload[1])
redirect := s[1]
c, err := a.csrfStore.GetCSRF(r)
defer a.csrfStore.ClearCSRF(w, r)
if err != nil || c.Value != nonce {
return nil, fmt.Errorf("csrf failure: %w", err)
}
redirectURL, err := urlutil.ParseAndValidateURL(redirect)
if err != nil { if err != nil {
return nil, httputil.Error(fmt.Sprintf("invalid redirect uri %s", redirect), http.StatusBadRequest, err) return nil, httputil.Error("invalid redirect uri", http.StatusBadRequest, err)
}
// sanity check, we are redirecting back to the same subdomain right?
if !middleware.SameDomain(redirectURL, a.RedirectURL) {
return nil, httputil.Error(fmt.Sprintf("invalid redirect domain %v, %v", redirectURL, a.RedirectURL), http.StatusBadRequest, nil)
} }
// todo(bdd): if we want to be _extra_ sure, we can validate that the
// redirectURL hmac is valid. But the nonce should cover the integrity...
// OK. Looks good so let's persist our user session
if err := a.sessionStore.SaveSession(w, r, session); err != nil { if err := a.sessionStore.SaveSession(w, r, session); err != nil {
return nil, fmt.Errorf("failed saving new session: %w", err) return nil, fmt.Errorf("failed saving new session: %w", err)
} }
@ -256,11 +212,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
// and exchanges that token for a pomerium session. The provided token's // and exchanges that token for a pomerium session. The provided token's
// audience ('aud') attribute must match Pomerium's client_id. // audience ('aud') attribute must match Pomerium's client_id.
func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil { code := r.FormValue("id_token")
httputil.ErrorResponse(w, r, err)
return
}
code := r.Form.Get("id_token")
if code == "" { if code == "" {
httputil.ErrorResponse(w, r, httputil.Error("missing id token", http.StatusBadRequest, nil)) httputil.ErrorResponse(w, r, httputil.Error("missing id token", http.StatusBadRequest, nil))
return return

View file

@ -21,6 +21,7 @@ func testAuthenticate() *Authenticate {
var auth Authenticate var auth Authenticate
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback") auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
auth.SharedKey = "IzY7MOZwzfOkmELXgozHDKTxoT3nOYhwkcmUVINsRww=" auth.SharedKey = "IzY7MOZwzfOkmELXgozHDKTxoT3nOYhwkcmUVINsRww="
auth.cookieSecret = []byte(auth.SharedKey)
auth.templates = templates.New() auth.templates = templates.New()
return &auth return &auth
} }
@ -51,6 +52,7 @@ func TestAuthenticate_Handler(t *testing.T) {
t.Error("handler cannot be nil") t.Error("handler cannot be nil")
} }
req := httptest.NewRequest("GET", "/robots.txt", nil) req := httptest.NewRequest("GET", "/robots.txt", nil)
req.Header.Set("Accept", "application/json")
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
h.ServeHTTP(rr, req) h.ServeHTTP(rr, req)
@ -63,6 +65,7 @@ func TestAuthenticate_Handler(t *testing.T) {
} }
func TestAuthenticate_SignIn(t *testing.T) { func TestAuthenticate_SignIn(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
state string state string
@ -76,36 +79,35 @@ func TestAuthenticate_SignIn(t *testing.T) {
{"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound}, {"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
{"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusFound}, {"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusFound},
{"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound}, {"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound},
{"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusBadRequest}, // mocking hmac is meh {"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusFound}, // mocking hmac is meh
{"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusBadRequest}, {"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound},
// {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, // {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
{"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, {"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
{"malformed form", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, {"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
{"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, {"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
// actually caught by go's handler, but we should keep the test. // actually caught by go's handler, but we should keep the test.
{"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, {"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
{"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusInternalServerError}, {"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusFound},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
a := &Authenticate{ a := &Authenticate{
sessionStore: tt.session, sessionStore: tt.session,
provider: tt.provider, provider: tt.provider,
RedirectURL: uriParse("https://some.example"), RedirectURL: uriParseHelper("https://some.example"),
csrfStore: &sessions.MockCSRFStore{},
SharedKey: "secret", SharedKey: "secret",
cipher: tt.cipher, cipher: tt.cipher,
} }
uri := &url.URL{Host: "corp.some.example", Scheme: "https", Path: "/"} uri := &url.URL{Host: "corp.some.example", Scheme: "https", Path: "/"}
if tt.name == "malformed form" { uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI)
uri.RawQuery = "example=%zzzzz"
} else {
uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI)
}
r := httptest.NewRequest(http.MethodGet, uri.String(), nil) r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, nil)
r = r.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
a.SignIn(w, r) a.SignIn(w, r)
@ -117,61 +119,18 @@ func TestAuthenticate_SignIn(t *testing.T) {
} }
} }
type mockCipher struct{} func uriParseHelper(s string) *url.URL {
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Marshal(s interface{}) (string, error) { return "ok", nil }
func (a mockCipher) Unmarshal(s string, i interface{}) error {
if s == "unmarshal error" || s == "error" {
return errors.New("error")
}
return nil
}
func Test_getAuthCodeRedirectURL(t *testing.T) {
tests := []struct {
name string
redirectURL *url.URL
state string
authCode string
want string
}{
{"https", uriParse("https://www.pomerium.io"), "state", "auth-code", "https://www.pomerium.io?code=auth-code&state=state"},
{"http", uriParse("http://www.pomerium.io"), "state", "auth-code", "http://www.pomerium.io?code=auth-code&state=state"},
{"no subdomain", uriParse("http://pomerium.io"), "state", "auth-code", "http://pomerium.io?code=auth-code&state=state"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getAuthCodeRedirectURL(tt.redirectURL, tt.state, tt.authCode); got != tt.want {
t.Errorf("getAuthCodeRedirectURL() = %v, want %v", got, tt.want)
}
})
}
}
func uriParse(s string) *url.URL {
uri, _ := url.Parse(s) uri, _ := url.Parse(s)
return uri return uri
} }
func TestAuthenticate_SignOut(t *testing.T) { func TestAuthenticate_SignOut(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
method string method string
ctxError error
redirectURL string redirectURL string
sig string sig string
ts string ts string
@ -181,17 +140,16 @@ func TestAuthenticate_SignOut(t *testing.T) {
wantCode int wantCode int
wantBody string wantBody string
}{ }{
{"good post", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""}, {"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"}, {"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"},
{"malformed form", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusInternalServerError, ""}, {"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, ""},
{"load session error", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""}, {"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, ""},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
a := &Authenticate{ a := &Authenticate{
sessionStore: tt.sessionStore, sessionStore: tt.sessionStore,
provider: tt.provider, provider: tt.provider,
cipher: mockCipher{},
templates: templates.New(), templates: templates.New(),
} }
u, _ := url.Parse("/sign_out") u, _ := url.Parse("/sign_out")
@ -200,10 +158,11 @@ func TestAuthenticate_SignOut(t *testing.T) {
params.Add("ts", tt.ts) params.Add("ts", tt.ts)
params.Add("redirect_uri", tt.redirectURL) params.Add("redirect_uri", tt.redirectURL)
u.RawQuery = params.Encode() u.RawQuery = params.Encode()
if tt.name == "malformed form" {
u.RawQuery = "example=%zzzzz"
}
r := httptest.NewRequest(tt.method, u.String(), nil) r := httptest.NewRequest(tt.method, u.String(), nil)
state, _ := tt.sessionStore.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
a.SignOut(w, r) a.SignOut(w, r)
@ -217,64 +176,8 @@ func TestAuthenticate_SignOut(t *testing.T) {
} }
} }
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) string {
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
h := cryptutil.Hash(secret, data)
return base64.URLEncoding.EncodeToString(h)
}
func TestAuthenticate_OAuthStart(t *testing.T) {
tests := []struct {
name string
method string
redirectURLSetting string
redirectURL string
sig string
ts string
provider identity.Authenticator
csrfStore sessions.MockCSRFStore
// sessionStore sessions.SessionStore
wantCode int
}{
{"good", http.MethodGet, "https://corp.pomerium.io/", "https://corp.pomerium.io/", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusFound},
{"bad timestamp", http.MethodGet, "https://corp.pomerium.io/", "https://corp.pomerium.io/", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Add(10 * time.Hour).Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest},
{"missing redirect", http.MethodGet, "https://corp.pomerium.io/", "", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest},
{"malformed redirect", http.MethodGet, "https://corp.pomerium.io/", "https://pomerium.com%zzzzz", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest},
{"different domains", http.MethodGet, "https://corp.notpomerium.io/", "https://corp.pomerium.io/", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Authenticate{
RedirectURL: uriParse(tt.redirectURLSetting),
csrfStore: tt.csrfStore,
provider: tt.provider,
SharedKey: "secret",
cipher: mockCipher{},
}
u, _ := url.Parse("/oauth_start")
params, _ := url.ParseQuery(u.RawQuery)
params.Add("sig", tt.sig)
params.Add("ts", tt.ts)
params.Add("redirect_uri", tt.redirectURL)
u.RawQuery = params.Encode()
r := httptest.NewRequest(tt.method, u.String(), nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
a.OAuthStart(w, r)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v\n%v", status, tt.wantCode, w.Body.String())
}
})
}
}
func TestAuthenticate_OAuthCallback(t *testing.T) { func TestAuthenticate_OAuthCallback(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
method string method string
@ -286,24 +189,20 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
authenticateURL string authenticateURL string
session sessions.SessionStore session sessions.SessionStore
provider identity.MockProvider provider identity.MockProvider
csrfStore sessions.MockCSRFStore
want string want string
wantCode int wantCode int
}{ }{
{"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusFound}, {"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusFound},
{"get csrf error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", GetError: errors.New("error"), Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError}, {"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
{"csrf nonce error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError}, {"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusInternalServerError},
{"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, {"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
{"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, {"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
{"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, {"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
{"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, {"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"invalid state string", http.MethodGet, "", "code", "nonce:https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, {"bad base64 state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")) + "%", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"malformed state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, {"too many state delimeters", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io:wait")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, {"too few state delimeters", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusInternalServerError},
{"malformed form", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
{"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"different domains", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://some.example.notpomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -311,7 +210,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
a := &Authenticate{ a := &Authenticate{
RedirectURL: authURL, RedirectURL: authURL,
sessionStore: tt.session, sessionStore: tt.session,
csrfStore: tt.csrfStore,
provider: tt.provider, provider: tt.provider,
} }
u, _ := url.Parse("/oauthGet") u, _ := url.Parse("/oauthGet")
@ -322,9 +220,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
u.RawQuery = params.Encode() u.RawQuery = params.Encode()
if tt.name == "malformed form" {
u.RawQuery = "example=%zzzzz"
}
r := httptest.NewRequest(tt.method, u.String(), nil) r := httptest.NewRequest(tt.method, u.String(), nil)
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -339,6 +234,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
} }
func TestAuthenticate_ExchangeToken(t *testing.T) { func TestAuthenticate_ExchangeToken(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
method string method string
@ -384,3 +280,55 @@ func TestAuthenticate_ExchangeToken(t *testing.T) {
}) })
} }
} }
func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
session sessions.SessionStore
ctxError error
provider identity.Authenticator
wantStatus int
}{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusOK},
{"expired,refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
{"expired,save error", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := Authenticate{
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
sessionStore: tt.session,
provider: tt.provider,
}
r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := a.VerifySession(fn)
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
})
}
}

View file

@ -14,7 +14,7 @@ import (
func ValidateOptions(o config.Options) error { func ValidateOptions(o config.Options) error {
decoded, err := base64.StdEncoding.DecodeString(o.SharedKey) decoded, err := base64.StdEncoding.DecodeString(o.SharedKey)
if err != nil { if err != nil {
return fmt.Errorf("authorize: `SHARED_SECRET` setting is invalid base64: %v", err) return fmt.Errorf("authorize: `SHARED_SECRET` malformed base64: %v", err)
} }
if len(decoded) != 32 { if len(decoded) != 32 {
return fmt.Errorf("authorize: `SHARED_SECRET` want 32 but got %d bytes", len(decoded)) return fmt.Errorf("authorize: `SHARED_SECRET` want 32 but got %d bytes", len(decoded))

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/gorilla/mux"
"github.com/spf13/viper" "github.com/spf13/viper"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -44,9 +45,9 @@ func main() {
setupTracing(opt) setupTracing(opt)
setupHTTPRedirectServer(opt) setupHTTPRedirectServer(opt)
mux := http.NewServeMux() r := newGlobalRouter(opt)
grpcServer := setupGRPCServer(opt) grpcServer := setupGRPCServer(opt)
_, err = newAuthenticateService(*opt, mux) _, err = newAuthenticateService(*opt, r.Host(urlutil.StripPort(opt.AuthenticateURL.Host)).Subrouter())
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: authenticate") log.Fatal().Err(err).Msg("cmd/pomerium: authenticate")
} }
@ -56,7 +57,7 @@ func main() {
log.Fatal().Err(err).Msg("cmd/pomerium: authorize") log.Fatal().Err(err).Msg("cmd/pomerium: authorize")
} }
proxy, err := newProxyService(*opt, mux) proxy, err := newProxyService(*opt, r)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: proxy") log.Fatal().Err(err).Msg("cmd/pomerium: proxy")
} }
@ -70,8 +71,7 @@ func main() {
log.Info().Str("file", e.Name).Msg("cmd/pomerium: config file changed") log.Info().Str("file", e.Name).Msg("cmd/pomerium: config file changed")
opt = config.HandleConfigUpdate(*configFile, opt, []config.OptionsUpdater{authz, proxy}) opt = config.HandleConfigUpdate(*configFile, opt, []config.OptionsUpdater{authz, proxy})
}) })
srv, err := httputil.NewTLSServer(configToServerOptions(opt), r, grpcServer)
srv, err := httputil.NewTLSServer(configToServerOptions(opt), mainHandler(opt, mux), grpcServer)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: couldn't start pomerium") log.Fatal().Err(err).Msg("cmd/pomerium: couldn't start pomerium")
} }
@ -80,7 +80,7 @@ func main() {
os.Exit(0) os.Exit(0)
} }
func newAuthenticateService(opt config.Options, mux *http.ServeMux) (*authenticate.Authenticate, error) { func newAuthenticateService(opt config.Options, r *mux.Router) (*authenticate.Authenticate, error) {
if !config.IsAuthenticate(opt.Services) { if !config.IsAuthenticate(opt.Services) {
return nil, nil return nil, nil
} }
@ -88,7 +88,7 @@ func newAuthenticateService(opt config.Options, mux *http.ServeMux) (*authentica
if err != nil { if err != nil {
return nil, err return nil, err
} }
mux.Handle(urlutil.StripPort(opt.AuthenticateURL.Host)+"/", service.Handler()) r.PathPrefix("/").Handler(service.Handler())
return service, nil return service, nil
} }
@ -104,7 +104,7 @@ func newAuthorizeService(opt config.Options, rpc *grpc.Server) (*authorize.Autho
return service, nil return service, nil
} }
func newProxyService(opt config.Options, mux *http.ServeMux) (*proxy.Proxy, error) { func newProxyService(opt config.Options, r *mux.Router) (*proxy.Proxy, error) {
if !config.IsProxy(opt.Services) { if !config.IsProxy(opt.Services) {
return nil, nil return nil, nil
} }
@ -112,15 +112,15 @@ func newProxyService(opt config.Options, mux *http.ServeMux) (*proxy.Proxy, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
mux.Handle("/", service.Handler()) r.PathPrefix("/").Handler(service.Handler())
return service, nil return service, nil
} }
func mainHandler(o *config.Options, mux http.Handler) http.Handler { func newGlobalRouter(o *config.Options) *mux.Router {
c := middleware.NewChain() mux := httputil.NewRouter()
c = c.Append(metrics.HTTPMetricsHandler(o.Services)) mux.Use(metrics.HTTPMetricsHandler(o.Services))
c = c.Append(log.NewHandler(log.Logger)) mux.Use(log.NewHandler(log.Logger))
c = c.Append(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) { mux.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
log.FromRequest(r).Debug(). log.FromRequest(r).Debug().
Dur("duration", duration). Dur("duration", duration).
Int("size", size). Int("size", size).
@ -133,15 +133,15 @@ func mainHandler(o *config.Options, mux http.Handler) http.Handler {
Msg("http-request") Msg("http-request")
})) }))
if len(o.Headers) != 0 { if len(o.Headers) != 0 {
c = c.Append(middleware.SetHeaders(o.Headers)) mux.Use(middleware.SetHeaders(o.Headers))
} }
c = c.Append(log.ForwardedAddrHandler("fwd_ip")) mux.Use(log.ForwardedAddrHandler("fwd_ip"))
c = c.Append(log.RemoteAddrHandler("ip")) mux.Use(log.RemoteAddrHandler("ip"))
c = c.Append(log.UserAgentHandler("user_agent")) mux.Use(log.UserAgentHandler("user_agent"))
c = c.Append(log.RefererHandler("referer")) mux.Use(log.RefererHandler("referer"))
c = c.Append(log.RequestIDHandler("req_id", "Request-Id")) mux.Use(log.RequestIDHandler("req_id", "Request-Id"))
c = c.Append(middleware.Healthcheck("/ping", version.UserAgent())) mux.Use(middleware.Healthcheck("/ping", version.UserAgent()))
return c.Then(mux) return mux
} }
func configToServerOptions(opt *config.Options) *httputil.ServerOptions { func configToServerOptions(opt *config.Options) *httputil.ServerOptions {

View file

@ -21,7 +21,7 @@ import (
) )
func Test_newAuthenticateService(t *testing.T) { func Test_newAuthenticateService(t *testing.T) {
mux := http.NewServeMux() mux := httputil.NewRouter()
tests := []struct { tests := []struct {
name string name string
@ -127,7 +127,7 @@ func Test_newProxyeService(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
mux := http.NewServeMux() mux := httputil.NewRouter()
testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example") testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -161,7 +161,7 @@ func Test_newProxyeService(t *testing.T) {
} }
} }
func Test_mainHandler(t *testing.T) { func Test_newGlobalRouter(t *testing.T) {
o := config.Options{ o := config.Options{
Services: "all", Services: "all",
Headers: map[string]string{ Headers: map[string]string{
@ -172,7 +172,6 @@ func Test_mainHandler(t *testing.T) {
"Content-Security-Policy": "default-src 'none'; style-src 'self' 'sha256-pSTVzZsFAqd2U3QYu+BoBDtuJWaPM/+qMy/dBRrhb5Y='; img-src 'self';", "Content-Security-Policy": "default-src 'none'; style-src 'self' 'sha256-pSTVzZsFAqd2U3QYu+BoBDtuJWaPM/+qMy/dBRrhb5Y='; img-src 'self';",
"Referrer-Policy": "Same-origin", "Referrer-Policy": "Same-origin",
}} }}
mux := http.NewServeMux()
req := httptest.NewRequest(http.MethodGet, "/404", nil) req := httptest.NewRequest(http.MethodGet, "/404", nil)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -181,8 +180,9 @@ func Test_mainHandler(t *testing.T) {
io.WriteString(w, `OK`) io.WriteString(w, `OK`)
}) })
mux.Handle("/404", h) out := newGlobalRouter(&o)
out := mainHandler(&o, mux) out.Handle("/404", h)
out.ServeHTTP(rr, req) out.ServeHTTP(rr, req)
expected := fmt.Sprintf("OK") expected := fmt.Sprintf("OK")
body := rr.Body.String() body := rr.Body.String()

View file

@ -5,6 +5,23 @@
### New ### New
- Add ability to override HTTPS backend's TLS Server Name. [GH-297](https://github.com/pomerium/pomerium/pull/297) - Add ability to override HTTPS backend's TLS Server Name. [GH-297](https://github.com/pomerium/pomerium/pull/297)
- Add ability to set pomerium's encrypted session in a auth bearer token, or query param.
### Security
- Under certain circumstances, where debug logging was enabled, pomerium's shared secret could be leaked to http access logs as a query param.
### Fixed
- Fixed an issue where CSRF would fail if multiple tabs were open. [GH-306](https://github.com/pomerium/pomerium/issues/306)
### Changed
- Authenticate service no longer uses gRPC.
### Removed
- Removed `AUTHENTICATE_INTERNAL_URL`/`authenticate_internal_url` which is no longer used.
## v0.3.0 ## v0.3.0

View file

@ -228,8 +228,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor
| Config Key | Description | Required | | Config Key | Description | Required |
| :--------------- | :---------------------------------------------------------------- | -------- | | :--------------- | :---------------------------------------------------------------- | -------- |
| tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ | | tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ |
| tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ | | tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ |
### Jaeger ### Jaeger
@ -243,8 +243,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor
| Config Key | Description | Required | | Config Key | Description | Required |
| :-------------------------------- | :------------------------------------------ | -------- | | :-------------------------------- | :------------------------------------------ | -------- |
| tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ | | tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ |
| tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ | | tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ |
#### Example #### Example
@ -478,11 +478,11 @@ Authenticate Service URL is the externally accessible URL for the authenticate s
- Config File Key: `authorize_service_url` - Config File Key: `authorize_service_url`
- Type: `URL` - Type: `URL`
- Required - Required
- Example: `https://access.corp.example.com` or `https://pomerium-authorize-service.default.svc.cluster.local` - Example: `https://authorize.corp.example.com` or `https://pomerium-authorize-service.default.svc.cluster.local`
Authorize Service URL is the location of the internally accessible authorize service. NOTE: Unlike authenticate, authorize has no publicly accessible http handlers so this setting is purely for gRPC communication. Authorize Service URL is the location of the internally accessible authorize service. NOTE: Unlike authenticate, authorize has no publicly accessible http handlers so this setting is purely for gRPC communication.
If your load balancer does not support gRPC pass-through you'll need to set this value to an internally routable location (`https://pomerium-authorize-service.default.svc.cluster.local`) instead of an externally routable one (`https://access.corp.example.com`). If your load balancer does not support gRPC pass-through you'll need to set this value to an internally routable location (`https://pomerium-authorize-service.default.svc.cluster.local`) instead of an externally routable one (`https://authorize.corp.example.com`).
## Override Certificate Name ## Override Certificate Name

4
go.mod
View file

@ -9,10 +9,12 @@ require (
github.com/fsnotify/fsnotify v1.4.7 github.com/fsnotify/fsnotify v1.4.7
github.com/golang/mock v1.3.1 github.com/golang/mock v1.3.1
github.com/golang/protobuf v1.3.1 github.com/golang/protobuf v1.3.1
github.com/google/go-cmp v0.3.0 github.com/google/go-cmp v0.3.1
github.com/gorilla/mux v1.6.2
github.com/magiconair/properties v1.8.1 // indirect github.com/magiconair/properties v1.8.1 // indirect
github.com/mitchellh/hashstructure v1.0.0 github.com/mitchellh/hashstructure v1.0.0
github.com/pelletier/go-toml v1.4.0 // indirect github.com/pelletier/go-toml v1.4.0 // indirect
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30
github.com/pomerium/go-oidc v2.0.0+incompatible github.com/pomerium/go-oidc v2.0.0+incompatible
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/prometheus/client_golang v0.9.3 github.com/prometheus/client_golang v0.9.3

8
go.sum
View file

@ -65,11 +65,16 @@ github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk=
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 h1:Iju5GlWwrvL6UBg4zJJt3btmonfrMlCDdsejg4CZE7c= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 h1:Iju5GlWwrvL6UBg4zJJt3btmonfrMlCDdsejg4CZE7c=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
@ -115,9 +120,12 @@ github.com/pelletier/go-toml v1.4.0 h1:u3Z1r+oOXJIkxqw34zVhyPgjBsm6X2wn21NWs/HfS
github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo= github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo=
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30 h1:jggCv6hZvcxjGa3gqkYY2EUuOkITI9Znugz/f3QJfRQ=
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0=
github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o= github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o=
github.com/pomerium/go-oidc v2.0.0+incompatible/go.mod h1:DRsGVw6MOgxbfq4Y57jKOE8lbEfayxeiY0A8/4vxjBM= github.com/pomerium/go-oidc v2.0.0+incompatible/go.mod h1:DRsGVw6MOgxbfq4Y57jKOE8lbEfayxeiY0A8/4vxjBM=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU= github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=

View file

@ -217,7 +217,7 @@ func (o *Options) Validate() error {
// shared key must be set for all modes other than "all" // shared key must be set for all modes other than "all"
if o.SharedKey == "" { if o.SharedKey == "" {
if o.Services == "all" { if o.Services == "all" {
o.SharedKey = cryptutil.GenerateRandomString(32) o.SharedKey = cryptutil.NewBase64Key()
} else { } else {
return errors.New("shared-key cannot be empty") return errors.New("shared-key cannot be empty")
} }

View file

@ -116,3 +116,9 @@ func (p *Policy) Validate() error {
return nil return nil
} }
func (p *Policy) String() string {
if p.Source == nil || p.Destination == nil {
return fmt.Sprintf("%s → %s", p.From, p.To)
}
return fmt.Sprintf("%s → %s", p.Source.String(), p.Destination.String())
}

View file

@ -1,4 +1,4 @@
package config // import "github.com/pomerium/pomerium/internal/config" package config
import ( import (
"testing" "testing"
@ -44,3 +44,28 @@ func Test_Validate(t *testing.T) {
}) })
} }
} }
func TestPolicy_String(t *testing.T) {
t.Parallel()
tests := []struct {
name string
From string
To string
want string
}{
{"good", "https://pomerium.io", "https://localhost", "https://pomerium.io → https://localhost"},
{"failed to validate", "https://pomerium.io", "localhost", "https://pomerium.io → localhost"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Policy{
From: tt.From,
To: tt.To,
}
p.Validate()
if got := p.String(); got != tt.want {
t.Errorf("Policy.String() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -13,20 +13,27 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
) )
// DefaultKeySize is the default key size in bytes.
const DefaultKeySize = 32 const DefaultKeySize = 32
// GenerateKey generates a random 32-byte key. // NewKey generates a random 32-byte key.
// //
// Panics if source of randomness fails. // Panics if source of randomness fails.
func GenerateKey() []byte { func NewKey() []byte {
return randomBytes(DefaultKeySize) return randomBytes(DefaultKeySize)
} }
// GenerateRandomString returns base64 encoded securely generated random string // NewBase64Key generates a random base64 encoded 32-byte key.
// of a given set of bytes.
// //
// Panics if source of randomness fails. // Panics if source of randomness fails.
func GenerateRandomString(c int) string { func NewBase64Key() string {
return NewRandomStringN(DefaultKeySize)
}
// NewRandomStringN returns base64 encoded random string of a given num of bytes.
//
// Panics if source of randomness fails.
func NewRandomStringN(c int) string {
return base64.StdEncoding.EncodeToString(randomBytes(c)) return base64.StdEncoding.EncodeToString(randomBytes(c))
} }

View file

@ -1,4 +1,4 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" package cryptutil
import ( import (
"crypto/rand" "crypto/rand"
@ -13,7 +13,7 @@ import (
func TestEncodeAndDecodeAccessToken(t *testing.T) { func TestEncodeAndDecodeAccessToken(t *testing.T) {
plaintext := []byte("my plain text value") plaintext := []byte("my plain text value")
key := GenerateKey() key := NewKey()
c, err := NewCipher(key) c, err := NewCipher(key)
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
@ -47,7 +47,7 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
} }
func TestMarshalAndUnmarshalStruct(t *testing.T) { func TestMarshalAndUnmarshalStruct(t *testing.T) {
key := GenerateKey() key := NewKey()
c, err := NewCipher(key) c, err := NewCipher(key)
if err != nil { if err != nil {
@ -102,7 +102,7 @@ func TestMarshalAndUnmarshalStruct(t *testing.T) {
} }
func TestCipherDataRace(t *testing.T) { func TestCipherDataRace(t *testing.T) {
cipher, err := NewCipher(GenerateKey()) cipher, err := NewCipher(NewKey())
if err != nil { if err != nil {
t.Fatalf("unexpected generating cipher err: %v", err) t.Fatalf("unexpected generating cipher err: %v", err)
} }
@ -183,21 +183,21 @@ func TestGenerateRandomString(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
o := GenerateRandomString(tt.c) o := NewRandomStringN(tt.c)
b, err := base64.StdEncoding.DecodeString(o) b, err := base64.StdEncoding.DecodeString(o)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
got := len(b) got := len(b)
if got != tt.want { if got != tt.want {
t.Errorf("GenerateRandomString() = %d, want %d", got, tt.want) t.Errorf("NewRandomStringN() = %d, want %d", got, tt.want)
} }
}) })
} }
} }
func TestXChaCha20Cipher_Marshal(t *testing.T) { func TestXChaCha20Cipher_Marshal(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
s interface{} s interface{}
@ -225,7 +225,7 @@ func TestXChaCha20Cipher_Marshal(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewCipher(GenerateKey()) c, err := NewCipher(NewKey())
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }
@ -239,15 +239,15 @@ func TestXChaCha20Cipher_Marshal(t *testing.T) {
} }
func TestNewCipher(t *testing.T) { func TestNewCipher(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
secret []byte secret []byte
wantErr bool wantErr bool
}{ }{
{"simple 32 byte key", GenerateKey(), false}, {"simple 32 byte key", NewKey(), false},
{"key too short", []byte("what is entropy"), true}, {"key too short", []byte("what is entropy"), true},
{"key too long", []byte(GenerateRandomString(33)), true}, {"key too long", []byte(NewRandomStringN(33)), true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -261,16 +261,16 @@ func TestNewCipher(t *testing.T) {
} }
func TestNewCipherFromBase64(t *testing.T) { func TestNewCipherFromBase64(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
s string s string
wantErr bool wantErr bool
}{ }{
{"simple 32 byte key", base64.StdEncoding.EncodeToString(GenerateKey()), false}, {"simple 32 byte key", base64.StdEncoding.EncodeToString(NewKey()), false},
{"key too short", base64.StdEncoding.EncodeToString([]byte("what is entropy")), true}, {"key too short", base64.StdEncoding.EncodeToString([]byte("what is entropy")), true},
{"key too long", GenerateRandomString(33), true}, {"key too long", NewRandomStringN(33), true},
{"bad base 64", string(GenerateKey()), true}, {"bad base 64", string(NewKey()), true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -282,3 +282,26 @@ func TestNewCipherFromBase64(t *testing.T) {
}) })
} }
} }
func TestNewBase64Key(t *testing.T) {
t.Parallel()
tests := []struct {
name string
want int
}{
{"simple", 32},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := NewBase64Key()
b, err := base64.StdEncoding.DecodeString(o)
if err != nil {
t.Error(err)
}
got := len(b)
if got != tt.want {
t.Errorf("NewBase64Key() = %d, want %d", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,19 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"net/http"
"github.com/gorilla/mux"
"github.com/pomerium/csrf"
)
// NewRouter returns a new router instance.
func NewRouter() *mux.Router {
return mux.NewRouter()
}
// CSRFFailureHandler sets a HTTP 403 Forbidden status and writes the
// CSRF failure reason to the response.
func CSRFFailureHandler(w http.ResponseWriter, r *http.Request) {
ErrorResponse(w, r, Error("CSRF Failure", http.StatusForbidden, csrf.FailureReason(r)))
}

View file

@ -0,0 +1,37 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestCSRFFailureHandler(t *testing.T) {
tests := []struct {
name string
wantBody string
wantStatus int
}{
{"basic csrf failure", "{\"error\":\"CSRF Failure\"}\n", http.StatusForbidden},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
CSRFFailureHandler(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %s", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %s", diff)
}
})
}
}

View file

@ -1,109 +0,0 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import "net/http"
// Constructor is a type alias for func(http.Handler) http.Handler
type Constructor func(http.Handler) http.Handler
// Chain acts as a list of http.Handler constructors.
// Chain is effectively immutable:
// once created, it will always hold
// the same set of constructors in the same order.
type Chain struct {
constructors []Constructor
}
// NewChain creates a new chain,
// memorizing the given list of middleware constructors.
// New serves no other function,
// constructors are only called upon a call to Then().
func NewChain(constructors ...Constructor) Chain {
return Chain{append([]Constructor(nil), constructors...)}
}
// Then chains the middleware and returns the final http.Handler.
// NewChain(m1, m2, m3).Then(h)
// is equivalent to:
// m1(m2(m3(h)))
// When the request comes in, it will be passed to m1, then m2, then m3
// and finally, the given handler
// (assuming every middleware calls the following one).
//
// A chain can be safely reused by calling Then() several times.
// stdStack := middleware.NewChain(ratelimitHandler, csrfHandler)
// indexPipe = stdStack.Then(indexHandler)
// authPipe = stdStack.Then(authHandler)
// Note that constructors are called on every call to Then()
// and thus several instances of the same middleware will be created
// when a chain is reused in this way.
// For proper middleware, this should cause no problems.
//
// Then() treats nil as http.DefaultServeMux.
func (c Chain) Then(h http.Handler) http.Handler {
if h == nil {
h = http.DefaultServeMux
}
for i := range c.constructors {
h = c.constructors[len(c.constructors)-1-i](h)
}
return h
}
// ThenFunc works identically to Then, but takes
// a HandlerFunc instead of a Handler.
//
// The following two statements are equivalent:
// c.Then(http.HandlerFunc(fn))
// c.ThenFunc(fn)
//
// ThenFunc provides all the guarantees of Then.
func (c Chain) ThenFunc(fn http.HandlerFunc) http.Handler {
if fn == nil {
return c.Then(nil)
}
return c.Then(fn)
}
// Append extends a chain, adding the specified constructors
// as the last ones in the request flow.
//
// Append returns a new chain, leaving the original one untouched.
//
// stdChain := middleware.NewChain(m1, m2)
// extChain := stdChain.Append(m3, m4)
// // requests in stdChain go m1 -> m2
// // requests in extChain go m1 -> m2 -> m3 -> m4
func (c Chain) Append(constructors ...Constructor) Chain {
newCons := make([]Constructor, 0, len(c.constructors)+len(constructors))
newCons = append(newCons, c.constructors...)
newCons = append(newCons, constructors...)
return Chain{newCons}
}
// Extend extends a chain by adding the specified chain
// as the last one in the request flow.
//
// Extend returns a new chain, leaving the original one untouched.
//
// stdChain := middleware.NewChain(m1, m2)
// ext1Chain := middleware.NewChain(m3, m4)
// ext2Chain := stdChain.Extend(ext1Chain)
// // requests in stdChain go m1 -> m2
// // requests in ext1Chain go m3 -> m4
// // requests in ext2Chain go m1 -> m2 -> m3 -> m4
//
// Another example:
// aHtmlAfterNosurf := middleware.NewChain(m2)
// aHtml := middleware.NewChain(m1, func(h http.Handler) http.Handler {
// csrf := nosurf.NewChain(h)
// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail))
// return csrf
// }).Extend(aHtmlAfterNosurf)
// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-handler
// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail
func (c Chain) Extend(chain Chain) Chain {
return c.Append(chain.constructors...)
}

View file

@ -1,177 +0,0 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
// A constructor for middleware
// that writes its own "tag" into the RW and does nothing else.
// Useful in checking if a chain is behaving in the right order.
func tagMiddleware(tag string) Constructor {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(tag))
h.ServeHTTP(w, r)
})
}
}
// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
// but the best we can do.
func funcsEqual(f1, f2 interface{}) bool {
val1 := reflect.ValueOf(f1)
val2 := reflect.ValueOf(f2)
return val1.Pointer() == val2.Pointer()
}
var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("app\n"))
})
func TestNew(t *testing.T) {
c1 := func(h http.Handler) http.Handler {
return nil
}
c2 := func(h http.Handler) http.Handler {
return http.StripPrefix("potato", nil)
}
slice := []Constructor{c1, c2}
chain := NewChain(slice...)
for k := range slice {
if !funcsEqual(chain.constructors[k], slice[k]) {
t.Error("New does not add constructors correctly")
}
}
}
func TestThenWorksWithNoMiddleware(t *testing.T) {
if !funcsEqual(NewChain().Then(testApp), testApp) {
t.Error("Then does not work with no middleware")
}
}
func TestThenTreatsNilAsDefaultServeMux(t *testing.T) {
if NewChain().Then(nil) != http.DefaultServeMux {
t.Error("Then does not treat nil as DefaultServeMux")
}
}
func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) {
if NewChain().ThenFunc(nil) != http.DefaultServeMux {
t.Error("ThenFunc does not treat nil as DefaultServeMux")
}
}
func TestThenFuncConstructsHandlerFunc(t *testing.T) {
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
chained := NewChain().ThenFunc(fn)
rec := httptest.NewRecorder()
chained.ServeHTTP(rec, (*http.Request)(nil))
if reflect.TypeOf(chained) != reflect.TypeOf(http.HandlerFunc(nil)) {
t.Error("ThenFunc does not construct HandlerFunc")
}
}
func TestThenOrdersHandlersCorrectly(t *testing.T) {
t1 := tagMiddleware("t1\n")
t2 := tagMiddleware("t2\n")
t3 := tagMiddleware("t3\n")
chained := NewChain(t1, t2, t3).Then(testApp)
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
chained.ServeHTTP(w, r)
if w.Body.String() != "t1\nt2\nt3\napp\n" {
t.Error("Then does not order handlers correctly")
}
}
func TestAppendAddsHandlersCorrectly(t *testing.T) {
chain := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
newChain := chain.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
if len(chain.constructors) != 2 {
t.Error("chain should have 2 constructors")
}
if len(newChain.constructors) != 4 {
t.Error("newChain should have 4 constructors")
}
chained := newChain.Then(testApp)
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
chained.ServeHTTP(w, r)
if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" {
t.Error("Append does not add handlers correctly")
}
}
func TestAppendRespectsImmutability(t *testing.T) {
chain := NewChain(tagMiddleware(""))
newChain := chain.Append(tagMiddleware(""))
if &chain.constructors[0] == &newChain.constructors[0] {
t.Error("Apppend does not respect immutability")
}
}
func TestExtendAddsHandlersCorrectly(t *testing.T) {
chain1 := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
chain2 := NewChain(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
newChain := chain1.Extend(chain2)
if len(chain1.constructors) != 2 {
t.Error("chain1 should contain 2 constructors")
}
if len(chain2.constructors) != 2 {
t.Error("chain2 should contain 2 constructors")
}
if len(newChain.constructors) != 4 {
t.Error("newChain should contain 4 constructors")
}
chained := newChain.Then(testApp)
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
chained.ServeHTTP(w, r)
if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" {
t.Error("Extend does not add handlers in correctly")
}
}
func TestExtendRespectsImmutability(t *testing.T) {
chain := NewChain(tagMiddleware(""))
newChain := chain.Extend(NewChain(tagMiddleware("")))
if &chain.constructors[0] == &newChain.constructors[0] {
t.Error("Extend does not respect immutability")
}
}

View file

@ -1,2 +1,2 @@
// Package middleware provides a standard set of middleware implementations for pomerium. // Package middleware provides a standard set of middleware for pomerium.
package middleware // import "github.com/pomerium/pomerium/internal/middleware" package middleware // import "github.com/pomerium/pomerium/internal/middleware"

View file

@ -79,12 +79,10 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
domain := req.Host domain := req.Host
if name == cs.csrfName() { if cs.CookieDomain != "" {
domain = req.Host
} else if cs.CookieDomain != "" {
domain = cs.CookieDomain domain = cs.CookieDomain
} else { } else {
domain = splitDomain(domain) domain = ParentSubdomain(domain)
} }
if h, _, err := net.SplitHostPort(domain); err == nil { if h, _, err := net.SplitHostPort(domain); err == nil {
@ -105,19 +103,11 @@ func (cs *CookieStore) makeCookie(req *http.Request, name string, value string,
return c return c
} }
func (cs *CookieStore) csrfName() string {
return fmt.Sprintf("%s_csrf", cs.Name)
}
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time. // makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return cs.makeCookie(req, cs.Name, value, expiration, now) return cs.makeCookie(req, cs.Name, value, expiration, now)
} }
func (cs *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return cs.makeCookie(req, cs.csrfName(), value, expiration, now)
}
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) { func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
if len(cookie.String()) <= MaxChunkSize { if len(cookie.String()) <= MaxChunkSize {
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
@ -134,7 +124,6 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
nc.Name = fmt.Sprintf("%s_%d", cookie.Name, i) nc.Name = fmt.Sprintf("%s_%d", cookie.Name, i)
nc.Value = c nc.Value = c
} }
fmt.Println(i)
http.SetCookie(w, &nc) http.SetCookie(w, &nc)
} }
} }
@ -150,25 +139,6 @@ func chunk(s string, size int) []string {
return ss return ss
} }
// ClearCSRF clears the CSRF cookie from the request
func (cs *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, cs.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
}
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
func (cs *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(w, cs.makeCSRFCookie(req, val, cs.CookieExpire, time.Now()))
}
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
func (cs *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
c, err := req.Cookie(cs.csrfName())
if err != nil {
return nil, ErrEmptyCSRF // ErrNoCookie is confusing in this context
}
return c, nil
}
// ClearSession clears the session cookie from a request // ClearSession clears the session cookie from a request
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) { func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now())) http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
@ -235,7 +205,8 @@ func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *
return nil return nil
} }
func splitDomain(s string) string { // ParentSubdomain returns the parent subdomain.
func ParentSubdomain(s string) string {
if strings.Count(s, ".") < 2 { if strings.Count(s, ".") < 2 {
return "" return ""
} }

View file

@ -1,4 +1,4 @@
package sessions package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import ( import (
"crypto/rand" "crypto/rand"
@ -38,7 +38,7 @@ func (a mockCipher) Unmarshal(s string, i interface{}) error {
return nil return nil
} }
func TestNewCookieStore(t *testing.T) { func TestNewCookieStore(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey()) cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -111,7 +111,7 @@ func TestNewCookieStore(t *testing.T) {
} }
func TestCookieStore_makeCookie(t *testing.T) { func TestCookieStore_makeCookie(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey()) cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -155,62 +155,13 @@ func TestCookieStore_makeCookie(t *testing.T) {
if diff := cmp.Diff(s.makeSessionCookie(r, tt.value, tt.expiration, now), tt.want); diff != "" { if diff := cmp.Diff(s.makeSessionCookie(r, tt.value, tt.expiration, now), tt.want); diff != "" {
t.Errorf("CookieStore.makeSessionCookie() = \n%s", diff) t.Errorf("CookieStore.makeSessionCookie() = \n%s", diff)
} }
got := s.makeCSRFCookie(r, tt.value, tt.expiration, now)
tt.wantCSRF.Name = "_pomerium_csrf"
if !reflect.DeepEqual(got, tt.wantCSRF) {
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.wantCSRF)
}
w := httptest.NewRecorder()
want := "new-csrf"
s.SetCSRF(w, r, want)
found := false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearCSRF(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
t.Error("clear csrf failed")
break
}
}
w = httptest.NewRecorder()
want = "new-session"
s.setSessionCookie(w, r, want)
found = false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearSession(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
t.Error("clear csrf failed")
break
}
}
}) })
} }
} }
func TestCookieStore_SaveSession(t *testing.T) { func TestCookieStore_SaveSession(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey()) cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -265,38 +216,6 @@ func TestCookieStore_SaveSession(t *testing.T) {
} }
} }
func TestMockCSRFStore(t *testing.T) {
tests := []struct {
name string
mockCSRF *MockCSRFStore
newCSRFValue string
wantErr bool
}{
{"basic",
&MockCSRFStore{
ResponseCSRF: "ok",
Cookie: &http.Cookie{Name: "hi"}},
"newcsrf",
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ms := tt.mockCSRF
ms.SetCSRF(nil, nil, tt.newCSRFValue)
ms.ClearCSRF(nil, nil)
got, err := ms.GetCSRF(nil)
if (err != nil) != tt.wantErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.mockCSRF.Cookie) {
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Cookie)
}
})
}
}
func TestMockSessionStore(t *testing.T) { func TestMockSessionStore(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -341,7 +260,7 @@ func TestMockSessionStore(t *testing.T) {
} }
} }
func Test_splitDomain(t *testing.T) { func Test_ParentSubdomain(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
s string s string
@ -354,8 +273,8 @@ func Test_splitDomain(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) { t.Run(tt.s, func(t *testing.T) {
if got := splitDomain(tt.s); got != tt.want { if got := ParentSubdomain(tt.s); got != tt.want {
t.Errorf("splitDomain() = %v, want %v", got, tt.want) t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
} }
}) })
} }

View file

@ -0,0 +1,130 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"context"
"errors"
"net/http"
"strings"
)
// Context keys
var (
SessionCtxKey = &contextKey{"Session"}
ErrorCtxKey = &contextKey{"Error"}
)
// Library errors
var (
ErrExpired = errors.New("internal/sessions: session is expired")
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
ErrMalformed = errors.New("internal/sessions: session is malformed")
)
// RetrieveSession http middleware handler will verify a auth session from a http request.
//
// RetrieveSession will search for a auth session in a http request, in the order:
// 1. `pomerium_session` URI query parameter
// 2. `Authorization: BEARER` request header
// 3. Cookie `_pomerium` value
func RetrieveSession(s SessionStore) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return retrieve(s, TokenFromQuery, TokenFromHeader, TokenFromCookie)(next)
}
}
func retrieve(s SessionStore, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
hfn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, err := retrieveFromRequest(s, r, findTokenFns...)
ctx = NewContext(ctx, token, err)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(hfn)
}
}
func retrieveFromRequest(s SessionStore, r *http.Request, findTokenFns ...func(r *http.Request) string) (*State, error) {
var tokenStr string
var err error
// Extract token string from the request by calling token find functions in
// the order they where provided. Further extraction stops if a function
// returns a non-empty string.
for _, fn := range findTokenFns {
tokenStr = fn(r)
if tokenStr != "" {
break
}
}
if tokenStr == "" {
return nil, ErrNoSessionFound
}
state, err := s.LoadSession(r)
if err != nil {
return nil, ErrMalformed
}
err = state.Valid()
if err != nil {
// a little unusual but we want to return the expired state too
return state, err
}
// Valid!
return state, nil
}
// NewContext sets context values for the user session state and error.
func NewContext(ctx context.Context, t *State, err error) context.Context {
ctx = context.WithValue(ctx, SessionCtxKey, t)
ctx = context.WithValue(ctx, ErrorCtxKey, err)
return ctx
}
// FromContext retrieves context values for the user session state and error.
func FromContext(ctx context.Context) (*State, error) {
state, _ := ctx.Value(SessionCtxKey).(*State)
err, _ := ctx.Value(ErrorCtxKey).(error)
return state, err
}
// TokenFromCookie tries to retrieve the token string from a cookie named
// "_pomerium".
func TokenFromCookie(r *http.Request) string {
cookie, err := r.Cookie("_pomerium")
if err != nil {
return ""
}
return cookie.Value
}
// TokenFromHeader tries to retrieve the token string from the
// "Authorization" request header: "Authorization: BEARER T".
func TokenFromHeader(r *http.Request) string {
// Get token from authorization header.
bearer := r.Header.Get("Authorization")
if len(bearer) > 7 && strings.EqualFold(bearer[0:6], "BEARER") {
return bearer[7:]
}
return ""
}
// TokenFromQuery tries to retrieve the token string from the "pomerium_session" URI
// query parameter.
// todo(bdd) : document setting session code as queryparam
func TokenFromQuery(r *http.Request) string {
// Get token from query param named "pomerium_session".
return r.URL.Query().Get("pomerium_session")
}
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation. This technique
// for defining context keys was copied from Go 1.7's new use of context in net/http.
type contextKey struct {
name string
}
func (k *contextKey) String() string {
return "SessionStore context value " + k.name
}

View file

@ -0,0 +1,133 @@
package sessions
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/google/go-cmp/cmp"
)
func TestNewContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
t *State
err error
want context.Context
}{
{"simple", context.Background(), &State{Email: "bdd@pomerium.io"}, nil, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctxOut := NewContext(tt.ctx, tt.t, tt.err)
stateOut, errOut := FromContext(ctxOut)
if diff := cmp.Diff(tt.t, stateOut); diff != "" {
t.Errorf("NewContext() = %s", diff)
}
if diff := cmp.Diff(tt.err, errOut); diff != "" {
t.Errorf("NewContext() = %s", diff)
}
})
}
}
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
// s SessionStore
state State
cookie bool
header bool
param bool
wantBody string
wantStatus int
}{
{"good cookie session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth header session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewCipherFromBase64(cryptutil.NewBase64Key())
if err != nil {
t.Fatal(err)
}
encSession, err := MarshalSession(&tt.state, cipher)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string
encSession += cryptutil.NewBase64Key()
fmt.Println(encSession)
}
cs, err := NewCookieStore(&CookieStoreOptions{
Name: "_pomerium",
CookieCipher: cipher,
})
if err != nil {
t.Fatal(err)
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
if tt.cookie {
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: encSession})
} else if tt.header {
r.Header.Set("Authorization", "Bearer "+encSession)
} else if tt.param {
q := r.URL.Query()
q.Add("pomerium_session", encSession)
r.URL.RawQuery = q.Encode()
}
got := RetrieveSession(cs)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}

View file

@ -4,28 +4,6 @@ import (
"net/http" "net/http"
) )
// MockCSRFStore is a mock implementation of the CSRF store interface
type MockCSRFStore struct {
ResponseCSRF string
Cookie *http.Cookie
GetError error
}
// SetCSRF sets the ResponseCSRF string to a val
func (ms MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
ms.ResponseCSRF = val
}
// ClearCSRF clears the ResponseCSRF string
func (ms MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
ms.ResponseCSRF = ""
}
// GetCSRF returns the cookie and error
func (ms MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
return ms.Cookie, ms.GetError
}
// MockSessionStore is a mock implementation of the SessionStore interface // MockSessionStore is a mock implementation of the SessionStore interface
type MockSessionStore struct { type MockSessionStore struct {
ResponseSession string ResponseSession string

View file

@ -10,9 +10,6 @@ import (
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
) )
// ErrExpired is an error for a expired sessions.
var ErrExpired = fmt.Errorf("internal/sessions: expired session")
// State is our object that keeps track of a user's session state // State is our object that keeps track of a user's session state
type State struct { type State struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`

View file

@ -12,7 +12,7 @@ import (
) )
func TestStateSerialization(t *testing.T) { func TestStateSerialization(t *testing.T) {
secret := cryptutil.GenerateKey() secret := cryptutil.NewKey()
c, err := cryptutil.NewCipher(secret) c, err := cryptutil.NewCipher(secret)
if err != nil { if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err) t.Fatalf("expected to be able to create cipher: %v", err)
@ -123,7 +123,7 @@ func TestState_Impersonating(t *testing.T) {
} }
func TestMarshalSession(t *testing.T) { func TestMarshalSession(t *testing.T) {
secret := cryptutil.GenerateKey() secret := cryptutil.NewKey()
c, err := cryptutil.NewCipher(secret) c, err := cryptutil.NewCipher(secret)
if err != nil { if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err) t.Fatalf("expected to be able to create cipher: %v", err)

View file

@ -8,16 +8,6 @@ import (
// ErrEmptySession is an error for an empty sessions. // ErrEmptySession is an error for an empty sessions.
var ErrEmptySession = errors.New("internal/sessions: empty session") var ErrEmptySession = errors.New("internal/sessions: empty session")
// ErrEmptyCSRF is an error for an empty sessions.
var ErrEmptyCSRF = errors.New("internal/sessions: empty csrf")
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
type CSRFStore interface {
SetCSRF(http.ResponseWriter, *http.Request, string)
GetCSRF(*http.Request) (*http.Cookie, error)
ClearCSRF(http.ResponseWriter, *http.Request)
}
// SessionStore has the functions for setting, getting, and clearing the Session cookie // SessionStore has the functions for setting, getting, and clearing the Session cookie
type SessionStore interface { type SessionStore interface {
ClearSession(http.ResponseWriter, *http.Request) ClearSession(http.ResponseWriter, *http.Request)

View file

@ -306,7 +306,7 @@ func New() *template.Template {
</svg> </svg>
<form method="POST" action="{{.SignoutURL}}"> <form method="POST" action="{{.SignoutURL}}">
<section> <section>
<h2>Session</h2> <h2>Current user</h2>
<p class="message">Your current session details.</p> <p class="message">Your current session details.</p>
<fieldset> <fieldset>
<label> <label>
@ -334,10 +334,22 @@ func New() *template.Template {
</fieldset> </fieldset>
</section> </section>
<div class="flex"> <div class="flex">
<button class="button half" type="submit">Sign Out</button> {{ .csrfField }}
<a href="/.pomerium/refresh" class="button half">Refresh</a> <button class="button full" type="submit">Sign Out</button>
</div> </div>
</form> </form>
<section>
<h2>Refresh Identity</h2>
<p class="message">Pomerium will automatically refresh your user session. However, if your group memberships have recently changed and haven't taken effect yet, you can refresh your session manually.</p>
<form method="POST" action="/.pomerium/refresh">
<div class="flex">
{{ .csrfField }}
<button class="button full" type="submit">Refresh</button>
</div>
</form>
</section>
{{if .IsAdmin}} {{if .IsAdmin}}
<form method="POST" action="/.pomerium/impersonate"> <form method="POST" action="/.pomerium/impersonate">
<section> <section>
@ -355,7 +367,7 @@ func New() *template.Template {
</fieldset> </fieldset>
</section> </section>
<div class="flex"> <div class="flex">
<input name="csrf" type="hidden" value="{{.CSRF}}"> {{ .csrfField }}
<button class="button full" type="submit">Impersonate session</button> <button class="button full" type="submit">Impersonate session</button>
</div> </div>
</form> </form>

View file

@ -1,9 +1,14 @@
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil" package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"strings" "strings"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
) )
// StripPort returns a host, without any port number. // StripPort returns a host, without any port number.
@ -32,18 +37,73 @@ func ParseAndValidateURL(rawurl string) (*url.URL, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if u.Scheme == "" { if err := ValidateURL(u); err != nil {
return nil, fmt.Errorf("%s url does contain a valid scheme. Did you mean https://%s?", rawurl, rawurl) return nil, err
}
if u.Host == "" {
return nil, fmt.Errorf("%s url does contain a valid hostname", rawurl)
} }
return u, nil return u, nil
} }
// ValidateURL wraps standard library's default url.Parse because
// it's much more lenient about what type of urls it accepts than pomerium.
func ValidateURL(u *url.URL) error {
if u == nil {
return fmt.Errorf("nil url")
}
if u.Scheme == "" {
return fmt.Errorf("%s url does contain a valid scheme. Did you mean https://%s?", u.String(), u.String())
}
if u.Host == "" {
return fmt.Errorf("%s url does contain a valid hostname", u.String())
}
return nil
}
func DeepCopy(u *url.URL) (*url.URL, error) { func DeepCopy(u *url.URL) (*url.URL, error) {
if u == nil { if u == nil {
return nil, nil return nil, nil
} }
return ParseAndValidateURL(u.String()) return ParseAndValidateURL(u.String())
} }
// testTimeNow can be used in tests to set a specific int64 time
var testTimeNow int64
// timestamp returns the current timestamp, in seconds.
//
// For testing purposes, the function that generates the timestamp can be
// overridden. If not set, it will return time.Now().UTC().Unix().
func timestamp() int64 {
if testTimeNow == 0 {
return time.Now().UTC().Unix()
}
return testTimeNow
}
// SignedRedirectURL takes a destination URL and adds redirect_uri to it's
// query params, along with a timestamp and an keyed signature.
func SignedRedirectURL(key string, destination, urlToSign *url.URL) *url.URL {
now := timestamp()
rawURL := urlToSign.String()
params, _ := url.ParseQuery(destination.RawQuery) // handled by incoming mux
params.Set("redirect_uri", rawURL)
params.Set("ts", fmt.Sprint(now))
params.Set("sig", hmacURL(key, rawURL, now))
destination.RawQuery = params.Encode()
return destination
}
// hmacURL takes a redirect url string and timestamp and returns the base64
// encoded HMAC result.
func hmacURL(key, data string, timestamp int64) string {
h := cryptutil.Hash(key, []byte(fmt.Sprint(data, timestamp)))
return base64.URLEncoding.EncodeToString(h)
}
// GetAbsoluteURL returns the current handler's absolute url.
// https://stackoverflow.com/a/23152483
func GetAbsoluteURL(r *http.Request) *url.URL {
u := r.URL
u.Scheme = "https"
u.Host = r.Host
return u
}

View file

@ -1,6 +1,7 @@
package urlutil package urlutil
import ( import (
"net/http"
"net/url" "net/url"
"reflect" "reflect"
"testing" "testing"
@ -35,7 +36,7 @@ func Test_StripPort(t *testing.T) {
} }
func TestParseAndValidateURL(t *testing.T) { func TestParseAndValidateURL(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
rawurl string rawurl string
@ -63,7 +64,7 @@ func TestParseAndValidateURL(t *testing.T) {
} }
func TestDeepCopy(t *testing.T) { func TestDeepCopy(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
u *url.URL u *url.URL
@ -87,3 +88,90 @@ func TestDeepCopy(t *testing.T) {
}) })
} }
} }
func TestValidateURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
u *url.URL
wantErr bool
}{
{"good", &url.URL{Scheme: "https", Host: "some.example"}, false},
{"nil", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ValidateURL(tt.u); (err != nil) != tt.wantErr {
t.Errorf("ValidateURL() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestSignedRedirectURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mockedTime int64
key string
destination *url.URL
urlToSign *url.URL
want *url.URL
}{
{"good", 2, "hunter42", &url.URL{Host: "pomerium.io", Scheme: "https://"}, &url.URL{Host: "pomerium.io", Scheme: "https://", Path: "/ok"}, &url.URL{Host: "pomerium.io", Scheme: "https://", RawQuery: "redirect_uri=https%3A%2F%2F%3A%2F%2Fpomerium.io%2Fok&sig=7jdo1XFcmuhjBHnpfVhll5cXflYByeMnbp5kRz87CVQ%3D&ts=2"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testTimeNow = tt.mockedTime
got := SignedRedirectURL(tt.key, tt.destination, tt.urlToSign)
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("SignedRedirectURL() = diff %v", diff)
}
})
}
}
func Test_timestamp(t *testing.T) {
t.Parallel()
tests := []struct {
name string
dontWant int64
}{
{"if unset should never return", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testTimeNow = tt.dontWant
if got := timestamp(); got == tt.dontWant {
t.Errorf("timestamp() = %v, dontWant %v", got, tt.dontWant)
}
})
}
}
func parseURLHelper(s string) *url.URL {
u, _ := url.Parse(s)
return u
}
func TestGetAbsoluteURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
u *url.URL
want *url.URL
}{
{"add https", parseURLHelper("http://pomerium.io"), parseURLHelper("https://pomerium.io")},
{"missing scheme", parseURLHelper("https://pomerium.io"), parseURLHelper("https://pomerium.io")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := http.Request{URL: tt.u, Host: tt.u.Host}
got := GetAbsoluteURL(&r)
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("GetAbsoluteURL() = %v", diff)
}
})
}
}

View file

@ -1,15 +1,15 @@
package proxy // import "github.com/pomerium/pomerium/proxy" package proxy // import "github.com/pomerium/pomerium/proxy"
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/pomerium/csrf"
"github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
@ -18,34 +18,55 @@ import (
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
) )
// StateParameter holds the redirect id along with the session id.
type StateParameter struct {
SessionID string `json:"session_id"`
RedirectURI string `json:"redirect_uri"`
}
// Handler returns the proxy service's ServeMux // Handler returns the proxy service's ServeMux
func (p *Proxy) Handler() http.Handler { func (p *Proxy) Handler() http.Handler {
// validation middleware chain r := httputil.NewRouter().StrictSlash(true)
validate := middleware.NewChain() r.Use(middleware.ValidateHost(func(host string) bool {
validate = validate.Append(middleware.ValidateHost(func(host string) bool {
_, ok := p.routeConfigs[host] _, ok := p.routeConfigs[host]
return ok return ok
})) }))
mux := http.NewServeMux() r.Use(csrf.Protect(
mux.HandleFunc("/robots.txt", p.RobotsTxt) p.cookieSecret,
mux.HandleFunc("/.pomerium", p.UserDashboard) csrf.Path("/"),
mux.HandleFunc("/.pomerium/impersonate", p.Impersonate) // POST csrf.Domain(p.cookieDomain),
mux.HandleFunc("/.pomerium/sign_out", p.SignOut) csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieName)),
// handlers with validation csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
mux.Handle("/.pomerium/callback", validate.ThenFunc(p.AuthenticateCallback)) ))
mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.ForceRefresh)) r.HandleFunc("/robots.txt", p.RobotsTxt)
mux.Handle("/", validate.ThenFunc(p.Proxy)) // requires authN not authZ
return mux r.Use(sessions.RetrieveSession(p.sessionStore))
r.Use(p.VerifySession)
r.HandleFunc("/.pomerium/", p.UserDashboard).Methods(http.MethodGet)
r.HandleFunc("/.pomerium/impersonate", p.Impersonate).Methods(http.MethodPost)
r.HandleFunc("/.pomerium/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost)
r.HandleFunc("/.pomerium/refresh", p.ForceRefresh).Methods(http.MethodPost)
r.PathPrefix("/").HandlerFunc(p.Proxy)
return r
}
// VerifySession is the middleware used to enforce a valid authentication
// session state is attached to the users's request context.
func (p *Proxy) VerifySession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
state, err := sessions.FromContext(r.Context())
if err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error")
p.authenticate(w, r)
return
}
if err := state.Valid(); err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to invalid session")
p.authenticate(w, r)
return
}
next.ServeHTTP(w, r)
})
} }
// RobotsTxt sets the User-Agent header in the response to be "Disallow" // RobotsTxt sets the User-Agent header in the response to be "Disallow"
func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) { func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "User-agent: *\nDisallow: /") fmt.Fprintf(w, "User-agent: *\nDisallow: /")
} }
@ -55,110 +76,18 @@ func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
// the local session state. // the local session state.
func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) { func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"} redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
switch r.Method { if uri, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")); err == nil && uri.String() != "" {
case http.MethodPost: redirectURL = uri
if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
uri, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
if err == nil && uri.String() != "" {
redirectURL = uri
}
default:
uri, err := urlutil.ParseAndValidateURL(r.URL.Query().Get("redirect_uri"))
if err == nil && uri.String() != "" {
redirectURL = uri
}
} }
http.Redirect(w, r, p.GetSignOutURL(p.authenticateURL, redirectURL).String(), http.StatusFound) uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
http.Redirect(w, r, uri.String(), http.StatusFound)
} }
// OAuthStart begins the authenticate flow, encrypting the redirect url // Authenticate begins the authenticate flow, encrypting the redirect url
// in a request to the provider's sign in endpoint. // in a request to the provider's sign in endpoint.
func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) { func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request) {
state := &StateParameter{ uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
SessionID: fmt.Sprintf("%x", cryptutil.GenerateKey()), http.Redirect(w, r, uri.String(), http.StatusFound)
RedirectURI: r.URL.String(),
}
// Encrypt CSRF + redirect_uri and store in csrf session. Validated on callback.
csrfState, err := p.cipher.Marshal(state)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
p.csrfStore.SetCSRF(w, r, csrfState)
paramState, err := p.cipher.Marshal(state)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
// Sanity check. The encrypted payload of local and remote state should
// never match as each encryption round uses a cryptographic nonce.
// if paramState == csrfState {
// httputil.ErrorResponse(w, r, httputil.Error("encrypted state should not match", http.StatusBadRequest, nil))
// return
// }
signinURL := p.GetSignInURL(p.authenticateURL, p.GetRedirectURL(r.Host), paramState)
// Redirect the user to the authenticate service along with the encrypted
// state which contains a redirect uri back to the proxy and a nonce
http.Redirect(w, r, signinURL.String(), http.StatusFound)
}
// AuthenticateCallback checks the state parameter to make sure it matches the
// local csrf state then redirects the user back to the original intended route.
func (p *Proxy) AuthenticateCallback(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
// Encrypted CSRF passed from authenticate service
remoteStateEncrypted := r.Form.Get("state")
var remoteStatePlain StateParameter
if err := p.cipher.Unmarshal(remoteStateEncrypted, &remoteStatePlain); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
c, err := p.csrfStore.GetCSRF(r)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
p.csrfStore.ClearCSRF(w, r)
localStateEncrypted := c.Value
var localStatePlain StateParameter
err = p.cipher.Unmarshal(localStateEncrypted, &localStatePlain)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
// assert no nonce reuse
if remoteStateEncrypted == localStateEncrypted {
p.sessionStore.ClearSession(w, r)
httputil.ErrorResponse(w, r,
httputil.Error("local and remote state", http.StatusBadRequest,
fmt.Errorf("possible nonce-reuse / replay attack")))
return
}
// Decrypted remote and local state struct (inc. nonce) must match
if remoteStatePlain.SessionID != localStatePlain.SessionID {
p.sessionStore.ClearSession(w, r)
httputil.ErrorResponse(w, r, httputil.Error("CSRF mismatch", http.StatusBadRequest, nil))
return
}
// This is the redirect back to the original requested application
http.Redirect(w, r, remoteStatePlain.RedirectURI, http.StatusFound)
} }
// shouldSkipAuthentication contains conditions for skipping authentication. // shouldSkipAuthentication contains conditions for skipping authentication.
@ -189,17 +118,6 @@ func isCORSPreflight(r *http.Request) bool {
r.Header.Get("Origin") != "" r.Header.Get("Origin") != ""
} }
func (p *Proxy) loadExistingSession(r *http.Request) (*sessions.State, error) {
s, err := p.sessionStore.LoadSession(r)
if err != nil {
return nil, fmt.Errorf("proxy: invalid session: %w", err)
}
if err := s.Valid(); err != nil {
return nil, fmt.Errorf("proxy: invalid state: %w", err)
}
return s, nil
}
// Proxy authenticates a request, either proxying the request if it is authenticated, // Proxy authenticates a request, either proxying the request if it is authenticated,
// or starting the authenticate service for validation if not. // or starting the authenticate service for validation if not.
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
@ -214,11 +132,10 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
route.ServeHTTP(w, r) route.ServeHTTP(w, r)
return return
} }
s, err := sessions.FromContext(r.Context())
s, err := p.loadExistingSession(r) if err != nil || s == nil {
if err != nil { log.Debug().Err(err).Msg("proxy: couldn't get session from context")
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting") p.authenticate(w, r)
p.OAuthStart(w, r)
return return
} }
authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s) authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s)
@ -226,7 +143,7 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
} else if !authorized { } else if !authorized {
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.Email), http.StatusForbidden, nil)) httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.RequestEmail()), http.StatusForbidden, nil))
return return
} }
r.Header.Set(HeaderUserID, s.User) r.Header.Set(HeaderUserID, s.User)
@ -240,62 +157,41 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
// It also contains certain administrative actions like user impersonation. // It also contains certain administrative actions like user impersonation.
// Nota bene: This endpoint does authentication, not authorization. // Nota bene: This endpoint does authentication, not authorization.
func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) { func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
session, err := p.loadExistingSession(r) session, err := sessions.FromContext(r.Context())
if err != nil { if err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting") httputil.ErrorResponse(w, r, err)
p.OAuthStart(w, r)
return return
} }
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/.pomerium/sign_out"}
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
} }
//todo(bdd): make sign out redirect a configuration option so that
// CSRF value used to mitigate replay attacks. // admins can set to whatever their corporate homepage is
csrf := &StateParameter{SessionID: fmt.Sprintf("%x", cryptutil.GenerateKey())} redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
csrfCookie, err := p.cipher.Marshal(csrf) signoutURL := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
if err != nil { templates.New().ExecuteTemplate(w, "dashboard.html", map[string]interface{}{
httputil.ErrorResponse(w, r, err) "Email": session.Email,
return "User": session.User,
} "Groups": session.Groups,
p.csrfStore.SetCSRF(w, r, csrfCookie) "RefreshDeadline": time.Until(session.RefreshDeadline).Round(time.Second).String(),
"SignoutURL": signoutURL.String(),
t := struct { "IsAdmin": isAdmin,
Email string "ImpersonateEmail": session.ImpersonateEmail,
User string "ImpersonateGroup": strings.Join(session.ImpersonateGroups, ","),
Groups []string "csrfField": csrf.TemplateField(r),
RefreshDeadline string })
SignoutURL string
IsAdmin bool
ImpersonateEmail string
ImpersonateGroup string
CSRF string
}{
Email: session.Email,
User: session.User,
Groups: session.Groups,
RefreshDeadline: time.Until(session.RefreshDeadline).Round(time.Second).String(),
SignoutURL: p.GetSignOutURL(p.authenticateURL, redirectURL).String(),
IsAdmin: isAdmin,
ImpersonateEmail: session.ImpersonateEmail,
ImpersonateGroup: strings.Join(session.ImpersonateGroups, ","),
CSRF: csrf.SessionID,
}
templates.New().ExecuteTemplate(w, "dashboard.html", t)
} }
// ForceRefresh redeems and extends an existing authenticated oidc session with // ForceRefresh redeems and extends an existing authenticated oidc session with
// the underlying identity provider. All session details including groups, // the underlying identity provider. All session details including groups,
// timeouts, will be renewed. // timeouts, will be renewed.
func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) { func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
session, err := p.loadExistingSession(r) session, err := sessions.FromContext(r.Context())
if err != nil { if err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting") httputil.ErrorResponse(w, r, err)
p.OAuthStart(w, r)
return return
} }
iss, err := session.IssuedAt() iss, err := session.IssuedAt()
@ -324,49 +220,25 @@ func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
// to the user's current user sessions state if the user is currently an // to the user's current user sessions state if the user is currently an
// administrative user. Requests are redirected back to the user dashboard. // administrative user. Requests are redirected back to the user dashboard.
func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) { func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost { session, err := sessions.FromContext(r.Context())
if err := r.ParseForm(); err != nil { if err != nil {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
}
session, err := p.loadExistingSession(r)
if err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
p.OAuthStart(w, r)
return
}
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
if err != nil || !isAdmin {
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.Email), http.StatusForbidden, err))
return
}
// CSRF check -- did this request originate from our form?
c, err := p.csrfStore.GetCSRF(r)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
p.csrfStore.ClearCSRF(w, r)
encryptedCSRF := c.Value
var decryptedCSRF StateParameter
if err = p.cipher.Unmarshal(encryptedCSRF, decryptedCSRF); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
if decryptedCSRF.SessionID != r.FormValue("csrf") {
httputil.ErrorResponse(w, r, httputil.Error("CSRF mismatch", http.StatusBadRequest, nil))
return
}
// OK to impersonation
session.ImpersonateEmail = r.FormValue("email")
session.ImpersonateGroups = strings.Split(r.FormValue("group"), ",")
if err := p.sessionStore.SaveSession(w, r, session); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
} }
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
if err != nil || !isAdmin {
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.RequestEmail()), http.StatusForbidden, err))
return
}
// OK to impersonation
session.ImpersonateEmail = r.FormValue("email")
session.ImpersonateGroups = strings.Split(r.FormValue("group"), ",")
if err := p.sessionStore.SaveSession(w, r, session); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
http.Redirect(w, r, "/.pomerium", http.StatusFound) http.Redirect(w, r, "/.pomerium", http.StatusFound)
} }
@ -391,48 +263,3 @@ func (p *Proxy) policy(r *http.Request) (*config.Policy, bool) {
} }
return nil, false return nil, false
} }
// GetRedirectURL returns the redirect url for a single reverse proxy host. HTTPS is set explicitly.
func (p *Proxy) GetRedirectURL(host string) *url.URL {
u := p.redirectURL
u.Scheme = "https"
u.Host = host
return u
}
// signRedirectURL takes a redirect url string and timestamp and returns the base64
// encoded HMAC result.
func (p *Proxy) signRedirectURL(rawRedirect string, timestamp time.Time) string {
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
h := cryptutil.Hash(p.SharedKey, data)
return base64.URLEncoding.EncodeToString(h)
}
// GetSignInURL with typical oauth parameters
func (p *Proxy) GetSignInURL(authenticateURL, redirectURL *url.URL, state string) *url.URL {
a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_in"})
now := time.Now()
rawRedirect := redirectURL.String()
params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux
params.Set("redirect_uri", rawRedirect)
params.Set("shared_secret", p.SharedKey)
params.Set("response_type", "code")
params.Add("state", state)
params.Set("ts", fmt.Sprint(now.Unix()))
params.Set("sig", p.signRedirectURL(rawRedirect, now))
a.RawQuery = params.Encode()
return a
}
// GetSignOutURL creates and returns the sign out URL, given a redirectURL
func (p *Proxy) GetSignOutURL(authenticateURL, redirectURL *url.URL) *url.URL {
a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_out"})
now := time.Now()
rawRedirect := redirectURL.String()
params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux
params.Add("redirect_uri", rawRedirect)
params.Set("ts", fmt.Sprint(now.Unix()))
params.Set("sig", p.signRedirectURL(rawRedirect, now))
a.RawQuery = params.Encode()
return a
}

View file

@ -7,45 +7,17 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/proxy/clients" "github.com/pomerium/pomerium/proxy/clients"
) )
type mockCipher struct{}
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Marshal(s interface{}) (string, error) {
if s == "error" {
return "", errors.New("error")
}
return "ok", nil
}
func (a mockCipher) Unmarshal(s string, i interface{}) error {
if s == "unmarshal error" || s == "error" {
return errors.New("error")
}
return nil
}
func TestProxy_RobotsTxt(t *testing.T) { func TestProxy_RobotsTxt(t *testing.T) {
proxy := Proxy{} proxy := Proxy{}
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil) req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
@ -60,94 +32,6 @@ func TestProxy_RobotsTxt(t *testing.T) {
} }
} }
func TestProxy_GetRedirectURL(t *testing.T) {
tests := []struct {
name string
host string
want *url.URL
}{
{"google", "google.com", &url.URL{Scheme: "https", Host: "google.com", Path: "/.pomerium/callback"}},
{"pomerium", "pomerium.io", &url.URL{Scheme: "https", Host: "pomerium.io", Path: "/.pomerium/callback"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Proxy{redirectURL: &url.URL{Path: "/.pomerium/callback"}}
if got := p.GetRedirectURL(tt.host); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Proxy.GetRedirectURL() = %v, want %v", got, tt.want)
}
})
}
}
func TestProxy_signRedirectURL(t *testing.T) {
tests := []struct {
name string
rawRedirect string
timestamp time.Time
want string
}{
{"pomerium", "https://pomerium.io/.pomerium/callback", fixedDate, "wq3rAjRGN96RXS8TAzH-uxQTD0XgY_8ZYEKMiOLD5P4="},
{"google", "https://google.com/.pomerium/callback", fixedDate, "7EYHZObq167CuyuPm5CqOtkU4zg5dFeUCs7W7QOrgNQ="},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Proxy{}
if got := p.signRedirectURL(tt.rawRedirect, tt.timestamp); got != tt.want {
t.Errorf("Proxy.signRedirectURL() = %v, want %v", got, tt.want)
}
})
}
}
func TestProxy_GetSignOutURL(t *testing.T) {
tests := []struct {
name string
authenticate string
redirect string
wantPrefix string
}{
{"good", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "https://auth.corp.pomerium.io/sign_out?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
authenticateURL, _ := url.Parse(tt.authenticate)
redirectURL, _ := url.Parse(tt.redirect)
p := &Proxy{}
// signature is ignored as it is tested above. Avoids testing time.Now
if got := p.GetSignOutURL(authenticateURL, redirectURL); !strings.HasPrefix(got.String(), tt.wantPrefix) {
t.Errorf("Proxy.GetSignOutURL() = %v, wantPrefix %v", got.String(), tt.wantPrefix)
}
})
}
}
func TestProxy_GetSignInURL(t *testing.T) {
tests := []struct {
name string
authenticate string
redirect string
state string
wantPrefix string
}{
{"good", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "example_state", "https://auth.corp.pomerium.io/sign_in?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io&response_type=code&shared_secret=shared-secret"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Proxy{SharedKey: "shared-secret"}
authenticateURL, _ := url.Parse(tt.authenticate)
redirectURL, _ := url.Parse(tt.redirect)
if got := p.GetSignInURL(authenticateURL, redirectURL, tt.state); !strings.HasPrefix(got.String(), tt.wantPrefix) {
t.Errorf("Proxy.GetSignOutURL() = %v, wantPrefix %v", got.String(), tt.wantPrefix)
}
})
}
}
func TestProxy_Signout(t *testing.T) { func TestProxy_Signout(t *testing.T) {
opts := testOptions(t) opts := testOptions(t)
err := ValidateOptions(opts) err := ValidateOptions(opts)
@ -171,7 +55,7 @@ func TestProxy_Signout(t *testing.T) {
} }
} }
func TestProxy_OAuthStart(t *testing.T) { func TestProxy_authenticate(t *testing.T) {
proxy, err := New(testOptions(t)) proxy, err := New(testOptions(t))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -179,18 +63,19 @@ func TestProxy_OAuthStart(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/oauth-start", nil) req := httptest.NewRequest(http.MethodGet, "/oauth-start", nil)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
proxy.OAuthStart(rr, req) proxy.authenticate(rr, req)
// expect oauth redirect // expect oauth redirect
if status := rr.Code; status != http.StatusFound { if status := rr.Code; status != http.StatusFound {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound) t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
} }
// expected url // expected url
expected := `<a href="https://authenticate.example/sign_in` expected := `<a href="https://authenticate.example/.pomerium/sign_in`
body := rr.Body.String() body := rr.Body.String()
if !strings.HasPrefix(body, expected) { if !strings.HasPrefix(body, expected) {
t.Errorf("handler returned unexpected body: got %v want %v", body, expected) t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
} }
} }
func TestProxy_Handler(t *testing.T) { func TestProxy_Handler(t *testing.T) {
proxy, err := New(testOptions(t)) proxy, err := New(testOptions(t))
if err != nil { if err != nil {
@ -226,7 +111,6 @@ func TestProxy_router(t *testing.T) {
{"good corp", "https://corp.example.com", policies, nil, true}, {"good corp", "https://corp.example.com", policies, nil, true},
{"good with slash", "https://corp.example.com/", policies, nil, true}, {"good with slash", "https://corp.example.com/", policies, nil, true},
{"good with path", "https://corp.example.com/123", policies, nil, true}, {"good with path", "https://corp.example.com/123", policies, nil, true},
// {"multiple", "https://corp.example.com/", map[string]string{"corp.unrelated.com": "unrelated.com", "corp.example.com": "example.com"}, nil, true},
{"no policies", "https://notcorp.example.com/123", []config.Policy{}, nil, false}, {"no policies", "https://notcorp.example.com/123", []config.Policy{}, nil, false},
{"bad corp", "https://notcorp.example.com/123", policies, nil, false}, {"bad corp", "https://notcorp.example.com/123", policies, nil, false},
{"bad sub-sub", "https://notcorp.corp.example.com/123", policies, nil, false}, {"bad sub-sub", "https://notcorp.corp.example.com/123", policies, nil, false},
@ -254,11 +138,15 @@ func TestProxy_Proxy(t *testing.T) {
goodSession := &sessions.State{ goodSession := &sessions.State{
AccessToken: "AccessToken", AccessToken: "AccessToken",
RefreshToken: "RefreshToken", RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second), RefreshDeadline: time.Now().Add(20 * time.Second),
} }
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprintln(w, "RVSI FILIVS CAISAR") fmt.Fprintln(w, "RVSI FILIVS CAISAR")
w.WriteHeader(http.StatusOK)
})) }))
defer ts.Close() defer ts.Close()
@ -285,25 +173,24 @@ func TestProxy_Proxy(t *testing.T) {
authorizer clients.Authorizer authorizer clients.Authorizer
wantStatus int wantStatus int
}{ }{
{"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, {"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(20 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK}, {"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
{"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, {"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, {"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
// same request as above, but with cors_allow_preflight=false in the policy // same request as above, but with cors_allow_preflight=false in the policy
{"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, {"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
// cors allowed, but the request is missing proper headers // cors allowed, but the request is missing proper headers
{"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, {"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
// redirect to start auth process // redirect to start auth process
{"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, {"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
{"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, {"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
{"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError}, {"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError},
// authenticate errors // authenticate errors
{"session error, redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: errors.New("weird"), Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound}, {"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound}, {"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
{"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
{"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound}, {"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound},
{"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound}, {"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, {"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
} }
for _, tt := range tests { for _, tt := range tests {
@ -319,18 +206,17 @@ func TestProxy_Proxy(t *testing.T) {
p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"} p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"}
p.sessionStore = tt.session p.sessionStore = tt.session
p.AuthorizeClient = tt.authorizer p.AuthorizeClient = tt.authorizer
r := httptest.NewRequest(tt.method, tt.host, nil) r := httptest.NewRequest(tt.method, tt.host, nil)
r.Header = tt.header r.Header = tt.header
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, nil)
r = r.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.Proxy(w, r) p.Proxy(w, r)
if status := w.Code; status != tt.wantStatus { if status := w.Code; status != tt.wantStatus {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", w.Body.String())
t.Errorf("\n%+v", opts)
t.Errorf("\n%+v", ts.URL)
t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String()) t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String())
} }
@ -342,6 +228,7 @@ func TestProxy_UserDashboard(t *testing.T) {
opts := testOptions(t) opts := testOptions(t)
tests := []struct { tests := []struct {
name string name string
ctxError error
options config.Options options config.Options
method string method string
cipher cryptutil.Cipher cipher cryptutil.Cipher
@ -351,11 +238,10 @@ func TestProxy_UserDashboard(t *testing.T) {
wantAdminForm bool wantAdminForm bool
wantStatus int wantStatus int
}{ }{
{"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK}, {"good", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK},
{"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthorize{}, false, http.StatusFound}, {"session context error", errors.New("error"), opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
{"can't save csrf", opts, http.MethodGet, &cryptutil.MockCipher{MarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{}, false, http.StatusInternalServerError}, {"want admin form good admin authorization", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
{"want admin form good admin authorization", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK}, {"is admin but authorization fails", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
{"is admin but authorization fails", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
} }
for _, tt := range tests { for _, tt := range tests {
@ -369,6 +255,11 @@ func TestProxy_UserDashboard(t *testing.T) {
p.AuthorizeClient = tt.authorizer p.AuthorizeClient = tt.authorizer
r := httptest.NewRequest(tt.method, "/", nil) r := httptest.NewRequest(tt.method, "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -395,6 +286,7 @@ func TestProxy_ForceRefresh(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
ctxError error
options config.Options options config.Options
method string method string
cipher cryptutil.Cipher cipher cryptutil.Cipher
@ -402,12 +294,12 @@ func TestProxy_ForceRefresh(t *testing.T) {
authorizer clients.Authorizer authorizer clients.Authorizer
wantStatus int wantStatus int
}{ }{
{"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound}, {"good", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
{"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthorize{}, http.StatusFound}, {"cannot load session", errors.New("error"), opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
{"bad id token", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError}, {"bad id token", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
{"issue date too soon", timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest}, {"issue date too soon", nil, timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest},
{"refresh failure", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound}, {"refresh failure", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
{"can't save refreshed session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound}, {"can't save refreshed session", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -420,6 +312,12 @@ func TestProxy_ForceRefresh(t *testing.T) {
p.AuthorizeClient = tt.authorizer p.AuthorizeClient = tt.authorizer
r := httptest.NewRequest(tt.method, "/", nil) r := httptest.NewRequest(tt.method, "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.ForceRefresh(w, r) p.ForceRefresh(w, r)
if status := w.Code; status != tt.wantStatus { if status := w.Code; status != tt.wantStatus {
@ -431,32 +329,29 @@ func TestProxy_ForceRefresh(t *testing.T) {
} }
func TestProxy_Impersonate(t *testing.T) { func TestProxy_Impersonate(t *testing.T) {
t.Parallel()
opts := testOptions(t) opts := testOptions(t)
tests := []struct { tests := []struct {
name string name string
malformed bool malformed bool
options config.Options options config.Options
ctxError error
method string method string
email string email string
groups string groups string
csrf string csrf string
cipher cryptutil.Cipher cipher cryptutil.Cipher
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
csrfStore sessions.CSRFStore
authorizer clients.Authorizer authorizer clients.Authorizer
wantStatus int wantStatus int
}{ }{
{"good", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, {"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"session load error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, {"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
// {"non admin users rejected", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden}, {"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"non admin users rejected on error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden}, {"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
{"csrf from store retrieve failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}, GetError: errors.New("err")}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
{"can't decrypt csrf value", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{UnmarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"save session failure", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"decrypted csrf mismatch", false, opts, http.MethodPost, "user@blah.com", "", "CSRF!", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest}, {"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"save session failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"malformed", true, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"groups", false, opts, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -466,17 +361,18 @@ func TestProxy_Impersonate(t *testing.T) {
} }
p.cipher = tt.cipher p.cipher = tt.cipher
p.sessionStore = tt.sessionStore p.sessionStore = tt.sessionStore
p.csrfStore = tt.csrfStore
p.AuthorizeClient = tt.authorizer p.AuthorizeClient = tt.authorizer
postForm := url.Values{} postForm := url.Values{}
postForm.Add("email", tt.email) postForm.Add("email", tt.email)
postForm.Add("group", tt.groups) postForm.Add("group", tt.groups)
postForm.Set("csrf", tt.csrf) postForm.Set("csrf", tt.csrf)
uri := &url.URL{Path: "/"} uri := &url.URL{Path: "/"}
if tt.malformed {
uri.RawQuery = "email=%zzzzz"
}
r := httptest.NewRequest(tt.method, uri.String(), bytes.NewBufferString(postForm.Encode())) r := httptest.NewRequest(tt.method, uri.String(), bytes.NewBufferString(postForm.Encode()))
state, _ := tt.sessionStore.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -489,50 +385,8 @@ func TestProxy_Impersonate(t *testing.T) {
} }
} }
func TestProxy_OAuthCallback(t *testing.T) {
tests := []struct {
name string
csrf sessions.MockCSRFStore
session sessions.MockSessionStore
params map[string]string
wantCode int
}{
{"good", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusFound},
{"state err", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "error"}, http.StatusInternalServerError},
{"csrf err", sessions.MockCSRFStore{GetError: errors.New("error")}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
{"unmarshal err", sessions.MockCSRFStore{Cookie: &http.Cookie{Name: "something_csrf", Value: "unmarshal error"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
{"malformed", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy, err := New(testOptions(t))
if err != nil {
t.Fatal(err)
}
proxy.sessionStore = &tt.session
proxy.csrfStore = tt.csrf
proxy.cipher = mockCipher{}
// proxy.Csrf
req := httptest.NewRequest(http.MethodPost, "/.pomerium/callback", nil)
q := req.URL.Query()
for k, v := range tt.params {
q.Add(k, v)
}
req.URL.RawQuery = q.Encode()
if tt.name == "malformed" {
req.URL.RawQuery = "email=%zzzzz"
}
w := httptest.NewRecorder()
proxy.AuthenticateCallback(w, req)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
}
})
}
}
func TestProxy_SignOut(t *testing.T) { func TestProxy_SignOut(t *testing.T) {
t.Parallel()
tests := []struct { tests := []struct {
name string name string
verb string verb string
@ -542,7 +396,6 @@ func TestProxy_SignOut(t *testing.T) {
{"good post", http.MethodPost, "https://test.example", http.StatusFound}, {"good post", http.MethodPost, "https://test.example", http.StatusFound},
{"good get", http.MethodGet, "https://test.example", http.StatusFound}, {"good get", http.MethodGet, "https://test.example", http.StatusFound},
{"good empty default", http.MethodGet, "", http.StatusFound}, {"good empty default", http.MethodGet, "", http.StatusFound},
{"malformed", http.MethodPost, "", http.StatusInternalServerError},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -554,9 +407,6 @@ func TestProxy_SignOut(t *testing.T) {
postForm := url.Values{} postForm := url.Values{}
postForm.Add("redirect_uri", tt.redirectURL) postForm.Add("redirect_uri", tt.redirectURL)
uri := &url.URL{Path: "/"} uri := &url.URL{Path: "/"}
if tt.name == "malformed" {
uri.RawQuery = "redirect_uri=%zzzzz"
}
query, _ := url.ParseQuery(uri.RawQuery) query, _ := url.ParseQuery(uri.RawQuery)
if tt.verb == http.MethodGet { if tt.verb == http.MethodGet {
@ -576,3 +426,56 @@ func TestProxy_SignOut(t *testing.T) {
}) })
} }
} }
func uriParseHelper(s string) *url.URL {
uri, _ := url.Parse(s)
return uri
}
func TestProxy_VerifySession(t *testing.T) {
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
session sessions.SessionStore
ctxError error
provider identity.Authenticator
wantStatus int
}{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := Proxy{
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
sessionStore: tt.session,
}
r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := a.VerifySession(fn)
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
})
}
}

View file

@ -2,6 +2,7 @@ package proxy // import "github.com/pomerium/pomerium/proxy"
import ( import (
"crypto/tls" "crypto/tls"
"encoding/base64"
"fmt" "fmt"
"html/template" "html/template"
stdlog "log" stdlog "log"
@ -12,11 +13,11 @@ import (
"github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
pom_httputil "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/templates" "github.com/pomerium/pomerium/internal/templates"
"github.com/pomerium/pomerium/internal/tripper" "github.com/pomerium/pomerium/internal/tripper"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
@ -32,6 +33,11 @@ const (
HeaderEmail = "x-pomerium-authenticated-user-email" HeaderEmail = "x-pomerium-authenticated-user-email"
// HeaderGroups is the header key containing the user's groups. // HeaderGroups is the header key containing the user's groups.
HeaderGroups = "x-pomerium-authenticated-user-groups" HeaderGroups = "x-pomerium-authenticated-user-groups"
// signinURL is the path to authenticate's sign in endpoint
signinURL = "/.pomerium/sign_in"
// signoutURL is the path to authenticate's sign out endpoint
signoutURL = "/.pomerium/sign_out"
) )
// ValidateOptions checks that proper configuration settings are set to create // ValidateOptions checks that proper configuration settings are set to create
@ -40,23 +46,21 @@ func ValidateOptions(o config.Options) error {
if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil { if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil {
return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %v", err) return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %v", err)
} }
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil { if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
return fmt.Errorf("proxy: invalid 'COOKIE_SECRET': %v", err) return fmt.Errorf("proxy: invalid 'COOKIE_SECRET': %v", err)
} }
if o.AuthenticateURL == nil {
return fmt.Errorf("proxy: missing 'AUTHENTICATE_SERVICE_URL'") if err := urlutil.ValidateURL(o.AuthenticateURL); err != nil {
}
if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil {
return fmt.Errorf("proxy: invalid 'AUTHENTICATE_SERVICE_URL': %v", err) return fmt.Errorf("proxy: invalid 'AUTHENTICATE_SERVICE_URL': %v", err)
} }
if o.AuthorizeURL == nil {
return fmt.Errorf("proxy: missing 'AUTHORIZE_SERVICE_URL'") if err := urlutil.ValidateURL(o.AuthorizeURL); err != nil {
}
if _, err := urlutil.ParseAndValidateURL(o.AuthorizeURL.String()); err != nil {
return fmt.Errorf("proxy: invalid 'AUTHORIZE_SERVICE_URL': %v", err) return fmt.Errorf("proxy: invalid 'AUTHORIZE_SERVICE_URL': %v", err)
} }
if len(o.SigningKey) != 0 { if len(o.SigningKey) != 0 {
if _, err := cryptutil.NewES256Signer(o.SigningKey, "localhost"); err != nil { if _, err := cryptutil.NewES256Signer(o.SigningKey, ""); err != nil {
return fmt.Errorf("proxy: invalid 'SIGNING_KEY': %v", err) return fmt.Errorf("proxy: invalid 'SIGNING_KEY': %v", err)
} }
} }
@ -66,17 +70,19 @@ func ValidateOptions(o config.Options) error {
// Proxy stores all the information associated with proxying a request. // Proxy stores all the information associated with proxying a request.
type Proxy struct { type Proxy struct {
// SharedKey used to mutually authenticate service communication // SharedKey used to mutually authenticate service communication
SharedKey string SharedKey string
authenticateURL *url.URL authenticateURL *url.URL
authorizeURL *url.URL authenticateSigninURL *url.URL
authenticateSignoutURL *url.URL
authorizeURL *url.URL
AuthorizeClient clients.Authorizer AuthorizeClient clients.Authorizer
cipher cryptutil.Cipher cipher cryptutil.Cipher
cookieName string cookieName string
csrfStore sessions.CSRFStore cookieDomain string
cookieSecret []byte
defaultUpstreamTimeout time.Duration defaultUpstreamTimeout time.Duration
redirectURL *url.URL
refreshCooldown time.Duration refreshCooldown time.Duration
routeConfigs map[string]*routeConfig routeConfigs map[string]*routeConfig
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
@ -95,10 +101,17 @@ func New(opts config.Options) (*Proxy, error) {
if err := ValidateOptions(opts); err != nil { if err := ValidateOptions(opts); err != nil {
return nil, err return nil, err
} }
decodedCookieSecret, err := base64.StdEncoding.DecodeString(opts.CookieSecret)
if err != nil {
return nil, err
}
cipher, err := cryptutil.NewCipherFromBase64(opts.CookieSecret) cipher, err := cryptutil.NewCipherFromBase64(opts.CookieSecret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if opts.CookieDomain == "" {
opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String())
}
cookieStore, err := sessions.NewCookieStore( cookieStore, err := sessions.NewCookieStore(
&sessions.CookieStoreOptions{ &sessions.CookieStoreOptions{
@ -116,13 +129,12 @@ func New(opts config.Options) (*Proxy, error) {
p := &Proxy{ p := &Proxy{
SharedKey: opts.SharedKey, SharedKey: opts.SharedKey,
routeConfigs: make(map[string]*routeConfig), routeConfigs: make(map[string]*routeConfig),
cipher: cipher, cipher: cipher,
cookieSecret: decodedCookieSecret,
cookieDomain: opts.CookieDomain,
cookieName: opts.CookieName, cookieName: opts.CookieName,
csrfStore: cookieStore,
defaultUpstreamTimeout: opts.DefaultUpstreamTimeout, defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
redirectURL: &url.URL{Path: "/.pomerium/callback"},
refreshCooldown: opts.RefreshCooldown, refreshCooldown: opts.RefreshCooldown,
sessionStore: cookieStore, sessionStore: cookieStore,
signingKey: opts.SigningKey, signingKey: opts.SigningKey,
@ -130,6 +142,9 @@ func New(opts config.Options) (*Proxy, error) {
} }
// DeepCopy urls to avoid accidental mutation, err checked in validate func // DeepCopy urls to avoid accidental mutation, err checked in validate func
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL) p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL) p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
if err := p.UpdatePolicies(&opts); err != nil { if err := p.UpdatePolicies(&opts); err != nil {
@ -172,24 +187,24 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error {
if policy.TLSSkipVerify { if policy.TLSSkipVerify {
tlsClientConfig.InsecureSkipVerify = true tlsClientConfig.InsecureSkipVerify = true
isCustomClientConfig = true isCustomClientConfig = true
log.Warn().Str("to", policy.Source.String()).Msg("proxy: tls skip verify") log.Warn().Str("policy", policy.String()).Msg("proxy: tls skip verify")
} }
if policy.RootCAs != nil { if policy.RootCAs != nil {
tlsClientConfig.RootCAs = policy.RootCAs tlsClientConfig.RootCAs = policy.RootCAs
isCustomClientConfig = true isCustomClientConfig = true
log.Debug().Str("to", policy.Source.String()).Msg("proxy: custom root ca") log.Debug().Str("policy", policy.String()).Msg("proxy: custom root ca")
} }
if policy.ClientCertificate != nil { if policy.ClientCertificate != nil {
tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate} tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate}
isCustomClientConfig = true isCustomClientConfig = true
log.Debug().Str("to", policy.Source.String()).Msg("proxy: client certs enabled") log.Debug().Str("policy", policy.String()).Msg("proxy: client certs enabled")
} }
if policy.TLSServerName != "" { if policy.TLSServerName != "" {
tlsClientConfig.ServerName = policy.TLSServerName tlsClientConfig.ServerName = policy.TLSServerName
isCustomClientConfig = true isCustomClientConfig = true
log.Debug().Str("to", policy.Source.String()).Msgf("proxy: tls hostname override to: %s", policy.TLSServerName) log.Debug().Str("policy", policy.String()).Msgf("proxy: tls hostname override to: %s", policy.TLSServerName)
} }
// We avoid setting a custom client config unless we have to as // We avoid setting a custom client config unless we have to as
@ -212,19 +227,6 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error {
return nil return nil
} }
// UpstreamProxy stores information for proxying the request to the upstream.
type UpstreamProxy struct {
name string
handler http.Handler
}
// ServeHTTP handles the second (reverse-proxying) leg of pomerium's request flow
func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), fmt.Sprintf("%s%s", r.Host, r.URL.Path))
defer span.End()
u.handler.ServeHTTP(w, r.WithContext(ctx))
}
// NewReverseProxy returns a new ReverseProxy that routes URLs to the scheme, host, and // NewReverseProxy returns a new ReverseProxy that routes URLs to the scheme, host, and
// base path provided in target. NewReverseProxy rewrites the Host header. // base path provided in target. NewReverseProxy rewrites the Host header.
func NewReverseProxy(to *url.URL) *httputil.ReverseProxy { func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
@ -242,22 +244,17 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
return proxy return proxy
} }
// newReverseProxyHandler applies handler specific options to a given route. // each route has a custom set of middleware applied to the reverse proxy
func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config.Policy) (handler http.Handler, err error) { func (p *Proxy) newReverseProxyHandler(rp http.Handler, route *config.Policy) (http.Handler, error) {
handler = &UpstreamProxy{ r := pom_httputil.NewRouter()
name: route.Destination.Host, r.Use(middleware.StripPomeriumCookie(p.cookieName))
handler: rp,
}
c := middleware.NewChain()
c = c.Append(middleware.StripPomeriumCookie(p.cookieName))
// if signing key is set, add signer to middleware // if signing key is set, add signer to middleware
if len(p.signingKey) != 0 { if len(p.signingKey) != 0 {
signer, err := cryptutil.NewES256Signer(p.signingKey, route.Source.Host) signer, err := cryptutil.NewES256Signer(p.signingKey, route.Source.Host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c = c.Append(middleware.SignRequest(signer, HeaderUserID, HeaderEmail, HeaderGroups, HeaderJWT)) r.Use(middleware.SignRequest(signer, HeaderUserID, HeaderEmail, HeaderGroups, HeaderJWT))
} }
// websockets cannot use the non-hijackable timeout-handler // websockets cannot use the non-hijackable timeout-handler
if !route.AllowWebsockets { if !route.AllowWebsockets {
@ -265,11 +262,16 @@ func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config.
if route.UpstreamTimeout != 0 { if route.UpstreamTimeout != 0 {
timeout = route.UpstreamTimeout timeout = route.UpstreamTimeout
} }
timeoutMsg := fmt.Sprintf("%s failed to respond within the %s timeout period", route.Destination.Host, timeout) timeoutMsg := fmt.Sprintf("%s timed out in %s", route.Destination.Host, timeout)
handler = http.TimeoutHandler(handler, timeout, timeoutMsg) rp = http.TimeoutHandler(rp, timeout, timeoutMsg)
} }
// todo(bdd) : fix cors
return c.Then(handler), nil // if route.CORSAllowPreflight {
// r.Use(cors.Default().Handler)
// }
r.Host(route.Destination.Host)
r.PathPrefix("/").Handler(rp)
return r, nil
} }
// UpdateOptions updates internal structures based on config.Options // UpdateOptions updates internal structures based on config.Options

View file

@ -12,8 +12,6 @@ import (
"github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/config"
) )
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
func newTestOptions(t *testing.T) *config.Options { func newTestOptions(t *testing.T) *config.Options {
opts, err := config.NewOptions("https://authenticate.example", "https://authorize.example") opts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
if err != nil { if err != nil {
@ -285,10 +283,11 @@ func Test_UpdateOptions(t *testing.T) {
goodClientCertPolicies := testOptions(t) goodClientCertPolicies := testOptions(t)
goodClientCertPolicies.Policies = []config.Policy{ goodClientCertPolicies.Policies = []config.Policy{
{To: "http://foo.example", From: "http://bar.example", {To: "http://foo.example", From: "http://bar.example",
TLSClientKey: "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", TLSClientCert: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=", TLSClientKey: "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", TLSClientCert: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="},
},
} }
goodClientCertPolicies.Validate() goodClientCertPolicies.Validate()
customServerName := testOptions(t)
customServerName.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSServerName: "test"}}
tests := []struct { tests := []struct {
name string name string
originalOptions config.Options originalOptions config.Options
@ -301,14 +300,13 @@ func Test_UpdateOptions(t *testing.T) {
{"good no change", good, good, "", "https://corp.example.example", false, true}, {"good no change", good, good, "", "https://corp.example.example", false, true},
{"changed", good, newPolicies, "", "https://bar.example", false, true}, {"changed", good, newPolicies, "", "https://bar.example", false, true},
{"changed and missing", good, newPolicies, "", "https://corp.example.example", false, false}, {"changed and missing", good, newPolicies, "", "https://corp.example.example", false, false},
// todo(bdd): not sure what intent of this test is?
{"bad signing key", good, newPolicies, "^bad base 64", "https://corp.example.example", true, false}, {"bad signing key", good, newPolicies, "^bad base 64", "https://corp.example.example", true, false},
{"bad change bad policy url", good, badNewPolicy, "", "https://bar.example", true, false}, {"bad change bad policy url", good, badNewPolicy, "", "https://bar.example", true, false},
// todo: stand up a test server using self signed certificates
{"disable tls verification", good, disableTLSPolicies, "", "https://bar.example", false, true}, {"disable tls verification", good, disableTLSPolicies, "", "https://bar.example", false, true},
{"custom root ca", good, customCAPolicies, "", "https://bar.example", false, true}, {"custom root ca", good, customCAPolicies, "", "https://bar.example", false, true},
{"bad custom root ca base64", good, badCustomCAPolicies, "", "https://bar.example", true, false}, {"bad custom root ca base64", good, badCustomCAPolicies, "", "https://bar.example", true, false},
{"good client certs", good, goodClientCertPolicies, "", "https://bar.example", false, true}, {"good client certs", good, goodClientCertPolicies, "", "https://bar.example", false, true},
{"custom server name", customServerName, customServerName, "", "https://bar.example", false, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {