all: refactor handler logic

- all: prefer `FormValues` to `ParseForm` with subsequent `Form.Get`s
- all: refactor authentication stack to be checked by middleware, and accessible via request context.
- all: replace http.ServeMux with gorilla/mux’s router
- all: replace custom CSRF checks with gorilla/csrf middleware
- authenticate: extract callback path as constant.
- internal/config: implement stringer interface for policy
- internal/cryptutil: add helper func `NewBase64Key`
- internal/cryptutil: rename `GenerateKey` to `NewKey`
- internal/cryptutil: rename `GenerateRandomString` to `NewRandomStringN`
- internal/middleware: removed alice in favor of gorilla/mux
- internal/sessions: remove unused `ValidateRedirectURI` and `ValidateClientSecret`
- internal/sessions: replace custom CSRF with gorilla/csrf fork that supports custom handler protection
- internal/urlutil: add `SignedRedirectURL` to create hmac'd URLs
- internal/urlutil: add `ValidateURL` helper to parse URL options
- internal/urlutil: add `GetAbsoluteURL` which takes a request and returns its absolute URL.
- proxy: remove holdover state verification checks; we no longer are setting sessions in any proxy routes so we don’t need them.
- proxy: replace un-named http.ServeMux with named domain routes.

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2019-09-12 13:54:30 -07:00
parent a793249386
commit dc12947241
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
37 changed files with 1132 additions and 1384 deletions

3
.gitignore vendored
View file

@ -76,4 +76,5 @@ yarn.lock
node_modules
i18n/*
docs/.vuepress/dist/
.firebase/
.firebase/
.changes.md

View file

@ -87,31 +87,6 @@ https://github.com/bitly/oauth2_proxy/blob/master/LICENSE
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
alice
SPDX-License-Identifier: MIT
https://github.com/justinas/alice/blob/master/LICENSE
The MIT License (MIT)
Copyright (c) 2014 Justinas Stankevicius
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
goji
SPDX-License-Identifier: MIT
https://github.com/zenazn/goji/blob/master/LICENSE

View file

@ -15,6 +15,8 @@ import (
"github.com/pomerium/pomerium/internal/urlutil"
)
const callbackPath = "/oauth2/callback"
// ValidateOptions checks that configuration are complete and valid.
// Returns on first error found.
func ValidateOptions(o config.Options) error {
@ -24,11 +26,8 @@ func ValidateOptions(o config.Options) error {
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %v", err)
}
if o.AuthenticateURL == nil {
return errors.New("authenticate: 'AUTHENTICATE_SERVICE_URL' is required")
}
if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil {
return fmt.Errorf("authenticate: couldn't parse 'AUTHENTICATE_SERVICE_URL': %v", err)
if err := urlutil.ValidateURL(o.AuthenticateURL); err != nil {
return fmt.Errorf("authenticate: invalid 'AUTHENTICATE_SERVICE_URL': %v", err)
}
if o.ClientID == "" {
return errors.New("authenticate: 'IDP_CLIENT_ID' is required")
@ -44,8 +43,10 @@ type Authenticate struct {
SharedKey string
RedirectURL *url.URL
cookieName string
cookieDomain string
cookieSecret []byte
templates *template.Template
csrfStore sessions.CSRFStore
sessionStore sessions.SessionStore
cipher cryptutil.Cipher
provider identity.Authenticator
@ -61,6 +62,9 @@ func New(opts config.Options) (*Authenticate, error) {
if err != nil {
return nil, err
}
if opts.CookieDomain == "" {
opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String())
}
cookieStore, err := sessions.NewCookieStore(
&sessions.CookieStoreOptions{
Name: opts.CookieName,
@ -74,7 +78,7 @@ func New(opts config.Options) (*Authenticate, error) {
return nil, err
}
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
redirectURL.Path = "/oauth2/callback"
redirectURL.Path = callbackPath
provider, err := identity.New(
opts.Provider,
&identity.Provider{
@ -94,9 +98,11 @@ func New(opts config.Options) (*Authenticate, error) {
SharedKey: opts.SharedKey,
RedirectURL: redirectURL,
templates: templates.New(),
csrfStore: cookieStore,
sessionStore: cookieStore,
cipher: cipher,
provider: provider,
cookieSecret: decodedCookieSecret,
cookieName: opts.CookieName,
cookieDomain: opts.CookieDomain,
}, nil
}

View file

@ -10,7 +10,8 @@ import (
"strings"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/csrf"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
@ -31,24 +32,68 @@ var CSPHeaders = map[string]string{
// Handler returns the authenticate service's HTTP multiplexer, and routes.
func (a *Authenticate) Handler() http.Handler {
// validation middleware chain
c := middleware.NewChain()
c = c.Append(middleware.SetHeaders(CSPHeaders))
mux := http.NewServeMux()
mux.Handle("/robots.txt", c.ThenFunc(a.RobotsTxt))
r := httputil.NewRouter()
r.Use(middleware.SetHeaders(CSPHeaders))
r.Use(csrf.Protect(
a.cookieSecret,
csrf.Path("/"),
csrf.Domain(a.cookieDomain),
csrf.UnsafePaths([]string{"/oauth2/callback"}), // enforce CSRF on "safe" handler
csrf.FormValueName("state"), // rfc6749 section-10.12
csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieName)),
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
))
r.HandleFunc("/robots.txt", a.RobotsTxt).Methods(http.MethodGet)
// Identity Provider (IdP) endpoints
mux.Handle("/oauth2", c.ThenFunc(a.OAuthStart))
mux.Handle("/oauth2/callback", c.ThenFunc(a.OAuthCallback))
r.HandleFunc("/oauth2/callback", a.OAuthCallback).Methods(http.MethodGet)
r.HandleFunc("/api/v1/token", a.ExchangeToken)
// Proxy service endpoints
validationMiddlewares := c.Append(
middleware.ValidateSignature(a.SharedKey),
middleware.ValidateRedirectURI(a.RedirectURL),
)
mux.Handle("/sign_in", validationMiddlewares.ThenFunc(a.SignIn))
mux.Handle("/sign_out", validationMiddlewares.ThenFunc(a.SignOut)) // POST
// Direct user access endpoints
mux.Handle("/api/v1/token", c.ThenFunc(a.ExchangeToken))
return mux
v := r.PathPrefix("/.pomerium").Subrouter()
v.Use(middleware.ValidateSignature(a.SharedKey))
v.Use(middleware.ValidateRedirectURI(a.RedirectURL))
v.Use(sessions.RetrieveSession(a.sessionStore))
v.Use(a.VerifySession)
v.HandleFunc("/sign_in", a.SignIn)
v.HandleFunc("/sign_out", a.SignOut)
return r
}
// VerifySession is the middleware used to enforce a valid authentication
// session state is attached to the users's request context.
func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
state, err := sessions.FromContext(r.Context())
if errors.Is(err, sessions.ErrExpired) {
if err := a.refresh(w, r, state); err != nil {
log.FromRequest(r).Debug().Str("cause", err.Error()).Msg("authenticate: couldn't refresh session")
a.sessionStore.ClearSession(w, r)
a.redirectToIdentityProvider(w, r)
return
}
} else if err != nil {
log.FromRequest(r).Err(err).Msg("authenticate: unexpected session state")
a.sessionStore.ClearSession(w, r)
a.redirectToIdentityProvider(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) error {
newSession, err := a.provider.Refresh(r.Context(), s)
if err != nil {
return fmt.Errorf("authenticate: refresh failed: %w", err)
}
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
return fmt.Errorf("authenticate: refresh save failed: %w", err)
}
return nil
}
// RobotsTxt handles the /robots.txt route.
@ -59,87 +104,22 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
}
func (a *Authenticate) loadExisting(w http.ResponseWriter, r *http.Request) (*sessions.State, error) {
session, err := a.sessionStore.LoadSession(r)
if err != nil {
return nil, err
}
err = session.Valid()
if err == nil {
return session, nil
} else if !errors.Is(err, sessions.ErrExpired) {
return nil, fmt.Errorf("authenticate: non-refreshable error: %w", err)
} else {
return a.refresh(w, r, session)
}
}
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (*sessions.State, error) {
newSession, err := a.provider.Refresh(r.Context(), s)
if err != nil {
return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
}
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
return nil, fmt.Errorf("authenticate: refresh save failed: %w", err)
}
return newSession, nil
}
// SignIn handles to authenticating a user.
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
session, err := a.loadExisting(w, r)
if err != nil {
log.FromRequest(r).Debug().Err(err).Msg("authenticate: need new session")
a.sessionStore.ClearSession(w, r)
a.OAuthStart(w, r)
return
}
if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
state := r.Form.Get("state")
if state == "" {
httputil.ErrorResponse(w, r, httputil.Error("sign in state empty", http.StatusBadRequest, nil))
return
}
redirectURL, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
return
}
// encrypt session state as json blob
encrypted, err := sessions.MarshalSession(session, a.cipher)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("couldn't marshal session", http.StatusInternalServerError, err))
return
}
http.Redirect(w, r, getAuthCodeRedirectURL(redirectURL, state, encrypted), http.StatusFound)
}
func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string {
// ParseQuery err handled by go's mux stack
params, _ := url.ParseQuery(redirectURL.RawQuery)
params.Set("code", authCode)
params.Set("state", state)
redirectURL.RawQuery = params.Encode()
return redirectURL.String()
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
}
// SignOut signs the user out and attempts to revoke the user's identity session
// Handles both GET and POST.
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
redirectURI := r.Form.Get("redirect_uri")
session, err := a.sessionStore.LoadSession(r)
session, err := sessions.FromContext(r.Context())
if err != nil {
log.Error().Err(err).Msg("authenticate: no session to signout, redirect and clear")
http.Redirect(w, r, redirectURI, http.StatusFound)
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
return
}
a.sessionStore.ClearSession(w, r)
@ -148,46 +128,30 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err))
return
}
http.Redirect(w, r, redirectURI, http.StatusFound)
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
return
}
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
}
// OAuthStart starts the authenticate process by redirecting to the identity provider.
// redirectToIdentityProvider starts the authenticate process by redirecting the
// user to their respective identity provider.
//
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
// https://tools.ietf.org/html/rfc6749#section-4.2.1
func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
authRedirectURL := a.RedirectURL.ResolveReference(r.URL)
// Nonce is the opaque, cryptographically binding value used to maintain
// state between the request and the callback.
// OIDC : 3.1.2.1. Authentication Request
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
a.csrfStore.SetCSRF(w, r, nonce)
// Redirection URI to which the response will be sent. This URI MUST exactly
// match one of the Redirection URI values for the Client pre-registered at
// at your identity provider
proxyRedirectURL, err := urlutil.ParseAndValidateURL(authRedirectURL.Query().Get("redirect_uri"))
if err != nil || !middleware.SameDomain(proxyRedirectURL, a.RedirectURL) {
httputil.ErrorResponse(w, r, httputil.Error("proxy url not from the root domain", http.StatusBadRequest, err))
return
}
// get the signature and timestamp values then compare hmac
proxyRedirectSig := authRedirectURL.Query().Get("sig")
ts := authRedirectURL.Query().Get("ts")
if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, a.SharedKey) {
httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil))
return
}
// State is the opaque value used to maintain state between the request and
// the callback; contains both the nonce and redirect URI
state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String())))
// build the provider sign in url
signInURL := a.provider.GetSignInURL(state)
http.Redirect(w, r, signInURL, http.StatusFound)
func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http.Request) {
redirectURL := a.RedirectURL.ResolveReference(r.URL)
nonce := csrf.Token(r)
state := fmt.Sprintf("%v:%v", nonce, redirectURL.String())
encodedState := base64.URLEncoding.EncodeToString([]byte(state))
http.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
}
// OAuthCallback handles the callback from the identity provider.
//
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowSteps
// https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
redirect, err := a.getOAuthCallback(w, r)
@ -195,57 +159,49 @@ func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err))
return
}
// redirect back to the proxy-service via sign_in
http.Redirect(w, r, redirect.String(), http.StatusFound)
}
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
if err := r.ParseForm(); err != nil {
return nil, httputil.Error("invalid signature", http.StatusBadRequest, err)
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
//
// first, check if the identity provider returned an error
if idpError := r.FormValue("error"); idpError != "" {
return nil, httputil.Error(idpError, http.StatusBadRequest, fmt.Errorf("identity provider: %v", idpError))
}
// OIDC : 3.1.2.6. Authentication Error Response
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthError
if idpError := r.Form.Get("error"); idpError != "" {
return nil, httputil.Error("provider returned an error", http.StatusBadRequest, fmt.Errorf("provider error: %v", idpError))
}
code := r.Form.Get("code")
// fail if no session redemption code is returned
code := r.FormValue("code")
if code == "" {
return nil, httputil.Error("provider didn't reply with code", http.StatusBadRequest, nil)
return nil, httputil.Error("identity provider returned empty code", http.StatusBadRequest, nil)
}
// validate the returned code with the identity provider
// Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5
//
// Exchange the supplied Authorization Code for a valid user session.
session, err := a.provider.Authenticate(r.Context(), code)
if err != nil {
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
}
// OIDC : 3.1.2.5. Successful Authentication Response
// Opaque value used to maintain state between the request and the callback.
bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state"))
// state includes a csrf nonce (validated by middleware) and redirect uri
bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state"))
if err != nil {
return nil, fmt.Errorf("failed decoding state: %w", err)
return nil, httputil.Error("malformed state", http.StatusBadRequest, err)
}
s := strings.SplitN(string(bytes), ":", 2)
if len(s) != 2 {
return nil, fmt.Errorf("invalid state size: %d", len(s))
// split state into its it's components (nonce:redirect_uri)
statePayload := strings.SplitN(string(bytes), ":", 2)
if len(statePayload) != 2 {
return nil, fmt.Errorf("state malformed, size: %d", len(statePayload))
}
// state contains the csrf nonce and redirect uri
nonce := s[0]
redirect := s[1]
c, err := a.csrfStore.GetCSRF(r)
defer a.csrfStore.ClearCSRF(w, r)
if err != nil || c.Value != nonce {
return nil, fmt.Errorf("csrf failure: %w", err)
}
redirectURL, err := urlutil.ParseAndValidateURL(redirect)
// parse redirect_uri; ignore csrf nonce (validity asserted by middleware)
redirectURL, err := urlutil.ParseAndValidateURL(statePayload[1])
if err != nil {
return nil, httputil.Error(fmt.Sprintf("invalid redirect uri %s", redirect), http.StatusBadRequest, err)
}
// sanity check, we are redirecting back to the same subdomain right?
if !middleware.SameDomain(redirectURL, a.RedirectURL) {
return nil, httputil.Error(fmt.Sprintf("invalid redirect domain %v, %v", redirectURL, a.RedirectURL), http.StatusBadRequest, nil)
return nil, httputil.Error("invalid redirect uri", http.StatusBadRequest, err)
}
// todo(bdd): if we want to be _extra_ sure, we can validate that the
// redirectURL hmac is valid. But the nonce should cover the integrity...
// OK. Looks good so let's persist our user session
if err := a.sessionStore.SaveSession(w, r, session); err != nil {
return nil, fmt.Errorf("failed saving new session: %w", err)
}
@ -256,11 +212,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
// and exchanges that token for a pomerium session. The provided token's
// audience ('aud') attribute must match Pomerium's client_id.
func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
code := r.Form.Get("id_token")
code := r.FormValue("id_token")
if code == "" {
httputil.ErrorResponse(w, r, httputil.Error("missing id token", http.StatusBadRequest, nil))
return

View file

@ -21,6 +21,7 @@ func testAuthenticate() *Authenticate {
var auth Authenticate
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
auth.SharedKey = "IzY7MOZwzfOkmELXgozHDKTxoT3nOYhwkcmUVINsRww="
auth.cookieSecret = []byte(auth.SharedKey)
auth.templates = templates.New()
return &auth
}
@ -51,6 +52,7 @@ func TestAuthenticate_Handler(t *testing.T) {
t.Error("handler cannot be nil")
}
req := httptest.NewRequest("GET", "/robots.txt", nil)
req.Header.Set("Accept", "application/json")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
@ -63,6 +65,7 @@ func TestAuthenticate_Handler(t *testing.T) {
}
func TestAuthenticate_SignIn(t *testing.T) {
t.Parallel()
tests := []struct {
name string
state string
@ -76,36 +79,35 @@ func TestAuthenticate_SignIn(t *testing.T) {
{"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
{"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusFound},
{"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound},
{"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusBadRequest}, // mocking hmac is meh
{"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusBadRequest},
{"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusFound}, // mocking hmac is meh
{"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound},
// {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
{"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
{"malformed form", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
{"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
{"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
// actually caught by go's handler, but we should keep the test.
{"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
{"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusInternalServerError},
{"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
{"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Authenticate{
sessionStore: tt.session,
provider: tt.provider,
RedirectURL: uriParse("https://some.example"),
csrfStore: &sessions.MockCSRFStore{},
RedirectURL: uriParseHelper("https://some.example"),
SharedKey: "secret",
cipher: tt.cipher,
}
uri := &url.URL{Host: "corp.some.example", Scheme: "https", Path: "/"}
if tt.name == "malformed form" {
uri.RawQuery = "example=%zzzzz"
} else {
uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI)
}
uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI)
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
r.Header.Set("Accept", "application/json")
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, nil)
r = r.WithContext(ctx)
w := httptest.NewRecorder()
a.SignIn(w, r)
@ -117,61 +119,18 @@ func TestAuthenticate_SignIn(t *testing.T) {
}
}
type mockCipher struct{}
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Marshal(s interface{}) (string, error) { return "ok", nil }
func (a mockCipher) Unmarshal(s string, i interface{}) error {
if s == "unmarshal error" || s == "error" {
return errors.New("error")
}
return nil
}
func Test_getAuthCodeRedirectURL(t *testing.T) {
tests := []struct {
name string
redirectURL *url.URL
state string
authCode string
want string
}{
{"https", uriParse("https://www.pomerium.io"), "state", "auth-code", "https://www.pomerium.io?code=auth-code&state=state"},
{"http", uriParse("http://www.pomerium.io"), "state", "auth-code", "http://www.pomerium.io?code=auth-code&state=state"},
{"no subdomain", uriParse("http://pomerium.io"), "state", "auth-code", "http://pomerium.io?code=auth-code&state=state"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getAuthCodeRedirectURL(tt.redirectURL, tt.state, tt.authCode); got != tt.want {
t.Errorf("getAuthCodeRedirectURL() = %v, want %v", got, tt.want)
}
})
}
}
func uriParse(s string) *url.URL {
func uriParseHelper(s string) *url.URL {
uri, _ := url.Parse(s)
return uri
}
func TestAuthenticate_SignOut(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
ctxError error
redirectURL string
sig string
ts string
@ -181,17 +140,16 @@ func TestAuthenticate_SignOut(t *testing.T) {
wantCode int
wantBody string
}{
{"good post", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"},
{"malformed form", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusInternalServerError, ""},
{"load session error", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"},
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, ""},
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Authenticate{
sessionStore: tt.sessionStore,
provider: tt.provider,
cipher: mockCipher{},
templates: templates.New(),
}
u, _ := url.Parse("/sign_out")
@ -200,10 +158,11 @@ func TestAuthenticate_SignOut(t *testing.T) {
params.Add("ts", tt.ts)
params.Add("redirect_uri", tt.redirectURL)
u.RawQuery = params.Encode()
if tt.name == "malformed form" {
u.RawQuery = "example=%zzzzz"
}
r := httptest.NewRequest(tt.method, u.String(), nil)
state, _ := tt.sessionStore.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
w := httptest.NewRecorder()
a.SignOut(w, r)
@ -217,64 +176,8 @@ func TestAuthenticate_SignOut(t *testing.T) {
}
}
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) string {
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
h := cryptutil.Hash(secret, data)
return base64.URLEncoding.EncodeToString(h)
}
func TestAuthenticate_OAuthStart(t *testing.T) {
tests := []struct {
name string
method string
redirectURLSetting string
redirectURL string
sig string
ts string
provider identity.Authenticator
csrfStore sessions.MockCSRFStore
// sessionStore sessions.SessionStore
wantCode int
}{
{"good", http.MethodGet, "https://corp.pomerium.io/", "https://corp.pomerium.io/", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusFound},
{"bad timestamp", http.MethodGet, "https://corp.pomerium.io/", "https://corp.pomerium.io/", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Add(10 * time.Hour).Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest},
{"missing redirect", http.MethodGet, "https://corp.pomerium.io/", "", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest},
{"malformed redirect", http.MethodGet, "https://corp.pomerium.io/", "https://pomerium.com%zzzzz", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest},
{"different domains", http.MethodGet, "https://corp.notpomerium.io/", "https://corp.pomerium.io/", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Authenticate{
RedirectURL: uriParse(tt.redirectURLSetting),
csrfStore: tt.csrfStore,
provider: tt.provider,
SharedKey: "secret",
cipher: mockCipher{},
}
u, _ := url.Parse("/oauth_start")
params, _ := url.ParseQuery(u.RawQuery)
params.Add("sig", tt.sig)
params.Add("ts", tt.ts)
params.Add("redirect_uri", tt.redirectURL)
u.RawQuery = params.Encode()
r := httptest.NewRequest(tt.method, u.String(), nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
a.OAuthStart(w, r)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v\n%v", status, tt.wantCode, w.Body.String())
}
})
}
}
func TestAuthenticate_OAuthCallback(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
@ -286,24 +189,20 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
authenticateURL string
session sessions.SessionStore
provider identity.MockProvider
csrfStore sessions.MockCSRFStore
want string
wantCode int
}{
{"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusFound},
{"get csrf error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", GetError: errors.New("error"), Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError},
{"csrf nonce error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError},
{"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
{"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
{"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
{"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
{"invalid state string", http.MethodGet, "", "code", "nonce:https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
{"malformed state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
{"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
{"malformed form", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
{"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"different domains", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://some.example.notpomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusFound},
{"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
{"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusInternalServerError},
{"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
{"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
{"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
{"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad base64 state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")) + "%", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"too many state delimeters", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io:wait")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"too few state delimeters", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -311,7 +210,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
a := &Authenticate{
RedirectURL: authURL,
sessionStore: tt.session,
csrfStore: tt.csrfStore,
provider: tt.provider,
}
u, _ := url.Parse("/oauthGet")
@ -322,9 +220,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
u.RawQuery = params.Encode()
if tt.name == "malformed form" {
u.RawQuery = "example=%zzzzz"
}
r := httptest.NewRequest(tt.method, u.String(), nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
@ -339,6 +234,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
}
func TestAuthenticate_ExchangeToken(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
@ -384,3 +280,55 @@ func TestAuthenticate_ExchangeToken(t *testing.T) {
})
}
}
func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
session sessions.SessionStore
ctxError error
provider identity.Authenticator
wantStatus int
}{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusOK},
{"expired,refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
{"expired,save error", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := Authenticate{
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
sessionStore: tt.session,
provider: tt.provider,
}
r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := a.VerifySession(fn)
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
})
}
}

View file

@ -14,7 +14,7 @@ import (
func ValidateOptions(o config.Options) error {
decoded, err := base64.StdEncoding.DecodeString(o.SharedKey)
if err != nil {
return fmt.Errorf("authorize: `SHARED_SECRET` setting is invalid base64: %v", err)
return fmt.Errorf("authorize: `SHARED_SECRET` malformed base64: %v", err)
}
if len(decoded) != 32 {
return fmt.Errorf("authorize: `SHARED_SECRET` want 32 but got %d bytes", len(decoded))

View file

@ -8,6 +8,7 @@ import (
"time"
"github.com/fsnotify/fsnotify"
"github.com/gorilla/mux"
"github.com/spf13/viper"
"google.golang.org/grpc"
@ -44,9 +45,9 @@ func main() {
setupTracing(opt)
setupHTTPRedirectServer(opt)
mux := http.NewServeMux()
r := newGlobalRouter(opt)
grpcServer := setupGRPCServer(opt)
_, err = newAuthenticateService(*opt, mux)
_, err = newAuthenticateService(*opt, r.Host(urlutil.StripPort(opt.AuthenticateURL.Host)).Subrouter())
if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: authenticate")
}
@ -56,7 +57,7 @@ func main() {
log.Fatal().Err(err).Msg("cmd/pomerium: authorize")
}
proxy, err := newProxyService(*opt, mux)
proxy, err := newProxyService(*opt, r)
if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: proxy")
}
@ -70,8 +71,7 @@ func main() {
log.Info().Str("file", e.Name).Msg("cmd/pomerium: config file changed")
opt = config.HandleConfigUpdate(*configFile, opt, []config.OptionsUpdater{authz, proxy})
})
srv, err := httputil.NewTLSServer(configToServerOptions(opt), mainHandler(opt, mux), grpcServer)
srv, err := httputil.NewTLSServer(configToServerOptions(opt), r, grpcServer)
if err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium: couldn't start pomerium")
}
@ -80,7 +80,7 @@ func main() {
os.Exit(0)
}
func newAuthenticateService(opt config.Options, mux *http.ServeMux) (*authenticate.Authenticate, error) {
func newAuthenticateService(opt config.Options, r *mux.Router) (*authenticate.Authenticate, error) {
if !config.IsAuthenticate(opt.Services) {
return nil, nil
}
@ -88,7 +88,7 @@ func newAuthenticateService(opt config.Options, mux *http.ServeMux) (*authentica
if err != nil {
return nil, err
}
mux.Handle(urlutil.StripPort(opt.AuthenticateURL.Host)+"/", service.Handler())
r.PathPrefix("/").Handler(service.Handler())
return service, nil
}
@ -104,7 +104,7 @@ func newAuthorizeService(opt config.Options, rpc *grpc.Server) (*authorize.Autho
return service, nil
}
func newProxyService(opt config.Options, mux *http.ServeMux) (*proxy.Proxy, error) {
func newProxyService(opt config.Options, r *mux.Router) (*proxy.Proxy, error) {
if !config.IsProxy(opt.Services) {
return nil, nil
}
@ -112,15 +112,15 @@ func newProxyService(opt config.Options, mux *http.ServeMux) (*proxy.Proxy, erro
if err != nil {
return nil, err
}
mux.Handle("/", service.Handler())
r.PathPrefix("/").Handler(service.Handler())
return service, nil
}
func mainHandler(o *config.Options, mux http.Handler) http.Handler {
c := middleware.NewChain()
c = c.Append(metrics.HTTPMetricsHandler(o.Services))
c = c.Append(log.NewHandler(log.Logger))
c = c.Append(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
func newGlobalRouter(o *config.Options) *mux.Router {
mux := httputil.NewRouter()
mux.Use(metrics.HTTPMetricsHandler(o.Services))
mux.Use(log.NewHandler(log.Logger))
mux.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
log.FromRequest(r).Debug().
Dur("duration", duration).
Int("size", size).
@ -133,15 +133,15 @@ func mainHandler(o *config.Options, mux http.Handler) http.Handler {
Msg("http-request")
}))
if len(o.Headers) != 0 {
c = c.Append(middleware.SetHeaders(o.Headers))
mux.Use(middleware.SetHeaders(o.Headers))
}
c = c.Append(log.ForwardedAddrHandler("fwd_ip"))
c = c.Append(log.RemoteAddrHandler("ip"))
c = c.Append(log.UserAgentHandler("user_agent"))
c = c.Append(log.RefererHandler("referer"))
c = c.Append(log.RequestIDHandler("req_id", "Request-Id"))
c = c.Append(middleware.Healthcheck("/ping", version.UserAgent()))
return c.Then(mux)
mux.Use(log.ForwardedAddrHandler("fwd_ip"))
mux.Use(log.RemoteAddrHandler("ip"))
mux.Use(log.UserAgentHandler("user_agent"))
mux.Use(log.RefererHandler("referer"))
mux.Use(log.RequestIDHandler("req_id", "Request-Id"))
mux.Use(middleware.Healthcheck("/ping", version.UserAgent()))
return mux
}
func configToServerOptions(opt *config.Options) *httputil.ServerOptions {

View file

@ -21,7 +21,7 @@ import (
)
func Test_newAuthenticateService(t *testing.T) {
mux := http.NewServeMux()
mux := httputil.NewRouter()
tests := []struct {
name string
@ -127,7 +127,7 @@ func Test_newProxyeService(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mux := http.NewServeMux()
mux := httputil.NewRouter()
testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
t.Fatal(err)
@ -161,7 +161,7 @@ func Test_newProxyeService(t *testing.T) {
}
}
func Test_mainHandler(t *testing.T) {
func Test_newGlobalRouter(t *testing.T) {
o := config.Options{
Services: "all",
Headers: map[string]string{
@ -172,7 +172,6 @@ func Test_mainHandler(t *testing.T) {
"Content-Security-Policy": "default-src 'none'; style-src 'self' 'sha256-pSTVzZsFAqd2U3QYu+BoBDtuJWaPM/+qMy/dBRrhb5Y='; img-src 'self';",
"Referrer-Policy": "Same-origin",
}}
mux := http.NewServeMux()
req := httptest.NewRequest(http.MethodGet, "/404", nil)
rr := httptest.NewRecorder()
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -181,8 +180,9 @@ func Test_mainHandler(t *testing.T) {
io.WriteString(w, `OK`)
})
mux.Handle("/404", h)
out := mainHandler(&o, mux)
out := newGlobalRouter(&o)
out.Handle("/404", h)
out.ServeHTTP(rr, req)
expected := fmt.Sprintf("OK")
body := rr.Body.String()

View file

@ -5,6 +5,23 @@
### New
- Add ability to override HTTPS backend's TLS Server Name. [GH-297](https://github.com/pomerium/pomerium/pull/297)
- Add ability to set pomerium's encrypted session in a auth bearer token, or query param.
### Security
- Under certain circumstances, where debug logging was enabled, pomerium's shared secret could be leaked to http access logs as a query param.
### Fixed
- Fixed an issue where CSRF would fail if multiple tabs were open. [GH-306](https://github.com/pomerium/pomerium/issues/306)
### Changed
- Authenticate service no longer uses gRPC.
### Removed
- Removed `AUTHENTICATE_INTERNAL_URL`/`authenticate_internal_url` which is no longer used.
## v0.3.0

View file

@ -228,8 +228,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor
| Config Key | Description | Required |
| :--------------- | :---------------------------------------------------------------- | -------- |
| tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ |
| tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ |
| tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ |
| tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ |
### Jaeger
@ -243,8 +243,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor
| Config Key | Description | Required |
| :-------------------------------- | :------------------------------------------ | -------- |
| tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ |
| tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ |
| tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ |
| tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ |
#### Example
@ -478,11 +478,11 @@ Authenticate Service URL is the externally accessible URL for the authenticate s
- Config File Key: `authorize_service_url`
- Type: `URL`
- Required
- Example: `https://access.corp.example.com` or `https://pomerium-authorize-service.default.svc.cluster.local`
- Example: `https://authorize.corp.example.com` or `https://pomerium-authorize-service.default.svc.cluster.local`
Authorize Service URL is the location of the internally accessible authorize service. NOTE: Unlike authenticate, authorize has no publicly accessible http handlers so this setting is purely for gRPC communication.
If your load balancer does not support gRPC pass-through you'll need to set this value to an internally routable location (`https://pomerium-authorize-service.default.svc.cluster.local`) instead of an externally routable one (`https://access.corp.example.com`).
If your load balancer does not support gRPC pass-through you'll need to set this value to an internally routable location (`https://pomerium-authorize-service.default.svc.cluster.local`) instead of an externally routable one (`https://authorize.corp.example.com`).
## Override Certificate Name

4
go.mod
View file

@ -9,10 +9,12 @@ require (
github.com/fsnotify/fsnotify v1.4.7
github.com/golang/mock v1.3.1
github.com/golang/protobuf v1.3.1
github.com/google/go-cmp v0.3.0
github.com/google/go-cmp v0.3.1
github.com/gorilla/mux v1.6.2
github.com/magiconair/properties v1.8.1 // indirect
github.com/mitchellh/hashstructure v1.0.0
github.com/pelletier/go-toml v1.4.0 // indirect
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30
github.com/pomerium/go-oidc v2.0.0+incompatible
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/prometheus/client_golang v0.9.3

8
go.sum
View file

@ -65,11 +65,16 @@ github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk=
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 h1:Iju5GlWwrvL6UBg4zJJt3btmonfrMlCDdsejg4CZE7c=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
@ -115,9 +120,12 @@ github.com/pelletier/go-toml v1.4.0 h1:u3Z1r+oOXJIkxqw34zVhyPgjBsm6X2wn21NWs/HfS
github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo=
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30 h1:jggCv6hZvcxjGa3gqkYY2EUuOkITI9Znugz/f3QJfRQ=
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0=
github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o=
github.com/pomerium/go-oidc v2.0.0+incompatible/go.mod h1:DRsGVw6MOgxbfq4Y57jKOE8lbEfayxeiY0A8/4vxjBM=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=

View file

@ -217,7 +217,7 @@ func (o *Options) Validate() error {
// shared key must be set for all modes other than "all"
if o.SharedKey == "" {
if o.Services == "all" {
o.SharedKey = cryptutil.GenerateRandomString(32)
o.SharedKey = cryptutil.NewBase64Key()
} else {
return errors.New("shared-key cannot be empty")
}

View file

@ -116,3 +116,9 @@ func (p *Policy) Validate() error {
return nil
}
func (p *Policy) String() string {
if p.Source == nil || p.Destination == nil {
return fmt.Sprintf("%s → %s", p.From, p.To)
}
return fmt.Sprintf("%s → %s", p.Source.String(), p.Destination.String())
}

View file

@ -1,4 +1,4 @@
package config // import "github.com/pomerium/pomerium/internal/config"
package config
import (
"testing"
@ -44,3 +44,28 @@ func Test_Validate(t *testing.T) {
})
}
}
func TestPolicy_String(t *testing.T) {
t.Parallel()
tests := []struct {
name string
From string
To string
want string
}{
{"good", "https://pomerium.io", "https://localhost", "https://pomerium.io → https://localhost"},
{"failed to validate", "https://pomerium.io", "localhost", "https://pomerium.io → localhost"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Policy{
From: tt.From,
To: tt.To,
}
p.Validate()
if got := p.String(); got != tt.want {
t.Errorf("Policy.String() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -13,20 +13,27 @@ import (
"golang.org/x/crypto/chacha20poly1305"
)
// DefaultKeySize is the default key size in bytes.
const DefaultKeySize = 32
// GenerateKey generates a random 32-byte key.
// NewKey generates a random 32-byte key.
//
// Panics if source of randomness fails.
func GenerateKey() []byte {
func NewKey() []byte {
return randomBytes(DefaultKeySize)
}
// GenerateRandomString returns base64 encoded securely generated random string
// of a given set of bytes.
// NewBase64Key generates a random base64 encoded 32-byte key.
//
// Panics if source of randomness fails.
func GenerateRandomString(c int) string {
func NewBase64Key() string {
return NewRandomStringN(DefaultKeySize)
}
// NewRandomStringN returns base64 encoded random string of a given num of bytes.
//
// Panics if source of randomness fails.
func NewRandomStringN(c int) string {
return base64.StdEncoding.EncodeToString(randomBytes(c))
}

View file

@ -1,4 +1,4 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
package cryptutil
import (
"crypto/rand"
@ -13,7 +13,7 @@ import (
func TestEncodeAndDecodeAccessToken(t *testing.T) {
plaintext := []byte("my plain text value")
key := GenerateKey()
key := NewKey()
c, err := NewCipher(key)
if err != nil {
t.Fatalf("unexpected err: %v", err)
@ -47,7 +47,7 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
}
func TestMarshalAndUnmarshalStruct(t *testing.T) {
key := GenerateKey()
key := NewKey()
c, err := NewCipher(key)
if err != nil {
@ -102,7 +102,7 @@ func TestMarshalAndUnmarshalStruct(t *testing.T) {
}
func TestCipherDataRace(t *testing.T) {
cipher, err := NewCipher(GenerateKey())
cipher, err := NewCipher(NewKey())
if err != nil {
t.Fatalf("unexpected generating cipher err: %v", err)
}
@ -183,21 +183,21 @@ func TestGenerateRandomString(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := GenerateRandomString(tt.c)
o := NewRandomStringN(tt.c)
b, err := base64.StdEncoding.DecodeString(o)
if err != nil {
t.Error(err)
}
got := len(b)
if got != tt.want {
t.Errorf("GenerateRandomString() = %d, want %d", got, tt.want)
t.Errorf("NewRandomStringN() = %d, want %d", got, tt.want)
}
})
}
}
func TestXChaCha20Cipher_Marshal(t *testing.T) {
t.Parallel()
tests := []struct {
name string
s interface{}
@ -225,7 +225,7 @@ func TestXChaCha20Cipher_Marshal(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewCipher(GenerateKey())
c, err := NewCipher(NewKey())
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
@ -239,15 +239,15 @@ func TestXChaCha20Cipher_Marshal(t *testing.T) {
}
func TestNewCipher(t *testing.T) {
t.Parallel()
tests := []struct {
name string
secret []byte
wantErr bool
}{
{"simple 32 byte key", GenerateKey(), false},
{"simple 32 byte key", NewKey(), false},
{"key too short", []byte("what is entropy"), true},
{"key too long", []byte(GenerateRandomString(33)), true},
{"key too long", []byte(NewRandomStringN(33)), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -261,16 +261,16 @@ func TestNewCipher(t *testing.T) {
}
func TestNewCipherFromBase64(t *testing.T) {
t.Parallel()
tests := []struct {
name string
s string
wantErr bool
}{
{"simple 32 byte key", base64.StdEncoding.EncodeToString(GenerateKey()), false},
{"simple 32 byte key", base64.StdEncoding.EncodeToString(NewKey()), false},
{"key too short", base64.StdEncoding.EncodeToString([]byte("what is entropy")), true},
{"key too long", GenerateRandomString(33), true},
{"bad base 64", string(GenerateKey()), true},
{"key too long", NewRandomStringN(33), true},
{"bad base 64", string(NewKey()), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -282,3 +282,26 @@ func TestNewCipherFromBase64(t *testing.T) {
})
}
}
func TestNewBase64Key(t *testing.T) {
t.Parallel()
tests := []struct {
name string
want int
}{
{"simple", 32},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := NewBase64Key()
b, err := base64.StdEncoding.DecodeString(o)
if err != nil {
t.Error(err)
}
got := len(b)
if got != tt.want {
t.Errorf("NewBase64Key() = %d, want %d", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,19 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"net/http"
"github.com/gorilla/mux"
"github.com/pomerium/csrf"
)
// NewRouter returns a new router instance.
func NewRouter() *mux.Router {
return mux.NewRouter()
}
// CSRFFailureHandler sets a HTTP 403 Forbidden status and writes the
// CSRF failure reason to the response.
func CSRFFailureHandler(w http.ResponseWriter, r *http.Request) {
ErrorResponse(w, r, Error("CSRF Failure", http.StatusForbidden, csrf.FailureReason(r)))
}

View file

@ -0,0 +1,37 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestCSRFFailureHandler(t *testing.T) {
tests := []struct {
name string
wantBody string
wantStatus int
}{
{"basic csrf failure", "{\"error\":\"CSRF Failure\"}\n", http.StatusForbidden},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
CSRFFailureHandler(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %s", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %s", diff)
}
})
}
}

View file

@ -1,109 +0,0 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import "net/http"
// Constructor is a type alias for func(http.Handler) http.Handler
type Constructor func(http.Handler) http.Handler
// Chain acts as a list of http.Handler constructors.
// Chain is effectively immutable:
// once created, it will always hold
// the same set of constructors in the same order.
type Chain struct {
constructors []Constructor
}
// NewChain creates a new chain,
// memorizing the given list of middleware constructors.
// New serves no other function,
// constructors are only called upon a call to Then().
func NewChain(constructors ...Constructor) Chain {
return Chain{append([]Constructor(nil), constructors...)}
}
// Then chains the middleware and returns the final http.Handler.
// NewChain(m1, m2, m3).Then(h)
// is equivalent to:
// m1(m2(m3(h)))
// When the request comes in, it will be passed to m1, then m2, then m3
// and finally, the given handler
// (assuming every middleware calls the following one).
//
// A chain can be safely reused by calling Then() several times.
// stdStack := middleware.NewChain(ratelimitHandler, csrfHandler)
// indexPipe = stdStack.Then(indexHandler)
// authPipe = stdStack.Then(authHandler)
// Note that constructors are called on every call to Then()
// and thus several instances of the same middleware will be created
// when a chain is reused in this way.
// For proper middleware, this should cause no problems.
//
// Then() treats nil as http.DefaultServeMux.
func (c Chain) Then(h http.Handler) http.Handler {
if h == nil {
h = http.DefaultServeMux
}
for i := range c.constructors {
h = c.constructors[len(c.constructors)-1-i](h)
}
return h
}
// ThenFunc works identically to Then, but takes
// a HandlerFunc instead of a Handler.
//
// The following two statements are equivalent:
// c.Then(http.HandlerFunc(fn))
// c.ThenFunc(fn)
//
// ThenFunc provides all the guarantees of Then.
func (c Chain) ThenFunc(fn http.HandlerFunc) http.Handler {
if fn == nil {
return c.Then(nil)
}
return c.Then(fn)
}
// Append extends a chain, adding the specified constructors
// as the last ones in the request flow.
//
// Append returns a new chain, leaving the original one untouched.
//
// stdChain := middleware.NewChain(m1, m2)
// extChain := stdChain.Append(m3, m4)
// // requests in stdChain go m1 -> m2
// // requests in extChain go m1 -> m2 -> m3 -> m4
func (c Chain) Append(constructors ...Constructor) Chain {
newCons := make([]Constructor, 0, len(c.constructors)+len(constructors))
newCons = append(newCons, c.constructors...)
newCons = append(newCons, constructors...)
return Chain{newCons}
}
// Extend extends a chain by adding the specified chain
// as the last one in the request flow.
//
// Extend returns a new chain, leaving the original one untouched.
//
// stdChain := middleware.NewChain(m1, m2)
// ext1Chain := middleware.NewChain(m3, m4)
// ext2Chain := stdChain.Extend(ext1Chain)
// // requests in stdChain go m1 -> m2
// // requests in ext1Chain go m3 -> m4
// // requests in ext2Chain go m1 -> m2 -> m3 -> m4
//
// Another example:
// aHtmlAfterNosurf := middleware.NewChain(m2)
// aHtml := middleware.NewChain(m1, func(h http.Handler) http.Handler {
// csrf := nosurf.NewChain(h)
// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail))
// return csrf
// }).Extend(aHtmlAfterNosurf)
// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-handler
// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail
func (c Chain) Extend(chain Chain) Chain {
return c.Append(chain.constructors...)
}

View file

@ -1,177 +0,0 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
// A constructor for middleware
// that writes its own "tag" into the RW and does nothing else.
// Useful in checking if a chain is behaving in the right order.
func tagMiddleware(tag string) Constructor {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(tag))
h.ServeHTTP(w, r)
})
}
}
// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
// but the best we can do.
func funcsEqual(f1, f2 interface{}) bool {
val1 := reflect.ValueOf(f1)
val2 := reflect.ValueOf(f2)
return val1.Pointer() == val2.Pointer()
}
var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("app\n"))
})
func TestNew(t *testing.T) {
c1 := func(h http.Handler) http.Handler {
return nil
}
c2 := func(h http.Handler) http.Handler {
return http.StripPrefix("potato", nil)
}
slice := []Constructor{c1, c2}
chain := NewChain(slice...)
for k := range slice {
if !funcsEqual(chain.constructors[k], slice[k]) {
t.Error("New does not add constructors correctly")
}
}
}
func TestThenWorksWithNoMiddleware(t *testing.T) {
if !funcsEqual(NewChain().Then(testApp), testApp) {
t.Error("Then does not work with no middleware")
}
}
func TestThenTreatsNilAsDefaultServeMux(t *testing.T) {
if NewChain().Then(nil) != http.DefaultServeMux {
t.Error("Then does not treat nil as DefaultServeMux")
}
}
func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) {
if NewChain().ThenFunc(nil) != http.DefaultServeMux {
t.Error("ThenFunc does not treat nil as DefaultServeMux")
}
}
func TestThenFuncConstructsHandlerFunc(t *testing.T) {
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
chained := NewChain().ThenFunc(fn)
rec := httptest.NewRecorder()
chained.ServeHTTP(rec, (*http.Request)(nil))
if reflect.TypeOf(chained) != reflect.TypeOf(http.HandlerFunc(nil)) {
t.Error("ThenFunc does not construct HandlerFunc")
}
}
func TestThenOrdersHandlersCorrectly(t *testing.T) {
t1 := tagMiddleware("t1\n")
t2 := tagMiddleware("t2\n")
t3 := tagMiddleware("t3\n")
chained := NewChain(t1, t2, t3).Then(testApp)
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
chained.ServeHTTP(w, r)
if w.Body.String() != "t1\nt2\nt3\napp\n" {
t.Error("Then does not order handlers correctly")
}
}
func TestAppendAddsHandlersCorrectly(t *testing.T) {
chain := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
newChain := chain.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
if len(chain.constructors) != 2 {
t.Error("chain should have 2 constructors")
}
if len(newChain.constructors) != 4 {
t.Error("newChain should have 4 constructors")
}
chained := newChain.Then(testApp)
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
chained.ServeHTTP(w, r)
if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" {
t.Error("Append does not add handlers correctly")
}
}
func TestAppendRespectsImmutability(t *testing.T) {
chain := NewChain(tagMiddleware(""))
newChain := chain.Append(tagMiddleware(""))
if &chain.constructors[0] == &newChain.constructors[0] {
t.Error("Apppend does not respect immutability")
}
}
func TestExtendAddsHandlersCorrectly(t *testing.T) {
chain1 := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
chain2 := NewChain(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
newChain := chain1.Extend(chain2)
if len(chain1.constructors) != 2 {
t.Error("chain1 should contain 2 constructors")
}
if len(chain2.constructors) != 2 {
t.Error("chain2 should contain 2 constructors")
}
if len(newChain.constructors) != 4 {
t.Error("newChain should contain 4 constructors")
}
chained := newChain.Then(testApp)
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
chained.ServeHTTP(w, r)
if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" {
t.Error("Extend does not add handlers in correctly")
}
}
func TestExtendRespectsImmutability(t *testing.T) {
chain := NewChain(tagMiddleware(""))
newChain := chain.Extend(NewChain(tagMiddleware("")))
if &chain.constructors[0] == &newChain.constructors[0] {
t.Error("Extend does not respect immutability")
}
}

View file

@ -1,2 +1,2 @@
// Package middleware provides a standard set of middleware implementations for pomerium.
// Package middleware provides a standard set of middleware for pomerium.
package middleware // import "github.com/pomerium/pomerium/internal/middleware"

View file

@ -79,12 +79,10 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
domain := req.Host
if name == cs.csrfName() {
domain = req.Host
} else if cs.CookieDomain != "" {
if cs.CookieDomain != "" {
domain = cs.CookieDomain
} else {
domain = splitDomain(domain)
domain = ParentSubdomain(domain)
}
if h, _, err := net.SplitHostPort(domain); err == nil {
@ -105,19 +103,11 @@ func (cs *CookieStore) makeCookie(req *http.Request, name string, value string,
return c
}
func (cs *CookieStore) csrfName() string {
return fmt.Sprintf("%s_csrf", cs.Name)
}
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return cs.makeCookie(req, cs.Name, value, expiration, now)
}
func (cs *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return cs.makeCookie(req, cs.csrfName(), value, expiration, now)
}
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
if len(cookie.String()) <= MaxChunkSize {
http.SetCookie(w, cookie)
@ -134,7 +124,6 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
nc.Name = fmt.Sprintf("%s_%d", cookie.Name, i)
nc.Value = c
}
fmt.Println(i)
http.SetCookie(w, &nc)
}
}
@ -150,25 +139,6 @@ func chunk(s string, size int) []string {
return ss
}
// ClearCSRF clears the CSRF cookie from the request
func (cs *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, cs.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
}
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
func (cs *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(w, cs.makeCSRFCookie(req, val, cs.CookieExpire, time.Now()))
}
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
func (cs *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
c, err := req.Cookie(cs.csrfName())
if err != nil {
return nil, ErrEmptyCSRF // ErrNoCookie is confusing in this context
}
return c, nil
}
// ClearSession clears the session cookie from a request
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
@ -235,7 +205,8 @@ func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *
return nil
}
func splitDomain(s string) string {
// ParentSubdomain returns the parent subdomain.
func ParentSubdomain(s string) string {
if strings.Count(s, ".") < 2 {
return ""
}

View file

@ -1,4 +1,4 @@
package sessions
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"crypto/rand"
@ -38,7 +38,7 @@ func (a mockCipher) Unmarshal(s string, i interface{}) error {
return nil
}
func TestNewCookieStore(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
@ -111,7 +111,7 @@ func TestNewCookieStore(t *testing.T) {
}
func TestCookieStore_makeCookie(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
@ -155,62 +155,13 @@ func TestCookieStore_makeCookie(t *testing.T) {
if diff := cmp.Diff(s.makeSessionCookie(r, tt.value, tt.expiration, now), tt.want); diff != "" {
t.Errorf("CookieStore.makeSessionCookie() = \n%s", diff)
}
got := s.makeCSRFCookie(r, tt.value, tt.expiration, now)
tt.wantCSRF.Name = "_pomerium_csrf"
if !reflect.DeepEqual(got, tt.wantCSRF) {
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.wantCSRF)
}
w := httptest.NewRecorder()
want := "new-csrf"
s.SetCSRF(w, r, want)
found := false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearCSRF(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
t.Error("clear csrf failed")
break
}
}
w = httptest.NewRecorder()
want = "new-session"
s.setSessionCookie(w, r, want)
found = false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearSession(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
t.Error("clear csrf failed")
break
}
}
})
}
}
func TestCookieStore_SaveSession(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
@ -265,38 +216,6 @@ func TestCookieStore_SaveSession(t *testing.T) {
}
}
func TestMockCSRFStore(t *testing.T) {
tests := []struct {
name string
mockCSRF *MockCSRFStore
newCSRFValue string
wantErr bool
}{
{"basic",
&MockCSRFStore{
ResponseCSRF: "ok",
Cookie: &http.Cookie{Name: "hi"}},
"newcsrf",
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ms := tt.mockCSRF
ms.SetCSRF(nil, nil, tt.newCSRFValue)
ms.ClearCSRF(nil, nil)
got, err := ms.GetCSRF(nil)
if (err != nil) != tt.wantErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.mockCSRF.Cookie) {
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Cookie)
}
})
}
}
func TestMockSessionStore(t *testing.T) {
tests := []struct {
name string
@ -341,7 +260,7 @@ func TestMockSessionStore(t *testing.T) {
}
}
func Test_splitDomain(t *testing.T) {
func Test_ParentSubdomain(t *testing.T) {
t.Parallel()
tests := []struct {
s string
@ -354,8 +273,8 @@ func Test_splitDomain(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) {
if got := splitDomain(tt.s); got != tt.want {
t.Errorf("splitDomain() = %v, want %v", got, tt.want)
if got := ParentSubdomain(tt.s); got != tt.want {
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
}
})
}

View file

@ -0,0 +1,130 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"context"
"errors"
"net/http"
"strings"
)
// Context keys
var (
SessionCtxKey = &contextKey{"Session"}
ErrorCtxKey = &contextKey{"Error"}
)
// Library errors
var (
ErrExpired = errors.New("internal/sessions: session is expired")
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
ErrMalformed = errors.New("internal/sessions: session is malformed")
)
// RetrieveSession http middleware handler will verify a auth session from a http request.
//
// RetrieveSession will search for a auth session in a http request, in the order:
// 1. `pomerium_session` URI query parameter
// 2. `Authorization: BEARER` request header
// 3. Cookie `_pomerium` value
func RetrieveSession(s SessionStore) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return retrieve(s, TokenFromQuery, TokenFromHeader, TokenFromCookie)(next)
}
}
func retrieve(s SessionStore, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
hfn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, err := retrieveFromRequest(s, r, findTokenFns...)
ctx = NewContext(ctx, token, err)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(hfn)
}
}
func retrieveFromRequest(s SessionStore, r *http.Request, findTokenFns ...func(r *http.Request) string) (*State, error) {
var tokenStr string
var err error
// Extract token string from the request by calling token find functions in
// the order they where provided. Further extraction stops if a function
// returns a non-empty string.
for _, fn := range findTokenFns {
tokenStr = fn(r)
if tokenStr != "" {
break
}
}
if tokenStr == "" {
return nil, ErrNoSessionFound
}
state, err := s.LoadSession(r)
if err != nil {
return nil, ErrMalformed
}
err = state.Valid()
if err != nil {
// a little unusual but we want to return the expired state too
return state, err
}
// Valid!
return state, nil
}
// NewContext sets context values for the user session state and error.
func NewContext(ctx context.Context, t *State, err error) context.Context {
ctx = context.WithValue(ctx, SessionCtxKey, t)
ctx = context.WithValue(ctx, ErrorCtxKey, err)
return ctx
}
// FromContext retrieves context values for the user session state and error.
func FromContext(ctx context.Context) (*State, error) {
state, _ := ctx.Value(SessionCtxKey).(*State)
err, _ := ctx.Value(ErrorCtxKey).(error)
return state, err
}
// TokenFromCookie tries to retrieve the token string from a cookie named
// "_pomerium".
func TokenFromCookie(r *http.Request) string {
cookie, err := r.Cookie("_pomerium")
if err != nil {
return ""
}
return cookie.Value
}
// TokenFromHeader tries to retrieve the token string from the
// "Authorization" request header: "Authorization: BEARER T".
func TokenFromHeader(r *http.Request) string {
// Get token from authorization header.
bearer := r.Header.Get("Authorization")
if len(bearer) > 7 && strings.EqualFold(bearer[0:6], "BEARER") {
return bearer[7:]
}
return ""
}
// TokenFromQuery tries to retrieve the token string from the "pomerium_session" URI
// query parameter.
// todo(bdd) : document setting session code as queryparam
func TokenFromQuery(r *http.Request) string {
// Get token from query param named "pomerium_session".
return r.URL.Query().Get("pomerium_session")
}
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation. This technique
// for defining context keys was copied from Go 1.7's new use of context in net/http.
type contextKey struct {
name string
}
func (k *contextKey) String() string {
return "SessionStore context value " + k.name
}

View file

@ -0,0 +1,133 @@
package sessions
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/google/go-cmp/cmp"
)
func TestNewContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
t *State
err error
want context.Context
}{
{"simple", context.Background(), &State{Email: "bdd@pomerium.io"}, nil, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctxOut := NewContext(tt.ctx, tt.t, tt.err)
stateOut, errOut := FromContext(ctxOut)
if diff := cmp.Diff(tt.t, stateOut); diff != "" {
t.Errorf("NewContext() = %s", diff)
}
if diff := cmp.Diff(tt.err, errOut); diff != "" {
t.Errorf("NewContext() = %s", diff)
}
})
}
}
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
// s SessionStore
state State
cookie bool
header bool
param bool
wantBody string
wantStatus int
}{
{"good cookie session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth header session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewCipherFromBase64(cryptutil.NewBase64Key())
if err != nil {
t.Fatal(err)
}
encSession, err := MarshalSession(&tt.state, cipher)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string
encSession += cryptutil.NewBase64Key()
fmt.Println(encSession)
}
cs, err := NewCookieStore(&CookieStoreOptions{
Name: "_pomerium",
CookieCipher: cipher,
})
if err != nil {
t.Fatal(err)
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
if tt.cookie {
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: encSession})
} else if tt.header {
r.Header.Set("Authorization", "Bearer "+encSession)
} else if tt.param {
q := r.URL.Query()
q.Add("pomerium_session", encSession)
r.URL.RawQuery = q.Encode()
}
got := RetrieveSession(cs)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}

View file

@ -4,28 +4,6 @@ import (
"net/http"
)
// MockCSRFStore is a mock implementation of the CSRF store interface
type MockCSRFStore struct {
ResponseCSRF string
Cookie *http.Cookie
GetError error
}
// SetCSRF sets the ResponseCSRF string to a val
func (ms MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
ms.ResponseCSRF = val
}
// ClearCSRF clears the ResponseCSRF string
func (ms MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
ms.ResponseCSRF = ""
}
// GetCSRF returns the cookie and error
func (ms MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
return ms.Cookie, ms.GetError
}
// MockSessionStore is a mock implementation of the SessionStore interface
type MockSessionStore struct {
ResponseSession string

View file

@ -10,9 +10,6 @@ import (
"github.com/pomerium/pomerium/internal/cryptutil"
)
// ErrExpired is an error for a expired sessions.
var ErrExpired = fmt.Errorf("internal/sessions: expired session")
// State is our object that keeps track of a user's session state
type State struct {
AccessToken string `json:"access_token"`

View file

@ -12,7 +12,7 @@ import (
)
func TestStateSerialization(t *testing.T) {
secret := cryptutil.GenerateKey()
secret := cryptutil.NewKey()
c, err := cryptutil.NewCipher(secret)
if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err)
@ -123,7 +123,7 @@ func TestState_Impersonating(t *testing.T) {
}
func TestMarshalSession(t *testing.T) {
secret := cryptutil.GenerateKey()
secret := cryptutil.NewKey()
c, err := cryptutil.NewCipher(secret)
if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err)

View file

@ -8,16 +8,6 @@ import (
// ErrEmptySession is an error for an empty sessions.
var ErrEmptySession = errors.New("internal/sessions: empty session")
// ErrEmptyCSRF is an error for an empty sessions.
var ErrEmptyCSRF = errors.New("internal/sessions: empty csrf")
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
type CSRFStore interface {
SetCSRF(http.ResponseWriter, *http.Request, string)
GetCSRF(*http.Request) (*http.Cookie, error)
ClearCSRF(http.ResponseWriter, *http.Request)
}
// SessionStore has the functions for setting, getting, and clearing the Session cookie
type SessionStore interface {
ClearSession(http.ResponseWriter, *http.Request)

View file

@ -306,7 +306,7 @@ func New() *template.Template {
</svg>
<form method="POST" action="{{.SignoutURL}}">
<section>
<h2>Session</h2>
<h2>Current user</h2>
<p class="message">Your current session details.</p>
<fieldset>
<label>
@ -334,10 +334,22 @@ func New() *template.Template {
</fieldset>
</section>
<div class="flex">
<button class="button half" type="submit">Sign Out</button>
<a href="/.pomerium/refresh" class="button half">Refresh</a>
{{ .csrfField }}
<button class="button full" type="submit">Sign Out</button>
</div>
</form>
<section>
<h2>Refresh Identity</h2>
<p class="message">Pomerium will automatically refresh your user session. However, if your group memberships have recently changed and haven't taken effect yet, you can refresh your session manually.</p>
<form method="POST" action="/.pomerium/refresh">
<div class="flex">
{{ .csrfField }}
<button class="button full" type="submit">Refresh</button>
</div>
</form>
</section>
{{if .IsAdmin}}
<form method="POST" action="/.pomerium/impersonate">
<section>
@ -355,7 +367,7 @@ func New() *template.Template {
</fieldset>
</section>
<div class="flex">
<input name="csrf" type="hidden" value="{{.CSRF}}">
{{ .csrfField }}
<button class="button full" type="submit">Impersonate session</button>
</div>
</form>

View file

@ -1,9 +1,14 @@
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
import (
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
)
// StripPort returns a host, without any port number.
@ -32,18 +37,73 @@ func ParseAndValidateURL(rawurl string) (*url.URL, error) {
if err != nil {
return nil, err
}
if u.Scheme == "" {
return nil, fmt.Errorf("%s url does contain a valid scheme. Did you mean https://%s?", rawurl, rawurl)
}
if u.Host == "" {
return nil, fmt.Errorf("%s url does contain a valid hostname", rawurl)
if err := ValidateURL(u); err != nil {
return nil, err
}
return u, nil
}
// ValidateURL wraps standard library's default url.Parse because
// it's much more lenient about what type of urls it accepts than pomerium.
func ValidateURL(u *url.URL) error {
if u == nil {
return fmt.Errorf("nil url")
}
if u.Scheme == "" {
return fmt.Errorf("%s url does contain a valid scheme. Did you mean https://%s?", u.String(), u.String())
}
if u.Host == "" {
return fmt.Errorf("%s url does contain a valid hostname", u.String())
}
return nil
}
func DeepCopy(u *url.URL) (*url.URL, error) {
if u == nil {
return nil, nil
}
return ParseAndValidateURL(u.String())
}
// testTimeNow can be used in tests to set a specific int64 time
var testTimeNow int64
// timestamp returns the current timestamp, in seconds.
//
// For testing purposes, the function that generates the timestamp can be
// overridden. If not set, it will return time.Now().UTC().Unix().
func timestamp() int64 {
if testTimeNow == 0 {
return time.Now().UTC().Unix()
}
return testTimeNow
}
// SignedRedirectURL takes a destination URL and adds redirect_uri to it's
// query params, along with a timestamp and an keyed signature.
func SignedRedirectURL(key string, destination, urlToSign *url.URL) *url.URL {
now := timestamp()
rawURL := urlToSign.String()
params, _ := url.ParseQuery(destination.RawQuery) // handled by incoming mux
params.Set("redirect_uri", rawURL)
params.Set("ts", fmt.Sprint(now))
params.Set("sig", hmacURL(key, rawURL, now))
destination.RawQuery = params.Encode()
return destination
}
// hmacURL takes a redirect url string and timestamp and returns the base64
// encoded HMAC result.
func hmacURL(key, data string, timestamp int64) string {
h := cryptutil.Hash(key, []byte(fmt.Sprint(data, timestamp)))
return base64.URLEncoding.EncodeToString(h)
}
// GetAbsoluteURL returns the current handler's absolute url.
// https://stackoverflow.com/a/23152483
func GetAbsoluteURL(r *http.Request) *url.URL {
u := r.URL
u.Scheme = "https"
u.Host = r.Host
return u
}

View file

@ -1,6 +1,7 @@
package urlutil
import (
"net/http"
"net/url"
"reflect"
"testing"
@ -35,7 +36,7 @@ func Test_StripPort(t *testing.T) {
}
func TestParseAndValidateURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
rawurl string
@ -63,7 +64,7 @@ func TestParseAndValidateURL(t *testing.T) {
}
func TestDeepCopy(t *testing.T) {
t.Parallel()
tests := []struct {
name string
u *url.URL
@ -87,3 +88,90 @@ func TestDeepCopy(t *testing.T) {
})
}
}
func TestValidateURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
u *url.URL
wantErr bool
}{
{"good", &url.URL{Scheme: "https", Host: "some.example"}, false},
{"nil", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ValidateURL(tt.u); (err != nil) != tt.wantErr {
t.Errorf("ValidateURL() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestSignedRedirectURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mockedTime int64
key string
destination *url.URL
urlToSign *url.URL
want *url.URL
}{
{"good", 2, "hunter42", &url.URL{Host: "pomerium.io", Scheme: "https://"}, &url.URL{Host: "pomerium.io", Scheme: "https://", Path: "/ok"}, &url.URL{Host: "pomerium.io", Scheme: "https://", RawQuery: "redirect_uri=https%3A%2F%2F%3A%2F%2Fpomerium.io%2Fok&sig=7jdo1XFcmuhjBHnpfVhll5cXflYByeMnbp5kRz87CVQ%3D&ts=2"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testTimeNow = tt.mockedTime
got := SignedRedirectURL(tt.key, tt.destination, tt.urlToSign)
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("SignedRedirectURL() = diff %v", diff)
}
})
}
}
func Test_timestamp(t *testing.T) {
t.Parallel()
tests := []struct {
name string
dontWant int64
}{
{"if unset should never return", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testTimeNow = tt.dontWant
if got := timestamp(); got == tt.dontWant {
t.Errorf("timestamp() = %v, dontWant %v", got, tt.dontWant)
}
})
}
}
func parseURLHelper(s string) *url.URL {
u, _ := url.Parse(s)
return u
}
func TestGetAbsoluteURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
u *url.URL
want *url.URL
}{
{"add https", parseURLHelper("http://pomerium.io"), parseURLHelper("https://pomerium.io")},
{"missing scheme", parseURLHelper("https://pomerium.io"), parseURLHelper("https://pomerium.io")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := http.Request{URL: tt.u, Host: tt.u.Host}
got := GetAbsoluteURL(&r)
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("GetAbsoluteURL() = %v", diff)
}
})
}
}

View file

@ -1,15 +1,15 @@
package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/pomerium/csrf"
"github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
@ -18,34 +18,55 @@ import (
"github.com/pomerium/pomerium/internal/urlutil"
)
// StateParameter holds the redirect id along with the session id.
type StateParameter struct {
SessionID string `json:"session_id"`
RedirectURI string `json:"redirect_uri"`
}
// Handler returns the proxy service's ServeMux
func (p *Proxy) Handler() http.Handler {
// validation middleware chain
validate := middleware.NewChain()
validate = validate.Append(middleware.ValidateHost(func(host string) bool {
r := httputil.NewRouter().StrictSlash(true)
r.Use(middleware.ValidateHost(func(host string) bool {
_, ok := p.routeConfigs[host]
return ok
}))
mux := http.NewServeMux()
mux.HandleFunc("/robots.txt", p.RobotsTxt)
mux.HandleFunc("/.pomerium", p.UserDashboard)
mux.HandleFunc("/.pomerium/impersonate", p.Impersonate) // POST
mux.HandleFunc("/.pomerium/sign_out", p.SignOut)
// handlers with validation
mux.Handle("/.pomerium/callback", validate.ThenFunc(p.AuthenticateCallback))
mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.ForceRefresh))
mux.Handle("/", validate.ThenFunc(p.Proxy))
return mux
r.Use(csrf.Protect(
p.cookieSecret,
csrf.Path("/"),
csrf.Domain(p.cookieDomain),
csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieName)),
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
))
r.HandleFunc("/robots.txt", p.RobotsTxt)
// requires authN not authZ
r.Use(sessions.RetrieveSession(p.sessionStore))
r.Use(p.VerifySession)
r.HandleFunc("/.pomerium/", p.UserDashboard).Methods(http.MethodGet)
r.HandleFunc("/.pomerium/impersonate", p.Impersonate).Methods(http.MethodPost)
r.HandleFunc("/.pomerium/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost)
r.HandleFunc("/.pomerium/refresh", p.ForceRefresh).Methods(http.MethodPost)
r.PathPrefix("/").HandlerFunc(p.Proxy)
return r
}
// VerifySession is the middleware used to enforce a valid authentication
// session state is attached to the users's request context.
func (p *Proxy) VerifySession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
state, err := sessions.FromContext(r.Context())
if err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error")
p.authenticate(w, r)
return
}
if err := state.Valid(); err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to invalid session")
p.authenticate(w, r)
return
}
next.ServeHTTP(w, r)
})
}
// RobotsTxt sets the User-Agent header in the response to be "Disallow"
func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
}
@ -55,110 +76,18 @@ func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
// the local session state.
func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
switch r.Method {
case http.MethodPost:
if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
uri, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
if err == nil && uri.String() != "" {
redirectURL = uri
}
default:
uri, err := urlutil.ParseAndValidateURL(r.URL.Query().Get("redirect_uri"))
if err == nil && uri.String() != "" {
redirectURL = uri
}
if uri, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")); err == nil && uri.String() != "" {
redirectURL = uri
}
http.Redirect(w, r, p.GetSignOutURL(p.authenticateURL, redirectURL).String(), http.StatusFound)
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
http.Redirect(w, r, uri.String(), http.StatusFound)
}
// OAuthStart begins the authenticate flow, encrypting the redirect url
// Authenticate begins the authenticate flow, encrypting the redirect url
// in a request to the provider's sign in endpoint.
func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
state := &StateParameter{
SessionID: fmt.Sprintf("%x", cryptutil.GenerateKey()),
RedirectURI: r.URL.String(),
}
// Encrypt CSRF + redirect_uri and store in csrf session. Validated on callback.
csrfState, err := p.cipher.Marshal(state)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
p.csrfStore.SetCSRF(w, r, csrfState)
paramState, err := p.cipher.Marshal(state)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
// Sanity check. The encrypted payload of local and remote state should
// never match as each encryption round uses a cryptographic nonce.
// if paramState == csrfState {
// httputil.ErrorResponse(w, r, httputil.Error("encrypted state should not match", http.StatusBadRequest, nil))
// return
// }
signinURL := p.GetSignInURL(p.authenticateURL, p.GetRedirectURL(r.Host), paramState)
// Redirect the user to the authenticate service along with the encrypted
// state which contains a redirect uri back to the proxy and a nonce
http.Redirect(w, r, signinURL.String(), http.StatusFound)
}
// AuthenticateCallback checks the state parameter to make sure it matches the
// local csrf state then redirects the user back to the original intended route.
func (p *Proxy) AuthenticateCallback(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
// Encrypted CSRF passed from authenticate service
remoteStateEncrypted := r.Form.Get("state")
var remoteStatePlain StateParameter
if err := p.cipher.Unmarshal(remoteStateEncrypted, &remoteStatePlain); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
c, err := p.csrfStore.GetCSRF(r)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
p.csrfStore.ClearCSRF(w, r)
localStateEncrypted := c.Value
var localStatePlain StateParameter
err = p.cipher.Unmarshal(localStateEncrypted, &localStatePlain)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
// assert no nonce reuse
if remoteStateEncrypted == localStateEncrypted {
p.sessionStore.ClearSession(w, r)
httputil.ErrorResponse(w, r,
httputil.Error("local and remote state", http.StatusBadRequest,
fmt.Errorf("possible nonce-reuse / replay attack")))
return
}
// Decrypted remote and local state struct (inc. nonce) must match
if remoteStatePlain.SessionID != localStatePlain.SessionID {
p.sessionStore.ClearSession(w, r)
httputil.ErrorResponse(w, r, httputil.Error("CSRF mismatch", http.StatusBadRequest, nil))
return
}
// This is the redirect back to the original requested application
http.Redirect(w, r, remoteStatePlain.RedirectURI, http.StatusFound)
func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request) {
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
http.Redirect(w, r, uri.String(), http.StatusFound)
}
// shouldSkipAuthentication contains conditions for skipping authentication.
@ -189,17 +118,6 @@ func isCORSPreflight(r *http.Request) bool {
r.Header.Get("Origin") != ""
}
func (p *Proxy) loadExistingSession(r *http.Request) (*sessions.State, error) {
s, err := p.sessionStore.LoadSession(r)
if err != nil {
return nil, fmt.Errorf("proxy: invalid session: %w", err)
}
if err := s.Valid(); err != nil {
return nil, fmt.Errorf("proxy: invalid state: %w", err)
}
return s, nil
}
// Proxy authenticates a request, either proxying the request if it is authenticated,
// or starting the authenticate service for validation if not.
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
@ -214,11 +132,10 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
route.ServeHTTP(w, r)
return
}
s, err := p.loadExistingSession(r)
if err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
p.OAuthStart(w, r)
s, err := sessions.FromContext(r.Context())
if err != nil || s == nil {
log.Debug().Err(err).Msg("proxy: couldn't get session from context")
p.authenticate(w, r)
return
}
authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s)
@ -226,7 +143,7 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, err)
return
} else if !authorized {
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.Email), http.StatusForbidden, nil))
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.RequestEmail()), http.StatusForbidden, nil))
return
}
r.Header.Set(HeaderUserID, s.User)
@ -240,62 +157,41 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
// It also contains certain administrative actions like user impersonation.
// Nota bene: This endpoint does authentication, not authorization.
func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
session, err := p.loadExistingSession(r)
session, err := sessions.FromContext(r.Context())
if err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
p.OAuthStart(w, r)
httputil.ErrorResponse(w, r, err)
return
}
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/.pomerium/sign_out"}
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
// CSRF value used to mitigate replay attacks.
csrf := &StateParameter{SessionID: fmt.Sprintf("%x", cryptutil.GenerateKey())}
csrfCookie, err := p.cipher.Marshal(csrf)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
p.csrfStore.SetCSRF(w, r, csrfCookie)
t := struct {
Email string
User string
Groups []string
RefreshDeadline string
SignoutURL string
IsAdmin bool
ImpersonateEmail string
ImpersonateGroup string
CSRF string
}{
Email: session.Email,
User: session.User,
Groups: session.Groups,
RefreshDeadline: time.Until(session.RefreshDeadline).Round(time.Second).String(),
SignoutURL: p.GetSignOutURL(p.authenticateURL, redirectURL).String(),
IsAdmin: isAdmin,
ImpersonateEmail: session.ImpersonateEmail,
ImpersonateGroup: strings.Join(session.ImpersonateGroups, ","),
CSRF: csrf.SessionID,
}
templates.New().ExecuteTemplate(w, "dashboard.html", t)
//todo(bdd): make sign out redirect a configuration option so that
// admins can set to whatever their corporate homepage is
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
signoutURL := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
templates.New().ExecuteTemplate(w, "dashboard.html", map[string]interface{}{
"Email": session.Email,
"User": session.User,
"Groups": session.Groups,
"RefreshDeadline": time.Until(session.RefreshDeadline).Round(time.Second).String(),
"SignoutURL": signoutURL.String(),
"IsAdmin": isAdmin,
"ImpersonateEmail": session.ImpersonateEmail,
"ImpersonateGroup": strings.Join(session.ImpersonateGroups, ","),
"csrfField": csrf.TemplateField(r),
})
}
// ForceRefresh redeems and extends an existing authenticated oidc session with
// the underlying identity provider. All session details including groups,
// timeouts, will be renewed.
func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
session, err := p.loadExistingSession(r)
session, err := sessions.FromContext(r.Context())
if err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
p.OAuthStart(w, r)
httputil.ErrorResponse(w, r, err)
return
}
iss, err := session.IssuedAt()
@ -324,49 +220,25 @@ func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
// to the user's current user sessions state if the user is currently an
// administrative user. Requests are redirected back to the user dashboard.
func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
session, err := p.loadExistingSession(r)
if err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
p.OAuthStart(w, r)
return
}
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
if err != nil || !isAdmin {
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.Email), http.StatusForbidden, err))
return
}
// CSRF check -- did this request originate from our form?
c, err := p.csrfStore.GetCSRF(r)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
p.csrfStore.ClearCSRF(w, r)
encryptedCSRF := c.Value
var decryptedCSRF StateParameter
if err = p.cipher.Unmarshal(encryptedCSRF, decryptedCSRF); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
if decryptedCSRF.SessionID != r.FormValue("csrf") {
httputil.ErrorResponse(w, r, httputil.Error("CSRF mismatch", http.StatusBadRequest, nil))
return
}
// OK to impersonation
session.ImpersonateEmail = r.FormValue("email")
session.ImpersonateGroups = strings.Split(r.FormValue("group"), ",")
if err := p.sessionStore.SaveSession(w, r, session); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
session, err := sessions.FromContext(r.Context())
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
if err != nil || !isAdmin {
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.RequestEmail()), http.StatusForbidden, err))
return
}
// OK to impersonation
session.ImpersonateEmail = r.FormValue("email")
session.ImpersonateGroups = strings.Split(r.FormValue("group"), ",")
if err := p.sessionStore.SaveSession(w, r, session); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
http.Redirect(w, r, "/.pomerium", http.StatusFound)
}
@ -391,48 +263,3 @@ func (p *Proxy) policy(r *http.Request) (*config.Policy, bool) {
}
return nil, false
}
// GetRedirectURL returns the redirect url for a single reverse proxy host. HTTPS is set explicitly.
func (p *Proxy) GetRedirectURL(host string) *url.URL {
u := p.redirectURL
u.Scheme = "https"
u.Host = host
return u
}
// signRedirectURL takes a redirect url string and timestamp and returns the base64
// encoded HMAC result.
func (p *Proxy) signRedirectURL(rawRedirect string, timestamp time.Time) string {
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
h := cryptutil.Hash(p.SharedKey, data)
return base64.URLEncoding.EncodeToString(h)
}
// GetSignInURL with typical oauth parameters
func (p *Proxy) GetSignInURL(authenticateURL, redirectURL *url.URL, state string) *url.URL {
a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_in"})
now := time.Now()
rawRedirect := redirectURL.String()
params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux
params.Set("redirect_uri", rawRedirect)
params.Set("shared_secret", p.SharedKey)
params.Set("response_type", "code")
params.Add("state", state)
params.Set("ts", fmt.Sprint(now.Unix()))
params.Set("sig", p.signRedirectURL(rawRedirect, now))
a.RawQuery = params.Encode()
return a
}
// GetSignOutURL creates and returns the sign out URL, given a redirectURL
func (p *Proxy) GetSignOutURL(authenticateURL, redirectURL *url.URL) *url.URL {
a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_out"})
now := time.Now()
rawRedirect := redirectURL.String()
params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux
params.Add("redirect_uri", rawRedirect)
params.Set("ts", fmt.Sprint(now.Unix()))
params.Set("sig", p.signRedirectURL(rawRedirect, now))
a.RawQuery = params.Encode()
return a
}

View file

@ -7,45 +7,17 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"time"
"github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/proxy/clients"
)
type mockCipher struct{}
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Marshal(s interface{}) (string, error) {
if s == "error" {
return "", errors.New("error")
}
return "ok", nil
}
func (a mockCipher) Unmarshal(s string, i interface{}) error {
if s == "unmarshal error" || s == "error" {
return errors.New("error")
}
return nil
}
func TestProxy_RobotsTxt(t *testing.T) {
proxy := Proxy{}
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
@ -60,94 +32,6 @@ func TestProxy_RobotsTxt(t *testing.T) {
}
}
func TestProxy_GetRedirectURL(t *testing.T) {
tests := []struct {
name string
host string
want *url.URL
}{
{"google", "google.com", &url.URL{Scheme: "https", Host: "google.com", Path: "/.pomerium/callback"}},
{"pomerium", "pomerium.io", &url.URL{Scheme: "https", Host: "pomerium.io", Path: "/.pomerium/callback"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Proxy{redirectURL: &url.URL{Path: "/.pomerium/callback"}}
if got := p.GetRedirectURL(tt.host); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Proxy.GetRedirectURL() = %v, want %v", got, tt.want)
}
})
}
}
func TestProxy_signRedirectURL(t *testing.T) {
tests := []struct {
name string
rawRedirect string
timestamp time.Time
want string
}{
{"pomerium", "https://pomerium.io/.pomerium/callback", fixedDate, "wq3rAjRGN96RXS8TAzH-uxQTD0XgY_8ZYEKMiOLD5P4="},
{"google", "https://google.com/.pomerium/callback", fixedDate, "7EYHZObq167CuyuPm5CqOtkU4zg5dFeUCs7W7QOrgNQ="},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Proxy{}
if got := p.signRedirectURL(tt.rawRedirect, tt.timestamp); got != tt.want {
t.Errorf("Proxy.signRedirectURL() = %v, want %v", got, tt.want)
}
})
}
}
func TestProxy_GetSignOutURL(t *testing.T) {
tests := []struct {
name string
authenticate string
redirect string
wantPrefix string
}{
{"good", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "https://auth.corp.pomerium.io/sign_out?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
authenticateURL, _ := url.Parse(tt.authenticate)
redirectURL, _ := url.Parse(tt.redirect)
p := &Proxy{}
// signature is ignored as it is tested above. Avoids testing time.Now
if got := p.GetSignOutURL(authenticateURL, redirectURL); !strings.HasPrefix(got.String(), tt.wantPrefix) {
t.Errorf("Proxy.GetSignOutURL() = %v, wantPrefix %v", got.String(), tt.wantPrefix)
}
})
}
}
func TestProxy_GetSignInURL(t *testing.T) {
tests := []struct {
name string
authenticate string
redirect string
state string
wantPrefix string
}{
{"good", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "example_state", "https://auth.corp.pomerium.io/sign_in?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io&response_type=code&shared_secret=shared-secret"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Proxy{SharedKey: "shared-secret"}
authenticateURL, _ := url.Parse(tt.authenticate)
redirectURL, _ := url.Parse(tt.redirect)
if got := p.GetSignInURL(authenticateURL, redirectURL, tt.state); !strings.HasPrefix(got.String(), tt.wantPrefix) {
t.Errorf("Proxy.GetSignOutURL() = %v, wantPrefix %v", got.String(), tt.wantPrefix)
}
})
}
}
func TestProxy_Signout(t *testing.T) {
opts := testOptions(t)
err := ValidateOptions(opts)
@ -171,7 +55,7 @@ func TestProxy_Signout(t *testing.T) {
}
}
func TestProxy_OAuthStart(t *testing.T) {
func TestProxy_authenticate(t *testing.T) {
proxy, err := New(testOptions(t))
if err != nil {
t.Fatal(err)
@ -179,18 +63,19 @@ func TestProxy_OAuthStart(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/oauth-start", nil)
rr := httptest.NewRecorder()
proxy.OAuthStart(rr, req)
proxy.authenticate(rr, req)
// expect oauth redirect
if status := rr.Code; status != http.StatusFound {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
}
// expected url
expected := `<a href="https://authenticate.example/sign_in`
expected := `<a href="https://authenticate.example/.pomerium/sign_in`
body := rr.Body.String()
if !strings.HasPrefix(body, expected) {
t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
}
}
func TestProxy_Handler(t *testing.T) {
proxy, err := New(testOptions(t))
if err != nil {
@ -226,7 +111,6 @@ func TestProxy_router(t *testing.T) {
{"good corp", "https://corp.example.com", policies, nil, true},
{"good with slash", "https://corp.example.com/", policies, nil, true},
{"good with path", "https://corp.example.com/123", policies, nil, true},
// {"multiple", "https://corp.example.com/", map[string]string{"corp.unrelated.com": "unrelated.com", "corp.example.com": "example.com"}, nil, true},
{"no policies", "https://notcorp.example.com/123", []config.Policy{}, nil, false},
{"bad corp", "https://notcorp.example.com/123", policies, nil, false},
{"bad sub-sub", "https://notcorp.corp.example.com/123", policies, nil, false},
@ -254,11 +138,15 @@ func TestProxy_Proxy(t *testing.T) {
goodSession := &sessions.State{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(20 * time.Second),
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
@ -285,25 +173,24 @@ func TestProxy_Proxy(t *testing.T) {
authorizer clients.Authorizer
wantStatus int
}{
{"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
{"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(20 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
{"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
// same request as above, but with cors_allow_preflight=false in the policy
{"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
{"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
// cors allowed, but the request is missing proper headers
{"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
{"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
// redirect to start auth process
{"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
{"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
{"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError},
{"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
{"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError},
// authenticate errors
{"session error, redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: errors.New("weird"), Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
{"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
{"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound},
{"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
{"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
}
for _, tt := range tests {
@ -319,18 +206,17 @@ func TestProxy_Proxy(t *testing.T) {
p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"}
p.sessionStore = tt.session
p.AuthorizeClient = tt.authorizer
r := httptest.NewRequest(tt.method, tt.host, nil)
r.Header = tt.header
r.Header.Set("Accept", "application/json")
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, nil)
r = r.WithContext(ctx)
w := httptest.NewRecorder()
p.Proxy(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", w.Body.String())
t.Errorf("\n%+v", opts)
t.Errorf("\n%+v", ts.URL)
t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String())
}
@ -342,6 +228,7 @@ func TestProxy_UserDashboard(t *testing.T) {
opts := testOptions(t)
tests := []struct {
name string
ctxError error
options config.Options
method string
cipher cryptutil.Cipher
@ -351,11 +238,10 @@ func TestProxy_UserDashboard(t *testing.T) {
wantAdminForm bool
wantStatus int
}{
{"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK},
{"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthorize{}, false, http.StatusFound},
{"can't save csrf", opts, http.MethodGet, &cryptutil.MockCipher{MarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
{"want admin form good admin authorization", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
{"is admin but authorization fails", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
{"good", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK},
{"session context error", errors.New("error"), opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
{"want admin form good admin authorization", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
{"is admin but authorization fails", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
}
for _, tt := range tests {
@ -369,6 +255,11 @@ func TestProxy_UserDashboard(t *testing.T) {
p.AuthorizeClient = tt.authorizer
r := httptest.NewRequest(tt.method, "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
@ -395,6 +286,7 @@ func TestProxy_ForceRefresh(t *testing.T) {
tests := []struct {
name string
ctxError error
options config.Options
method string
cipher cryptutil.Cipher
@ -402,12 +294,12 @@ func TestProxy_ForceRefresh(t *testing.T) {
authorizer clients.Authorizer
wantStatus int
}{
{"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
{"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthorize{}, http.StatusFound},
{"bad id token", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
{"issue date too soon", timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest},
{"refresh failure", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
{"can't save refreshed session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
{"good", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
{"cannot load session", errors.New("error"), opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
{"bad id token", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
{"issue date too soon", nil, timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest},
{"refresh failure", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
{"can't save refreshed session", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -420,6 +312,12 @@ func TestProxy_ForceRefresh(t *testing.T) {
p.AuthorizeClient = tt.authorizer
r := httptest.NewRequest(tt.method, "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
p.ForceRefresh(w, r)
if status := w.Code; status != tt.wantStatus {
@ -431,32 +329,29 @@ func TestProxy_ForceRefresh(t *testing.T) {
}
func TestProxy_Impersonate(t *testing.T) {
t.Parallel()
opts := testOptions(t)
tests := []struct {
name string
malformed bool
options config.Options
ctxError error
method string
email string
groups string
csrf string
cipher cryptutil.Cipher
sessionStore sessions.SessionStore
csrfStore sessions.CSRFStore
authorizer clients.Authorizer
wantStatus int
}{
{"good", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"session load error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
// {"non admin users rejected", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
{"non admin users rejected on error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
{"csrf from store retrieve failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}, GetError: errors.New("err")}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"can't decrypt csrf value", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{UnmarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"decrypted csrf mismatch", false, opts, http.MethodPost, "user@blah.com", "", "CSRF!", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest},
{"save session failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"malformed", true, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"groups", false, opts, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
{"save session failure", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -466,17 +361,18 @@ func TestProxy_Impersonate(t *testing.T) {
}
p.cipher = tt.cipher
p.sessionStore = tt.sessionStore
p.csrfStore = tt.csrfStore
p.AuthorizeClient = tt.authorizer
postForm := url.Values{}
postForm.Add("email", tt.email)
postForm.Add("group", tt.groups)
postForm.Set("csrf", tt.csrf)
uri := &url.URL{Path: "/"}
if tt.malformed {
uri.RawQuery = "email=%zzzzz"
}
r := httptest.NewRequest(tt.method, uri.String(), bytes.NewBufferString(postForm.Encode()))
state, _ := tt.sessionStore.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
w := httptest.NewRecorder()
@ -489,50 +385,8 @@ func TestProxy_Impersonate(t *testing.T) {
}
}
func TestProxy_OAuthCallback(t *testing.T) {
tests := []struct {
name string
csrf sessions.MockCSRFStore
session sessions.MockSessionStore
params map[string]string
wantCode int
}{
{"good", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusFound},
{"state err", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "error"}, http.StatusInternalServerError},
{"csrf err", sessions.MockCSRFStore{GetError: errors.New("error")}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
{"unmarshal err", sessions.MockCSRFStore{Cookie: &http.Cookie{Name: "something_csrf", Value: "unmarshal error"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
{"malformed", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy, err := New(testOptions(t))
if err != nil {
t.Fatal(err)
}
proxy.sessionStore = &tt.session
proxy.csrfStore = tt.csrf
proxy.cipher = mockCipher{}
// proxy.Csrf
req := httptest.NewRequest(http.MethodPost, "/.pomerium/callback", nil)
q := req.URL.Query()
for k, v := range tt.params {
q.Add(k, v)
}
req.URL.RawQuery = q.Encode()
if tt.name == "malformed" {
req.URL.RawQuery = "email=%zzzzz"
}
w := httptest.NewRecorder()
proxy.AuthenticateCallback(w, req)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
}
})
}
}
func TestProxy_SignOut(t *testing.T) {
t.Parallel()
tests := []struct {
name string
verb string
@ -542,7 +396,6 @@ func TestProxy_SignOut(t *testing.T) {
{"good post", http.MethodPost, "https://test.example", http.StatusFound},
{"good get", http.MethodGet, "https://test.example", http.StatusFound},
{"good empty default", http.MethodGet, "", http.StatusFound},
{"malformed", http.MethodPost, "", http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -554,9 +407,6 @@ func TestProxy_SignOut(t *testing.T) {
postForm := url.Values{}
postForm.Add("redirect_uri", tt.redirectURL)
uri := &url.URL{Path: "/"}
if tt.name == "malformed" {
uri.RawQuery = "redirect_uri=%zzzzz"
}
query, _ := url.ParseQuery(uri.RawQuery)
if tt.verb == http.MethodGet {
@ -576,3 +426,56 @@ func TestProxy_SignOut(t *testing.T) {
})
}
}
func uriParseHelper(s string) *url.URL {
uri, _ := url.Parse(s)
return uri
}
func TestProxy_VerifySession(t *testing.T) {
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
session sessions.SessionStore
ctxError error
provider identity.Authenticator
wantStatus int
}{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := Proxy{
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
sessionStore: tt.session,
}
r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := a.VerifySession(fn)
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
})
}
}

View file

@ -2,6 +2,7 @@ package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"crypto/tls"
"encoding/base64"
"fmt"
"html/template"
stdlog "log"
@ -12,11 +13,11 @@ import (
"github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil"
pom_httputil "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/templates"
"github.com/pomerium/pomerium/internal/tripper"
"github.com/pomerium/pomerium/internal/urlutil"
@ -32,6 +33,11 @@ const (
HeaderEmail = "x-pomerium-authenticated-user-email"
// HeaderGroups is the header key containing the user's groups.
HeaderGroups = "x-pomerium-authenticated-user-groups"
// signinURL is the path to authenticate's sign in endpoint
signinURL = "/.pomerium/sign_in"
// signoutURL is the path to authenticate's sign out endpoint
signoutURL = "/.pomerium/sign_out"
)
// ValidateOptions checks that proper configuration settings are set to create
@ -40,23 +46,21 @@ func ValidateOptions(o config.Options) error {
if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil {
return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %v", err)
}
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
return fmt.Errorf("proxy: invalid 'COOKIE_SECRET': %v", err)
}
if o.AuthenticateURL == nil {
return fmt.Errorf("proxy: missing 'AUTHENTICATE_SERVICE_URL'")
}
if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil {
if err := urlutil.ValidateURL(o.AuthenticateURL); err != nil {
return fmt.Errorf("proxy: invalid 'AUTHENTICATE_SERVICE_URL': %v", err)
}
if o.AuthorizeURL == nil {
return fmt.Errorf("proxy: missing 'AUTHORIZE_SERVICE_URL'")
}
if _, err := urlutil.ParseAndValidateURL(o.AuthorizeURL.String()); err != nil {
if err := urlutil.ValidateURL(o.AuthorizeURL); err != nil {
return fmt.Errorf("proxy: invalid 'AUTHORIZE_SERVICE_URL': %v", err)
}
if len(o.SigningKey) != 0 {
if _, err := cryptutil.NewES256Signer(o.SigningKey, "localhost"); err != nil {
if _, err := cryptutil.NewES256Signer(o.SigningKey, ""); err != nil {
return fmt.Errorf("proxy: invalid 'SIGNING_KEY': %v", err)
}
}
@ -66,17 +70,19 @@ func ValidateOptions(o config.Options) error {
// Proxy stores all the information associated with proxying a request.
type Proxy struct {
// SharedKey used to mutually authenticate service communication
SharedKey string
authenticateURL *url.URL
authorizeURL *url.URL
SharedKey string
authenticateURL *url.URL
authenticateSigninURL *url.URL
authenticateSignoutURL *url.URL
authorizeURL *url.URL
AuthorizeClient clients.Authorizer
cipher cryptutil.Cipher
cookieName string
csrfStore sessions.CSRFStore
cookieDomain string
cookieSecret []byte
defaultUpstreamTimeout time.Duration
redirectURL *url.URL
refreshCooldown time.Duration
routeConfigs map[string]*routeConfig
sessionStore sessions.SessionStore
@ -95,10 +101,17 @@ func New(opts config.Options) (*Proxy, error) {
if err := ValidateOptions(opts); err != nil {
return nil, err
}
decodedCookieSecret, err := base64.StdEncoding.DecodeString(opts.CookieSecret)
if err != nil {
return nil, err
}
cipher, err := cryptutil.NewCipherFromBase64(opts.CookieSecret)
if err != nil {
return nil, err
}
if opts.CookieDomain == "" {
opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String())
}
cookieStore, err := sessions.NewCookieStore(
&sessions.CookieStoreOptions{
@ -116,13 +129,12 @@ func New(opts config.Options) (*Proxy, error) {
p := &Proxy{
SharedKey: opts.SharedKey,
routeConfigs: make(map[string]*routeConfig),
routeConfigs: make(map[string]*routeConfig),
cipher: cipher,
cookieSecret: decodedCookieSecret,
cookieDomain: opts.CookieDomain,
cookieName: opts.CookieName,
csrfStore: cookieStore,
defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
redirectURL: &url.URL{Path: "/.pomerium/callback"},
refreshCooldown: opts.RefreshCooldown,
sessionStore: cookieStore,
signingKey: opts.SigningKey,
@ -130,6 +142,9 @@ func New(opts config.Options) (*Proxy, error) {
}
// DeepCopy urls to avoid accidental mutation, err checked in validate func
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
if err := p.UpdatePolicies(&opts); err != nil {
@ -172,24 +187,24 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error {
if policy.TLSSkipVerify {
tlsClientConfig.InsecureSkipVerify = true
isCustomClientConfig = true
log.Warn().Str("to", policy.Source.String()).Msg("proxy: tls skip verify")
log.Warn().Str("policy", policy.String()).Msg("proxy: tls skip verify")
}
if policy.RootCAs != nil {
tlsClientConfig.RootCAs = policy.RootCAs
isCustomClientConfig = true
log.Debug().Str("to", policy.Source.String()).Msg("proxy: custom root ca")
log.Debug().Str("policy", policy.String()).Msg("proxy: custom root ca")
}
if policy.ClientCertificate != nil {
tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate}
isCustomClientConfig = true
log.Debug().Str("to", policy.Source.String()).Msg("proxy: client certs enabled")
log.Debug().Str("policy", policy.String()).Msg("proxy: client certs enabled")
}
if policy.TLSServerName != "" {
tlsClientConfig.ServerName = policy.TLSServerName
isCustomClientConfig = true
log.Debug().Str("to", policy.Source.String()).Msgf("proxy: tls hostname override to: %s", policy.TLSServerName)
log.Debug().Str("policy", policy.String()).Msgf("proxy: tls hostname override to: %s", policy.TLSServerName)
}
// We avoid setting a custom client config unless we have to as
@ -212,19 +227,6 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error {
return nil
}
// UpstreamProxy stores information for proxying the request to the upstream.
type UpstreamProxy struct {
name string
handler http.Handler
}
// ServeHTTP handles the second (reverse-proxying) leg of pomerium's request flow
func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), fmt.Sprintf("%s%s", r.Host, r.URL.Path))
defer span.End()
u.handler.ServeHTTP(w, r.WithContext(ctx))
}
// NewReverseProxy returns a new ReverseProxy that routes URLs to the scheme, host, and
// base path provided in target. NewReverseProxy rewrites the Host header.
func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
@ -242,22 +244,17 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
return proxy
}
// newReverseProxyHandler applies handler specific options to a given route.
func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config.Policy) (handler http.Handler, err error) {
handler = &UpstreamProxy{
name: route.Destination.Host,
handler: rp,
}
c := middleware.NewChain()
c = c.Append(middleware.StripPomeriumCookie(p.cookieName))
// each route has a custom set of middleware applied to the reverse proxy
func (p *Proxy) newReverseProxyHandler(rp http.Handler, route *config.Policy) (http.Handler, error) {
r := pom_httputil.NewRouter()
r.Use(middleware.StripPomeriumCookie(p.cookieName))
// if signing key is set, add signer to middleware
if len(p.signingKey) != 0 {
signer, err := cryptutil.NewES256Signer(p.signingKey, route.Source.Host)
if err != nil {
return nil, err
}
c = c.Append(middleware.SignRequest(signer, HeaderUserID, HeaderEmail, HeaderGroups, HeaderJWT))
r.Use(middleware.SignRequest(signer, HeaderUserID, HeaderEmail, HeaderGroups, HeaderJWT))
}
// websockets cannot use the non-hijackable timeout-handler
if !route.AllowWebsockets {
@ -265,11 +262,16 @@ func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config.
if route.UpstreamTimeout != 0 {
timeout = route.UpstreamTimeout
}
timeoutMsg := fmt.Sprintf("%s failed to respond within the %s timeout period", route.Destination.Host, timeout)
handler = http.TimeoutHandler(handler, timeout, timeoutMsg)
timeoutMsg := fmt.Sprintf("%s timed out in %s", route.Destination.Host, timeout)
rp = http.TimeoutHandler(rp, timeout, timeoutMsg)
}
return c.Then(handler), nil
// todo(bdd) : fix cors
// if route.CORSAllowPreflight {
// r.Use(cors.Default().Handler)
// }
r.Host(route.Destination.Host)
r.PathPrefix("/").Handler(rp)
return r, nil
}
// UpdateOptions updates internal structures based on config.Options

View file

@ -12,8 +12,6 @@ import (
"github.com/pomerium/pomerium/internal/config"
)
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
func newTestOptions(t *testing.T) *config.Options {
opts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
@ -285,10 +283,11 @@ func Test_UpdateOptions(t *testing.T) {
goodClientCertPolicies := testOptions(t)
goodClientCertPolicies.Policies = []config.Policy{
{To: "http://foo.example", From: "http://bar.example",
TLSClientKey: "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", TLSClientCert: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=",
},
TLSClientKey: "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", TLSClientCert: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="},
}
goodClientCertPolicies.Validate()
customServerName := testOptions(t)
customServerName.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSServerName: "test"}}
tests := []struct {
name string
originalOptions config.Options
@ -301,14 +300,13 @@ func Test_UpdateOptions(t *testing.T) {
{"good no change", good, good, "", "https://corp.example.example", false, true},
{"changed", good, newPolicies, "", "https://bar.example", false, true},
{"changed and missing", good, newPolicies, "", "https://corp.example.example", false, false},
// todo(bdd): not sure what intent of this test is?
{"bad signing key", good, newPolicies, "^bad base 64", "https://corp.example.example", true, false},
{"bad change bad policy url", good, badNewPolicy, "", "https://bar.example", true, false},
// todo: stand up a test server using self signed certificates
{"disable tls verification", good, disableTLSPolicies, "", "https://bar.example", false, true},
{"custom root ca", good, customCAPolicies, "", "https://bar.example", false, true},
{"bad custom root ca base64", good, badCustomCAPolicies, "", "https://bar.example", true, false},
{"good client certs", good, goodClientCertPolicies, "", "https://bar.example", false, true},
{"custom server name", customServerName, customServerName, "", "https://bar.example", false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {