diff --git a/.gitignore b/.gitignore index 5215b7287..3bb0c717c 100644 --- a/.gitignore +++ b/.gitignore @@ -76,4 +76,5 @@ yarn.lock node_modules i18n/* docs/.vuepress/dist/ -.firebase/ \ No newline at end of file +.firebase/ +.changes.md \ No newline at end of file diff --git a/3RD-PARTY b/3RD-PARTY index a3a7e0289..4be7e7180 100644 --- a/3RD-PARTY +++ b/3RD-PARTY @@ -87,31 +87,6 @@ https://github.com/bitly/oauth2_proxy/blob/master/LICENSE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -alice -SPDX-License-Identifier: MIT -https://github.com/justinas/alice/blob/master/LICENSE - - The MIT License (MIT) - - Copyright (c) 2014 Justinas Stankevicius - - Permission is hereby granted, free of charge, to any person obtaining a copy of - this software and associated documentation files (the "Software"), to deal in - the Software without restriction, including without limitation the rights to - use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of - the Software, and to permit persons to whom the Software is furnished to do so, - subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS - FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER - IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN - CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - goji SPDX-License-Identifier: MIT https://github.com/zenazn/goji/blob/master/LICENSE diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 0971b0b38..88df09c2f 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -15,6 +15,8 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" ) +const callbackPath = "/oauth2/callback" + // ValidateOptions checks that configuration are complete and valid. // Returns on first error found. func ValidateOptions(o config.Options) error { @@ -24,11 +26,8 @@ func ValidateOptions(o config.Options) error { if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil { return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %v", err) } - if o.AuthenticateURL == nil { - return errors.New("authenticate: 'AUTHENTICATE_SERVICE_URL' is required") - } - if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil { - return fmt.Errorf("authenticate: couldn't parse 'AUTHENTICATE_SERVICE_URL': %v", err) + if err := urlutil.ValidateURL(o.AuthenticateURL); err != nil { + return fmt.Errorf("authenticate: invalid 'AUTHENTICATE_SERVICE_URL': %v", err) } if o.ClientID == "" { return errors.New("authenticate: 'IDP_CLIENT_ID' is required") @@ -44,8 +43,10 @@ type Authenticate struct { SharedKey string RedirectURL *url.URL + cookieName string + cookieDomain string + cookieSecret []byte templates *template.Template - csrfStore sessions.CSRFStore sessionStore sessions.SessionStore cipher cryptutil.Cipher provider identity.Authenticator @@ -61,6 +62,9 @@ func New(opts config.Options) (*Authenticate, error) { if err != nil { return nil, err } + if opts.CookieDomain == "" { + opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String()) + } cookieStore, err := sessions.NewCookieStore( &sessions.CookieStoreOptions{ Name: opts.CookieName, @@ -74,7 +78,7 @@ func New(opts config.Options) (*Authenticate, error) { return nil, err } redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL) - redirectURL.Path = "/oauth2/callback" + redirectURL.Path = callbackPath provider, err := identity.New( opts.Provider, &identity.Provider{ @@ -94,9 +98,11 @@ func New(opts config.Options) (*Authenticate, error) { SharedKey: opts.SharedKey, RedirectURL: redirectURL, templates: templates.New(), - csrfStore: cookieStore, sessionStore: cookieStore, cipher: cipher, provider: provider, + cookieSecret: decodedCookieSecret, + cookieName: opts.CookieName, + cookieDomain: opts.CookieDomain, }, nil } diff --git a/authenticate/handlers.go b/authenticate/handlers.go index f5a9b94e0..8e8e1a039 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -10,7 +10,8 @@ import ( "strings" "time" - "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/csrf" + "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/middleware" @@ -31,24 +32,68 @@ var CSPHeaders = map[string]string{ // Handler returns the authenticate service's HTTP multiplexer, and routes. func (a *Authenticate) Handler() http.Handler { - // validation middleware chain - c := middleware.NewChain() - c = c.Append(middleware.SetHeaders(CSPHeaders)) - mux := http.NewServeMux() - mux.Handle("/robots.txt", c.ThenFunc(a.RobotsTxt)) + r := httputil.NewRouter() + r.Use(middleware.SetHeaders(CSPHeaders)) + r.Use(csrf.Protect( + a.cookieSecret, + csrf.Path("/"), + csrf.Domain(a.cookieDomain), + csrf.UnsafePaths([]string{"/oauth2/callback"}), // enforce CSRF on "safe" handler + csrf.FormValueName("state"), // rfc6749 section-10.12 + csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieName)), + csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)), + )) + + r.HandleFunc("/robots.txt", a.RobotsTxt).Methods(http.MethodGet) // Identity Provider (IdP) endpoints - mux.Handle("/oauth2", c.ThenFunc(a.OAuthStart)) - mux.Handle("/oauth2/callback", c.ThenFunc(a.OAuthCallback)) + r.HandleFunc("/oauth2/callback", a.OAuthCallback).Methods(http.MethodGet) + r.HandleFunc("/api/v1/token", a.ExchangeToken) + // Proxy service endpoints - validationMiddlewares := c.Append( - middleware.ValidateSignature(a.SharedKey), - middleware.ValidateRedirectURI(a.RedirectURL), - ) - mux.Handle("/sign_in", validationMiddlewares.ThenFunc(a.SignIn)) - mux.Handle("/sign_out", validationMiddlewares.ThenFunc(a.SignOut)) // POST - // Direct user access endpoints - mux.Handle("/api/v1/token", c.ThenFunc(a.ExchangeToken)) - return mux + v := r.PathPrefix("/.pomerium").Subrouter() + v.Use(middleware.ValidateSignature(a.SharedKey)) + v.Use(middleware.ValidateRedirectURI(a.RedirectURL)) + v.Use(sessions.RetrieveSession(a.sessionStore)) + v.Use(a.VerifySession) + + v.HandleFunc("/sign_in", a.SignIn) + v.HandleFunc("/sign_out", a.SignOut) + return r +} + +// VerifySession is the middleware used to enforce a valid authentication +// session state is attached to the users's request context. +func (a *Authenticate) VerifySession(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + state, err := sessions.FromContext(r.Context()) + if errors.Is(err, sessions.ErrExpired) { + if err := a.refresh(w, r, state); err != nil { + log.FromRequest(r).Debug().Str("cause", err.Error()).Msg("authenticate: couldn't refresh session") + a.sessionStore.ClearSession(w, r) + a.redirectToIdentityProvider(w, r) + return + } + + } else if err != nil { + log.FromRequest(r).Err(err).Msg("authenticate: unexpected session state") + a.sessionStore.ClearSession(w, r) + a.redirectToIdentityProvider(w, r) + return + } + next.ServeHTTP(w, r) + }) +} + +func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) error { + newSession, err := a.provider.Refresh(r.Context(), s) + if err != nil { + return fmt.Errorf("authenticate: refresh failed: %w", err) + } + if err := a.sessionStore.SaveSession(w, r, newSession); err != nil { + return fmt.Errorf("authenticate: refresh save failed: %w", err) + } + return nil + } // RobotsTxt handles the /robots.txt route. @@ -59,87 +104,22 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "User-agent: *\nDisallow: /") } -func (a *Authenticate) loadExisting(w http.ResponseWriter, r *http.Request) (*sessions.State, error) { - session, err := a.sessionStore.LoadSession(r) - if err != nil { - return nil, err - } - err = session.Valid() - if err == nil { - return session, nil - } else if !errors.Is(err, sessions.ErrExpired) { - return nil, fmt.Errorf("authenticate: non-refreshable error: %w", err) - } else { - return a.refresh(w, r, session) - } -} - -func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (*sessions.State, error) { - newSession, err := a.provider.Refresh(r.Context(), s) - if err != nil { - return nil, fmt.Errorf("authenticate: refresh failed: %w", err) - } - if err := a.sessionStore.SaveSession(w, r, newSession); err != nil { - return nil, fmt.Errorf("authenticate: refresh save failed: %w", err) - } - return newSession, nil - -} - // SignIn handles to authenticating a user. func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { - session, err := a.loadExisting(w, r) - if err != nil { - log.FromRequest(r).Debug().Err(err).Msg("authenticate: need new session") - a.sessionStore.ClearSession(w, r) - a.OAuthStart(w, r) - return - } - if err := r.ParseForm(); err != nil { - httputil.ErrorResponse(w, r, err) - return - } - state := r.Form.Get("state") - if state == "" { - httputil.ErrorResponse(w, r, httputil.Error("sign in state empty", http.StatusBadRequest, nil)) - return - } - - redirectURL, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri")) + redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")) if err != nil { httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) return } - // encrypt session state as json blob - encrypted, err := sessions.MarshalSession(session, a.cipher) - if err != nil { - httputil.ErrorResponse(w, r, httputil.Error("couldn't marshal session", http.StatusInternalServerError, err)) - return - } - http.Redirect(w, r, getAuthCodeRedirectURL(redirectURL, state, encrypted), http.StatusFound) -} - -func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string { - // ParseQuery err handled by go's mux stack - params, _ := url.ParseQuery(redirectURL.RawQuery) - params.Set("code", authCode) - params.Set("state", state) - redirectURL.RawQuery = params.Encode() - return redirectURL.String() + http.Redirect(w, r, redirectURL.String(), http.StatusFound) } // SignOut signs the user out and attempts to revoke the user's identity session // Handles both GET and POST. func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { - if err := r.ParseForm(); err != nil { - httputil.ErrorResponse(w, r, err) - return - } - redirectURI := r.Form.Get("redirect_uri") - session, err := a.sessionStore.LoadSession(r) + session, err := sessions.FromContext(r.Context()) if err != nil { - log.Error().Err(err).Msg("authenticate: no session to signout, redirect and clear") - http.Redirect(w, r, redirectURI, http.StatusFound) + httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) return } a.sessionStore.ClearSession(w, r) @@ -148,46 +128,30 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err)) return } - http.Redirect(w, r, redirectURI, http.StatusFound) + redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")) + if err != nil { + httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) + return + } + http.Redirect(w, r, redirectURL.String(), http.StatusFound) } -// OAuthStart starts the authenticate process by redirecting to the identity provider. +// redirectToIdentityProvider starts the authenticate process by redirecting the +// user to their respective identity provider. +// // https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest // https://tools.ietf.org/html/rfc6749#section-4.2.1 -func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { - authRedirectURL := a.RedirectURL.ResolveReference(r.URL) - - // Nonce is the opaque, cryptographically binding value used to maintain - // state between the request and the callback. - // OIDC : 3.1.2.1. Authentication Request - nonce := fmt.Sprintf("%x", cryptutil.GenerateKey()) - a.csrfStore.SetCSRF(w, r, nonce) - // Redirection URI to which the response will be sent. This URI MUST exactly - // match one of the Redirection URI values for the Client pre-registered at - // at your identity provider - proxyRedirectURL, err := urlutil.ParseAndValidateURL(authRedirectURL.Query().Get("redirect_uri")) - if err != nil || !middleware.SameDomain(proxyRedirectURL, a.RedirectURL) { - httputil.ErrorResponse(w, r, httputil.Error("proxy url not from the root domain", http.StatusBadRequest, err)) - return - } - - // get the signature and timestamp values then compare hmac - proxyRedirectSig := authRedirectURL.Query().Get("sig") - ts := authRedirectURL.Query().Get("ts") - if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, a.SharedKey) { - httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil)) - return - } - // State is the opaque value used to maintain state between the request and - // the callback; contains both the nonce and redirect URI - state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String()))) - - // build the provider sign in url - signInURL := a.provider.GetSignInURL(state) - http.Redirect(w, r, signInURL, http.StatusFound) +func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http.Request) { + redirectURL := a.RedirectURL.ResolveReference(r.URL) + nonce := csrf.Token(r) + state := fmt.Sprintf("%v:%v", nonce, redirectURL.String()) + encodedState := base64.URLEncoding.EncodeToString([]byte(state)) + http.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound) } // OAuthCallback handles the callback from the identity provider. +// +// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowSteps // https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) { redirect, err := a.getOAuthCallback(w, r) @@ -195,57 +159,49 @@ func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) { httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err)) return } - // redirect back to the proxy-service via sign_in http.Redirect(w, r, redirect.String(), http.StatusFound) } func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) { - if err := r.ParseForm(); err != nil { - return nil, httputil.Error("invalid signature", http.StatusBadRequest, err) + // Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6 + // + // first, check if the identity provider returned an error + if idpError := r.FormValue("error"); idpError != "" { + return nil, httputil.Error(idpError, http.StatusBadRequest, fmt.Errorf("identity provider: %v", idpError)) } - // OIDC : 3.1.2.6. Authentication Error Response - // https://openid.net/specs/openid-connect-core-1_0-final.html#AuthError - if idpError := r.Form.Get("error"); idpError != "" { - return nil, httputil.Error("provider returned an error", http.StatusBadRequest, fmt.Errorf("provider error: %v", idpError)) - } - code := r.Form.Get("code") + // fail if no session redemption code is returned + code := r.FormValue("code") if code == "" { - return nil, httputil.Error("provider didn't reply with code", http.StatusBadRequest, nil) + return nil, httputil.Error("identity provider returned empty code", http.StatusBadRequest, nil) } - // validate the returned code with the identity provider + // Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5 + // + // Exchange the supplied Authorization Code for a valid user session. session, err := a.provider.Authenticate(r.Context(), code) if err != nil { return nil, fmt.Errorf("error redeeming authenticate code: %w", err) } - - // OIDC : 3.1.2.5. Successful Authentication Response - // Opaque value used to maintain state between the request and the callback. - bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state")) + // state includes a csrf nonce (validated by middleware) and redirect uri + bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state")) if err != nil { - return nil, fmt.Errorf("failed decoding state: %w", err) + return nil, httputil.Error("malformed state", http.StatusBadRequest, err) } - s := strings.SplitN(string(bytes), ":", 2) - if len(s) != 2 { - return nil, fmt.Errorf("invalid state size: %d", len(s)) + // split state into its it's components (nonce:redirect_uri) + statePayload := strings.SplitN(string(bytes), ":", 2) + if len(statePayload) != 2 { + return nil, fmt.Errorf("state malformed, size: %d", len(statePayload)) } - // state contains the csrf nonce and redirect uri - nonce := s[0] - redirect := s[1] - c, err := a.csrfStore.GetCSRF(r) - defer a.csrfStore.ClearCSRF(w, r) - if err != nil || c.Value != nonce { - return nil, fmt.Errorf("csrf failure: %w", err) - } - redirectURL, err := urlutil.ParseAndValidateURL(redirect) + // parse redirect_uri; ignore csrf nonce (validity asserted by middleware) + redirectURL, err := urlutil.ParseAndValidateURL(statePayload[1]) if err != nil { - return nil, httputil.Error(fmt.Sprintf("invalid redirect uri %s", redirect), http.StatusBadRequest, err) - } - // sanity check, we are redirecting back to the same subdomain right? - if !middleware.SameDomain(redirectURL, a.RedirectURL) { - return nil, httputil.Error(fmt.Sprintf("invalid redirect domain %v, %v", redirectURL, a.RedirectURL), http.StatusBadRequest, nil) + return nil, httputil.Error("invalid redirect uri", http.StatusBadRequest, err) } + // todo(bdd): if we want to be _extra_ sure, we can validate that the + // redirectURL hmac is valid. But the nonce should cover the integrity... + + // OK. Looks good so let's persist our user session if err := a.sessionStore.SaveSession(w, r, session); err != nil { return nil, fmt.Errorf("failed saving new session: %w", err) } @@ -256,11 +212,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) // and exchanges that token for a pomerium session. The provided token's // audience ('aud') attribute must match Pomerium's client_id. func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) { - if err := r.ParseForm(); err != nil { - httputil.ErrorResponse(w, r, err) - return - } - code := r.Form.Get("id_token") + code := r.FormValue("id_token") if code == "" { httputil.ErrorResponse(w, r, httputil.Error("missing id token", http.StatusBadRequest, nil)) return diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index a273e1aa0..55c17430f 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -21,6 +21,7 @@ func testAuthenticate() *Authenticate { var auth Authenticate auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback") auth.SharedKey = "IzY7MOZwzfOkmELXgozHDKTxoT3nOYhwkcmUVINsRww=" + auth.cookieSecret = []byte(auth.SharedKey) auth.templates = templates.New() return &auth } @@ -51,6 +52,7 @@ func TestAuthenticate_Handler(t *testing.T) { t.Error("handler cannot be nil") } req := httptest.NewRequest("GET", "/robots.txt", nil) + req.Header.Set("Accept", "application/json") rr := httptest.NewRecorder() h.ServeHTTP(rr, req) @@ -63,6 +65,7 @@ func TestAuthenticate_Handler(t *testing.T) { } func TestAuthenticate_SignIn(t *testing.T) { + t.Parallel() tests := []struct { name string state string @@ -76,36 +79,35 @@ func TestAuthenticate_SignIn(t *testing.T) { {"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound}, {"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusFound}, {"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound}, - {"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusBadRequest}, // mocking hmac is meh - {"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusBadRequest}, + {"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusFound}, // mocking hmac is meh + {"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound}, // {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, - {"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, - {"malformed form", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, - {"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, + {"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound}, + {"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound}, {"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, // actually caught by go's handler, but we should keep the test. - {"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, - {"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusInternalServerError}, + {"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, + {"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusFound}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authenticate{ sessionStore: tt.session, provider: tt.provider, - RedirectURL: uriParse("https://some.example"), - csrfStore: &sessions.MockCSRFStore{}, + RedirectURL: uriParseHelper("https://some.example"), SharedKey: "secret", cipher: tt.cipher, } uri := &url.URL{Host: "corp.some.example", Scheme: "https", Path: "/"} - if tt.name == "malformed form" { - uri.RawQuery = "example=%zzzzz" - } else { - uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI) - } + uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI) r := httptest.NewRequest(http.MethodGet, uri.String(), nil) r.Header.Set("Accept", "application/json") + state, _ := tt.session.LoadSession(r) + ctx := r.Context() + ctx = sessions.NewContext(ctx, state, nil) + r = r.WithContext(ctx) + w := httptest.NewRecorder() a.SignIn(w, r) @@ -117,61 +119,18 @@ func TestAuthenticate_SignIn(t *testing.T) { } } -type mockCipher struct{} - -func (a mockCipher) Encrypt(s []byte) ([]byte, error) { - if string(s) == "error" { - return []byte(""), errors.New("error encrypting") - } - return []byte("OK"), nil -} - -func (a mockCipher) Decrypt(s []byte) ([]byte, error) { - if string(s) == "error" { - return []byte(""), errors.New("error encrypting") - } - return []byte("OK"), nil -} -func (a mockCipher) Marshal(s interface{}) (string, error) { return "ok", nil } -func (a mockCipher) Unmarshal(s string, i interface{}) error { - if s == "unmarshal error" || s == "error" { - return errors.New("error") - } - return nil -} - -func Test_getAuthCodeRedirectURL(t *testing.T) { - tests := []struct { - name string - redirectURL *url.URL - state string - authCode string - want string - }{ - {"https", uriParse("https://www.pomerium.io"), "state", "auth-code", "https://www.pomerium.io?code=auth-code&state=state"}, - {"http", uriParse("http://www.pomerium.io"), "state", "auth-code", "http://www.pomerium.io?code=auth-code&state=state"}, - {"no subdomain", uriParse("http://pomerium.io"), "state", "auth-code", "http://pomerium.io?code=auth-code&state=state"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := getAuthCodeRedirectURL(tt.redirectURL, tt.state, tt.authCode); got != tt.want { - t.Errorf("getAuthCodeRedirectURL() = %v, want %v", got, tt.want) - } - }) - } -} - -func uriParse(s string) *url.URL { +func uriParseHelper(s string) *url.URL { uri, _ := url.Parse(s) return uri } func TestAuthenticate_SignOut(t *testing.T) { - + t.Parallel() tests := []struct { name string method string + ctxError error redirectURL string sig string ts string @@ -181,17 +140,16 @@ func TestAuthenticate_SignOut(t *testing.T) { wantCode int wantBody string }{ - {"good post", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""}, - {"failed revoke", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"}, - {"malformed form", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusInternalServerError, ""}, - {"load session error", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""}, + {"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""}, + {"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"}, + {"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, ""}, + {"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authenticate{ sessionStore: tt.sessionStore, provider: tt.provider, - cipher: mockCipher{}, templates: templates.New(), } u, _ := url.Parse("/sign_out") @@ -200,10 +158,11 @@ func TestAuthenticate_SignOut(t *testing.T) { params.Add("ts", tt.ts) params.Add("redirect_uri", tt.redirectURL) u.RawQuery = params.Encode() - if tt.name == "malformed form" { - u.RawQuery = "example=%zzzzz" - } r := httptest.NewRequest(tt.method, u.String(), nil) + state, _ := tt.sessionStore.LoadSession(r) + ctx := r.Context() + ctx = sessions.NewContext(ctx, state, tt.ctxError) + r = r.WithContext(ctx) w := httptest.NewRecorder() a.SignOut(w, r) @@ -217,64 +176,8 @@ func TestAuthenticate_SignOut(t *testing.T) { } } -func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) string { - data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix())) - h := cryptutil.Hash(secret, data) - return base64.URLEncoding.EncodeToString(h) -} - -func TestAuthenticate_OAuthStart(t *testing.T) { - tests := []struct { - name string - method string - redirectURLSetting string - - redirectURL string - sig string - ts string - - provider identity.Authenticator - csrfStore sessions.MockCSRFStore - // sessionStore sessions.SessionStore - wantCode int - }{ - {"good", http.MethodGet, "https://corp.pomerium.io/", "https://corp.pomerium.io/", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusFound}, - {"bad timestamp", http.MethodGet, "https://corp.pomerium.io/", "https://corp.pomerium.io/", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Add(10 * time.Hour).Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest}, - {"missing redirect", http.MethodGet, "https://corp.pomerium.io/", "", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest}, - {"malformed redirect", http.MethodGet, "https://corp.pomerium.io/", "https://pomerium.com%zzzzz", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest}, - {"different domains", http.MethodGet, "https://corp.notpomerium.io/", "https://corp.pomerium.io/", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a := &Authenticate{ - RedirectURL: uriParse(tt.redirectURLSetting), - csrfStore: tt.csrfStore, - provider: tt.provider, - SharedKey: "secret", - cipher: mockCipher{}, - } - u, _ := url.Parse("/oauth_start") - params, _ := url.ParseQuery(u.RawQuery) - params.Add("sig", tt.sig) - params.Add("ts", tt.ts) - params.Add("redirect_uri", tt.redirectURL) - - u.RawQuery = params.Encode() - - r := httptest.NewRequest(tt.method, u.String(), nil) - r.Header.Set("Accept", "application/json") - - w := httptest.NewRecorder() - - a.OAuthStart(w, r) - if status := w.Code; status != tt.wantCode { - t.Errorf("handler returned wrong status code: got %v want %v\n%v", status, tt.wantCode, w.Body.String()) - } - }) - } -} - func TestAuthenticate_OAuthCallback(t *testing.T) { + t.Parallel() tests := []struct { name string method string @@ -286,24 +189,20 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { authenticateURL string session sessions.SessionStore provider identity.MockProvider - csrfStore sessions.MockCSRFStore want string wantCode int }{ - {"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusFound}, - {"get csrf error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", GetError: errors.New("error"), Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError}, - {"csrf nonce error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError}, - {"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, - {"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, - {"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, - {"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, - {"invalid state string", http.MethodGet, "", "code", "nonce:https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, - {"malformed state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError}, - {"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, - {"malformed form", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest}, - {"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"different domains", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://some.example.notpomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusFound}, + {"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError}, + {"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusInternalServerError}, + {"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest}, + {"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest}, + {"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest}, + {"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"bad base64 state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")) + "%", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"too many state delimeters", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io:wait")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"too few state delimeters", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusInternalServerError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -311,7 +210,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { a := &Authenticate{ RedirectURL: authURL, sessionStore: tt.session, - csrfStore: tt.csrfStore, provider: tt.provider, } u, _ := url.Parse("/oauthGet") @@ -322,9 +220,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { u.RawQuery = params.Encode() - if tt.name == "malformed form" { - u.RawQuery = "example=%zzzzz" - } r := httptest.NewRequest(tt.method, u.String(), nil) r.Header.Set("Accept", "application/json") w := httptest.NewRecorder() @@ -339,6 +234,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { } func TestAuthenticate_ExchangeToken(t *testing.T) { + t.Parallel() tests := []struct { name string method string @@ -384,3 +280,55 @@ func TestAuthenticate_ExchangeToken(t *testing.T) { }) } } + +func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + fmt.Fprintln(w, "RVSI FILIVS CAISAR") + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + session sessions.SessionStore + ctxError error + provider identity.Authenticator + + wantStatus int + }{ + {"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK}, + {"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound}, + {"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusOK}, + {"expired,refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound}, + {"expired,save error", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusFound}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + a := Authenticate{ + SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", + cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), + RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"), + sessionStore: tt.session, + provider: tt.provider, + } + r := httptest.NewRequest("GET", "/", nil) + state, _ := tt.session.LoadSession(r) + ctx := r.Context() + ctx = sessions.NewContext(ctx, state, tt.ctxError) + r = r.WithContext(ctx) + + r.Header.Set("Accept", "application/json") + + w := httptest.NewRecorder() + + got := a.VerifySession(fn) + got.ServeHTTP(w, r) + if status := w.Code; status != tt.wantStatus { + t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) + + } + }) + } +} diff --git a/authorize/authorize.go b/authorize/authorize.go index e1c03a3c2..30233fc12 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -14,7 +14,7 @@ import ( func ValidateOptions(o config.Options) error { decoded, err := base64.StdEncoding.DecodeString(o.SharedKey) if err != nil { - return fmt.Errorf("authorize: `SHARED_SECRET` setting is invalid base64: %v", err) + return fmt.Errorf("authorize: `SHARED_SECRET` malformed base64: %v", err) } if len(decoded) != 32 { return fmt.Errorf("authorize: `SHARED_SECRET` want 32 but got %d bytes", len(decoded)) diff --git a/cmd/pomerium/main.go b/cmd/pomerium/main.go index 540e12b2e..a7999c063 100644 --- a/cmd/pomerium/main.go +++ b/cmd/pomerium/main.go @@ -8,6 +8,7 @@ import ( "time" "github.com/fsnotify/fsnotify" + "github.com/gorilla/mux" "github.com/spf13/viper" "google.golang.org/grpc" @@ -44,9 +45,9 @@ func main() { setupTracing(opt) setupHTTPRedirectServer(opt) - mux := http.NewServeMux() + r := newGlobalRouter(opt) grpcServer := setupGRPCServer(opt) - _, err = newAuthenticateService(*opt, mux) + _, err = newAuthenticateService(*opt, r.Host(urlutil.StripPort(opt.AuthenticateURL.Host)).Subrouter()) if err != nil { log.Fatal().Err(err).Msg("cmd/pomerium: authenticate") } @@ -56,7 +57,7 @@ func main() { log.Fatal().Err(err).Msg("cmd/pomerium: authorize") } - proxy, err := newProxyService(*opt, mux) + proxy, err := newProxyService(*opt, r) if err != nil { log.Fatal().Err(err).Msg("cmd/pomerium: proxy") } @@ -70,8 +71,7 @@ func main() { log.Info().Str("file", e.Name).Msg("cmd/pomerium: config file changed") opt = config.HandleConfigUpdate(*configFile, opt, []config.OptionsUpdater{authz, proxy}) }) - - srv, err := httputil.NewTLSServer(configToServerOptions(opt), mainHandler(opt, mux), grpcServer) + srv, err := httputil.NewTLSServer(configToServerOptions(opt), r, grpcServer) if err != nil { log.Fatal().Err(err).Msg("cmd/pomerium: couldn't start pomerium") } @@ -80,7 +80,7 @@ func main() { os.Exit(0) } -func newAuthenticateService(opt config.Options, mux *http.ServeMux) (*authenticate.Authenticate, error) { +func newAuthenticateService(opt config.Options, r *mux.Router) (*authenticate.Authenticate, error) { if !config.IsAuthenticate(opt.Services) { return nil, nil } @@ -88,7 +88,7 @@ func newAuthenticateService(opt config.Options, mux *http.ServeMux) (*authentica if err != nil { return nil, err } - mux.Handle(urlutil.StripPort(opt.AuthenticateURL.Host)+"/", service.Handler()) + r.PathPrefix("/").Handler(service.Handler()) return service, nil } @@ -104,7 +104,7 @@ func newAuthorizeService(opt config.Options, rpc *grpc.Server) (*authorize.Autho return service, nil } -func newProxyService(opt config.Options, mux *http.ServeMux) (*proxy.Proxy, error) { +func newProxyService(opt config.Options, r *mux.Router) (*proxy.Proxy, error) { if !config.IsProxy(opt.Services) { return nil, nil } @@ -112,15 +112,15 @@ func newProxyService(opt config.Options, mux *http.ServeMux) (*proxy.Proxy, erro if err != nil { return nil, err } - mux.Handle("/", service.Handler()) + r.PathPrefix("/").Handler(service.Handler()) return service, nil } -func mainHandler(o *config.Options, mux http.Handler) http.Handler { - c := middleware.NewChain() - c = c.Append(metrics.HTTPMetricsHandler(o.Services)) - c = c.Append(log.NewHandler(log.Logger)) - c = c.Append(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) { +func newGlobalRouter(o *config.Options) *mux.Router { + mux := httputil.NewRouter() + mux.Use(metrics.HTTPMetricsHandler(o.Services)) + mux.Use(log.NewHandler(log.Logger)) + mux.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) { log.FromRequest(r).Debug(). Dur("duration", duration). Int("size", size). @@ -133,15 +133,15 @@ func mainHandler(o *config.Options, mux http.Handler) http.Handler { Msg("http-request") })) if len(o.Headers) != 0 { - c = c.Append(middleware.SetHeaders(o.Headers)) + mux.Use(middleware.SetHeaders(o.Headers)) } - c = c.Append(log.ForwardedAddrHandler("fwd_ip")) - c = c.Append(log.RemoteAddrHandler("ip")) - c = c.Append(log.UserAgentHandler("user_agent")) - c = c.Append(log.RefererHandler("referer")) - c = c.Append(log.RequestIDHandler("req_id", "Request-Id")) - c = c.Append(middleware.Healthcheck("/ping", version.UserAgent())) - return c.Then(mux) + mux.Use(log.ForwardedAddrHandler("fwd_ip")) + mux.Use(log.RemoteAddrHandler("ip")) + mux.Use(log.UserAgentHandler("user_agent")) + mux.Use(log.RefererHandler("referer")) + mux.Use(log.RequestIDHandler("req_id", "Request-Id")) + mux.Use(middleware.Healthcheck("/ping", version.UserAgent())) + return mux } func configToServerOptions(opt *config.Options) *httputil.ServerOptions { diff --git a/cmd/pomerium/main_test.go b/cmd/pomerium/main_test.go index 1d24b1cf5..52ed82511 100644 --- a/cmd/pomerium/main_test.go +++ b/cmd/pomerium/main_test.go @@ -21,7 +21,7 @@ import ( ) func Test_newAuthenticateService(t *testing.T) { - mux := http.NewServeMux() + mux := httputil.NewRouter() tests := []struct { name string @@ -127,7 +127,7 @@ func Test_newProxyeService(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mux := http.NewServeMux() + mux := httputil.NewRouter() testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example") if err != nil { t.Fatal(err) @@ -161,7 +161,7 @@ func Test_newProxyeService(t *testing.T) { } } -func Test_mainHandler(t *testing.T) { +func Test_newGlobalRouter(t *testing.T) { o := config.Options{ Services: "all", Headers: map[string]string{ @@ -172,7 +172,6 @@ func Test_mainHandler(t *testing.T) { "Content-Security-Policy": "default-src 'none'; style-src 'self' 'sha256-pSTVzZsFAqd2U3QYu+BoBDtuJWaPM/+qMy/dBRrhb5Y='; img-src 'self';", "Referrer-Policy": "Same-origin", }} - mux := http.NewServeMux() req := httptest.NewRequest(http.MethodGet, "/404", nil) rr := httptest.NewRecorder() h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -181,8 +180,9 @@ func Test_mainHandler(t *testing.T) { io.WriteString(w, `OK`) }) - mux.Handle("/404", h) - out := mainHandler(&o, mux) + out := newGlobalRouter(&o) + out.Handle("/404", h) + out.ServeHTTP(rr, req) expected := fmt.Sprintf("OK") body := rr.Body.String() diff --git a/docs/docs/CHANGELOG.md b/docs/docs/CHANGELOG.md index 53e4c1e33..87644a4a9 100644 --- a/docs/docs/CHANGELOG.md +++ b/docs/docs/CHANGELOG.md @@ -5,6 +5,23 @@ ### New - Add ability to override HTTPS backend's TLS Server Name. [GH-297](https://github.com/pomerium/pomerium/pull/297) +- Add ability to set pomerium's encrypted session in a auth bearer token, or query param. + +### Security + +- Under certain circumstances, where debug logging was enabled, pomerium's shared secret could be leaked to http access logs as a query param. + +### Fixed + +- Fixed an issue where CSRF would fail if multiple tabs were open. [GH-306](https://github.com/pomerium/pomerium/issues/306) + +### Changed + +- Authenticate service no longer uses gRPC. + +### Removed + +- Removed `AUTHENTICATE_INTERNAL_URL`/`authenticate_internal_url` which is no longer used. ## v0.3.0 diff --git a/docs/docs/reference/reference.md b/docs/docs/reference/reference.md index e958407a0..792e6c551 100644 --- a/docs/docs/reference/reference.md +++ b/docs/docs/reference/reference.md @@ -228,8 +228,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor | Config Key | Description | Required | | :--------------- | :---------------------------------------------------------------- | -------- | -| tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ | -| tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ | +| tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ | +| tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ | ### Jaeger @@ -243,8 +243,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor | Config Key | Description | Required | | :-------------------------------- | :------------------------------------------ | -------- | -| tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ | -| tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ | +| tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ | +| tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ | #### Example @@ -478,11 +478,11 @@ Authenticate Service URL is the externally accessible URL for the authenticate s - Config File Key: `authorize_service_url` - Type: `URL` - Required -- Example: `https://access.corp.example.com` or `https://pomerium-authorize-service.default.svc.cluster.local` +- Example: `https://authorize.corp.example.com` or `https://pomerium-authorize-service.default.svc.cluster.local` Authorize Service URL is the location of the internally accessible authorize service. NOTE: Unlike authenticate, authorize has no publicly accessible http handlers so this setting is purely for gRPC communication. -If your load balancer does not support gRPC pass-through you'll need to set this value to an internally routable location (`https://pomerium-authorize-service.default.svc.cluster.local`) instead of an externally routable one (`https://access.corp.example.com`). +If your load balancer does not support gRPC pass-through you'll need to set this value to an internally routable location (`https://pomerium-authorize-service.default.svc.cluster.local`) instead of an externally routable one (`https://authorize.corp.example.com`). ## Override Certificate Name diff --git a/go.mod b/go.mod index 0dcdc8412..efd3cc42a 100644 --- a/go.mod +++ b/go.mod @@ -9,10 +9,12 @@ require ( github.com/fsnotify/fsnotify v1.4.7 github.com/golang/mock v1.3.1 github.com/golang/protobuf v1.3.1 - github.com/google/go-cmp v0.3.0 + github.com/google/go-cmp v0.3.1 + github.com/gorilla/mux v1.6.2 github.com/magiconair/properties v1.8.1 // indirect github.com/mitchellh/hashstructure v1.0.0 github.com/pelletier/go-toml v1.4.0 // indirect + github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30 github.com/pomerium/go-oidc v2.0.0+incompatible github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/prometheus/client_golang v0.9.3 diff --git a/go.sum b/go.sum index e06dfd01c..a7e0bfde8 100644 --- a/go.sum +++ b/go.sum @@ -65,11 +65,16 @@ github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= +github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk= github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 h1:Iju5GlWwrvL6UBg4zJJt3btmonfrMlCDdsejg4CZE7c= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= @@ -115,9 +120,12 @@ github.com/pelletier/go-toml v1.4.0 h1:u3Z1r+oOXJIkxqw34zVhyPgjBsm6X2wn21NWs/HfS github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30 h1:jggCv6hZvcxjGa3gqkYY2EUuOkITI9Znugz/f3QJfRQ= +github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0= github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o= github.com/pomerium/go-oidc v2.0.0+incompatible/go.mod h1:DRsGVw6MOgxbfq4Y57jKOE8lbEfayxeiY0A8/4vxjBM= github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU= diff --git a/internal/config/options.go b/internal/config/options.go index cfb4c5e96..e2d8e0624 100644 --- a/internal/config/options.go +++ b/internal/config/options.go @@ -217,7 +217,7 @@ func (o *Options) Validate() error { // shared key must be set for all modes other than "all" if o.SharedKey == "" { if o.Services == "all" { - o.SharedKey = cryptutil.GenerateRandomString(32) + o.SharedKey = cryptutil.NewBase64Key() } else { return errors.New("shared-key cannot be empty") } diff --git a/internal/config/policy.go b/internal/config/policy.go index fb81e3000..672800bf5 100644 --- a/internal/config/policy.go +++ b/internal/config/policy.go @@ -116,3 +116,9 @@ func (p *Policy) Validate() error { return nil } +func (p *Policy) String() string { + if p.Source == nil || p.Destination == nil { + return fmt.Sprintf("%s → %s", p.From, p.To) + } + return fmt.Sprintf("%s → %s", p.Source.String(), p.Destination.String()) +} diff --git a/internal/config/policy_test.go b/internal/config/policy_test.go index e59370fd1..026bf03ea 100644 --- a/internal/config/policy_test.go +++ b/internal/config/policy_test.go @@ -1,4 +1,4 @@ -package config // import "github.com/pomerium/pomerium/internal/config" +package config import ( "testing" @@ -44,3 +44,28 @@ func Test_Validate(t *testing.T) { }) } } + +func TestPolicy_String(t *testing.T) { + t.Parallel() + tests := []struct { + name string + From string + To string + want string + }{ + {"good", "https://pomerium.io", "https://localhost", "https://pomerium.io → https://localhost"}, + {"failed to validate", "https://pomerium.io", "localhost", "https://pomerium.io → localhost"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Policy{ + From: tt.From, + To: tt.To, + } + p.Validate() + if got := p.String(); got != tt.want { + t.Errorf("Policy.String() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/cryptutil/encrypt.go b/internal/cryptutil/encrypt.go index 82d79647f..f49cb16e8 100644 --- a/internal/cryptutil/encrypt.go +++ b/internal/cryptutil/encrypt.go @@ -13,20 +13,27 @@ import ( "golang.org/x/crypto/chacha20poly1305" ) +// DefaultKeySize is the default key size in bytes. const DefaultKeySize = 32 -// GenerateKey generates a random 32-byte key. +// NewKey generates a random 32-byte key. // // Panics if source of randomness fails. -func GenerateKey() []byte { +func NewKey() []byte { return randomBytes(DefaultKeySize) } -// GenerateRandomString returns base64 encoded securely generated random string -// of a given set of bytes. +// NewBase64Key generates a random base64 encoded 32-byte key. // // Panics if source of randomness fails. -func GenerateRandomString(c int) string { +func NewBase64Key() string { + return NewRandomStringN(DefaultKeySize) +} + +// NewRandomStringN returns base64 encoded random string of a given num of bytes. +// +// Panics if source of randomness fails. +func NewRandomStringN(c int) string { return base64.StdEncoding.EncodeToString(randomBytes(c)) } diff --git a/internal/cryptutil/encrypt_test.go b/internal/cryptutil/encrypt_test.go index eddd45a37..dc5671f51 100644 --- a/internal/cryptutil/encrypt_test.go +++ b/internal/cryptutil/encrypt_test.go @@ -1,4 +1,4 @@ -package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" +package cryptutil import ( "crypto/rand" @@ -13,7 +13,7 @@ import ( func TestEncodeAndDecodeAccessToken(t *testing.T) { plaintext := []byte("my plain text value") - key := GenerateKey() + key := NewKey() c, err := NewCipher(key) if err != nil { t.Fatalf("unexpected err: %v", err) @@ -47,7 +47,7 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) { } func TestMarshalAndUnmarshalStruct(t *testing.T) { - key := GenerateKey() + key := NewKey() c, err := NewCipher(key) if err != nil { @@ -102,7 +102,7 @@ func TestMarshalAndUnmarshalStruct(t *testing.T) { } func TestCipherDataRace(t *testing.T) { - cipher, err := NewCipher(GenerateKey()) + cipher, err := NewCipher(NewKey()) if err != nil { t.Fatalf("unexpected generating cipher err: %v", err) } @@ -183,21 +183,21 @@ func TestGenerateRandomString(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - o := GenerateRandomString(tt.c) + o := NewRandomStringN(tt.c) b, err := base64.StdEncoding.DecodeString(o) if err != nil { t.Error(err) } got := len(b) if got != tt.want { - t.Errorf("GenerateRandomString() = %d, want %d", got, tt.want) + t.Errorf("NewRandomStringN() = %d, want %d", got, tt.want) } }) } } func TestXChaCha20Cipher_Marshal(t *testing.T) { - + t.Parallel() tests := []struct { name string s interface{} @@ -225,7 +225,7 @@ func TestXChaCha20Cipher_Marshal(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c, err := NewCipher(GenerateKey()) + c, err := NewCipher(NewKey()) if err != nil { t.Fatalf("unexpected err: %v", err) } @@ -239,15 +239,15 @@ func TestXChaCha20Cipher_Marshal(t *testing.T) { } func TestNewCipher(t *testing.T) { - + t.Parallel() tests := []struct { name string secret []byte wantErr bool }{ - {"simple 32 byte key", GenerateKey(), false}, + {"simple 32 byte key", NewKey(), false}, {"key too short", []byte("what is entropy"), true}, - {"key too long", []byte(GenerateRandomString(33)), true}, + {"key too long", []byte(NewRandomStringN(33)), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -261,16 +261,16 @@ func TestNewCipher(t *testing.T) { } func TestNewCipherFromBase64(t *testing.T) { - + t.Parallel() tests := []struct { name string s string wantErr bool }{ - {"simple 32 byte key", base64.StdEncoding.EncodeToString(GenerateKey()), false}, + {"simple 32 byte key", base64.StdEncoding.EncodeToString(NewKey()), false}, {"key too short", base64.StdEncoding.EncodeToString([]byte("what is entropy")), true}, - {"key too long", GenerateRandomString(33), true}, - {"bad base 64", string(GenerateKey()), true}, + {"key too long", NewRandomStringN(33), true}, + {"bad base 64", string(NewKey()), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -282,3 +282,26 @@ func TestNewCipherFromBase64(t *testing.T) { }) } } + +func TestNewBase64Key(t *testing.T) { + t.Parallel() + tests := []struct { + name string + want int + }{ + {"simple", 32}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := NewBase64Key() + b, err := base64.StdEncoding.DecodeString(o) + if err != nil { + t.Error(err) + } + got := len(b) + if got != tt.want { + t.Errorf("NewBase64Key() = %d, want %d", got, tt.want) + } + }) + } +} diff --git a/internal/httputil/router.go b/internal/httputil/router.go new file mode 100644 index 000000000..787d570f2 --- /dev/null +++ b/internal/httputil/router.go @@ -0,0 +1,19 @@ +package httputil // import "github.com/pomerium/pomerium/internal/httputil" + +import ( + "net/http" + + "github.com/gorilla/mux" + "github.com/pomerium/csrf" +) + +// NewRouter returns a new router instance. +func NewRouter() *mux.Router { + return mux.NewRouter() +} + +// CSRFFailureHandler sets a HTTP 403 Forbidden status and writes the +// CSRF failure reason to the response. +func CSRFFailureHandler(w http.ResponseWriter, r *http.Request) { + ErrorResponse(w, r, Error("CSRF Failure", http.StatusForbidden, csrf.FailureReason(r))) +} diff --git a/internal/httputil/router_test.go b/internal/httputil/router_test.go new file mode 100644 index 000000000..96fa0e4ef --- /dev/null +++ b/internal/httputil/router_test.go @@ -0,0 +1,37 @@ +package httputil // import "github.com/pomerium/pomerium/internal/httputil" + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestCSRFFailureHandler(t *testing.T) { + + tests := []struct { + name string + + wantBody string + wantStatus int + }{ + {"basic csrf failure", "{\"error\":\"CSRF Failure\"}\n", http.StatusForbidden}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + CSRFFailureHandler(w, r) + gotBody := w.Body.String() + gotStatus := w.Result().StatusCode + if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" { + t.Errorf("RetrieveSession() = %s", diff) + } + if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" { + t.Errorf("RetrieveSession() = %s", diff) + } + }) + } +} diff --git a/internal/middleware/chain.go b/internal/middleware/chain.go deleted file mode 100644 index 58fa2af0c..000000000 --- a/internal/middleware/chain.go +++ /dev/null @@ -1,109 +0,0 @@ -package middleware // import "github.com/pomerium/pomerium/internal/middleware" - -import "net/http" - -// Constructor is a type alias for func(http.Handler) http.Handler -type Constructor func(http.Handler) http.Handler - -// Chain acts as a list of http.Handler constructors. -// Chain is effectively immutable: -// once created, it will always hold -// the same set of constructors in the same order. -type Chain struct { - constructors []Constructor -} - -// NewChain creates a new chain, -// memorizing the given list of middleware constructors. -// New serves no other function, -// constructors are only called upon a call to Then(). -func NewChain(constructors ...Constructor) Chain { - return Chain{append([]Constructor(nil), constructors...)} -} - -// Then chains the middleware and returns the final http.Handler. -// NewChain(m1, m2, m3).Then(h) -// is equivalent to: -// m1(m2(m3(h))) -// When the request comes in, it will be passed to m1, then m2, then m3 -// and finally, the given handler -// (assuming every middleware calls the following one). -// -// A chain can be safely reused by calling Then() several times. -// stdStack := middleware.NewChain(ratelimitHandler, csrfHandler) -// indexPipe = stdStack.Then(indexHandler) -// authPipe = stdStack.Then(authHandler) -// Note that constructors are called on every call to Then() -// and thus several instances of the same middleware will be created -// when a chain is reused in this way. -// For proper middleware, this should cause no problems. -// -// Then() treats nil as http.DefaultServeMux. -func (c Chain) Then(h http.Handler) http.Handler { - if h == nil { - h = http.DefaultServeMux - } - - for i := range c.constructors { - h = c.constructors[len(c.constructors)-1-i](h) - } - - return h -} - -// ThenFunc works identically to Then, but takes -// a HandlerFunc instead of a Handler. -// -// The following two statements are equivalent: -// c.Then(http.HandlerFunc(fn)) -// c.ThenFunc(fn) -// -// ThenFunc provides all the guarantees of Then. -func (c Chain) ThenFunc(fn http.HandlerFunc) http.Handler { - if fn == nil { - return c.Then(nil) - } - return c.Then(fn) -} - -// Append extends a chain, adding the specified constructors -// as the last ones in the request flow. -// -// Append returns a new chain, leaving the original one untouched. -// -// stdChain := middleware.NewChain(m1, m2) -// extChain := stdChain.Append(m3, m4) -// // requests in stdChain go m1 -> m2 -// // requests in extChain go m1 -> m2 -> m3 -> m4 -func (c Chain) Append(constructors ...Constructor) Chain { - newCons := make([]Constructor, 0, len(c.constructors)+len(constructors)) - newCons = append(newCons, c.constructors...) - newCons = append(newCons, constructors...) - - return Chain{newCons} -} - -// Extend extends a chain by adding the specified chain -// as the last one in the request flow. -// -// Extend returns a new chain, leaving the original one untouched. -// -// stdChain := middleware.NewChain(m1, m2) -// ext1Chain := middleware.NewChain(m3, m4) -// ext2Chain := stdChain.Extend(ext1Chain) -// // requests in stdChain go m1 -> m2 -// // requests in ext1Chain go m3 -> m4 -// // requests in ext2Chain go m1 -> m2 -> m3 -> m4 -// -// Another example: -// aHtmlAfterNosurf := middleware.NewChain(m2) -// aHtml := middleware.NewChain(m1, func(h http.Handler) http.Handler { -// csrf := nosurf.NewChain(h) -// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail)) -// return csrf -// }).Extend(aHtmlAfterNosurf) -// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-handler -// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail -func (c Chain) Extend(chain Chain) Chain { - return c.Append(chain.constructors...) -} diff --git a/internal/middleware/chain_test.go b/internal/middleware/chain_test.go deleted file mode 100644 index 422272923..000000000 --- a/internal/middleware/chain_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package middleware // import "github.com/pomerium/pomerium/internal/middleware" - -import ( - "net/http" - "net/http/httptest" - "reflect" - "testing" -) - -// A constructor for middleware -// that writes its own "tag" into the RW and does nothing else. -// Useful in checking if a chain is behaving in the right order. -func tagMiddleware(tag string) Constructor { - return func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(tag)) - h.ServeHTTP(w, r) - }) - } -} - -// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer), -// but the best we can do. -func funcsEqual(f1, f2 interface{}) bool { - val1 := reflect.ValueOf(f1) - val2 := reflect.ValueOf(f2) - return val1.Pointer() == val2.Pointer() -} - -var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("app\n")) -}) - -func TestNew(t *testing.T) { - c1 := func(h http.Handler) http.Handler { - return nil - } - - c2 := func(h http.Handler) http.Handler { - return http.StripPrefix("potato", nil) - } - - slice := []Constructor{c1, c2} - - chain := NewChain(slice...) - for k := range slice { - if !funcsEqual(chain.constructors[k], slice[k]) { - t.Error("New does not add constructors correctly") - } - } -} - -func TestThenWorksWithNoMiddleware(t *testing.T) { - if !funcsEqual(NewChain().Then(testApp), testApp) { - t.Error("Then does not work with no middleware") - } -} - -func TestThenTreatsNilAsDefaultServeMux(t *testing.T) { - if NewChain().Then(nil) != http.DefaultServeMux { - t.Error("Then does not treat nil as DefaultServeMux") - } -} - -func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) { - if NewChain().ThenFunc(nil) != http.DefaultServeMux { - t.Error("ThenFunc does not treat nil as DefaultServeMux") - } -} - -func TestThenFuncConstructsHandlerFunc(t *testing.T) { - fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - chained := NewChain().ThenFunc(fn) - rec := httptest.NewRecorder() - - chained.ServeHTTP(rec, (*http.Request)(nil)) - - if reflect.TypeOf(chained) != reflect.TypeOf(http.HandlerFunc(nil)) { - t.Error("ThenFunc does not construct HandlerFunc") - } -} - -func TestThenOrdersHandlersCorrectly(t *testing.T) { - t1 := tagMiddleware("t1\n") - t2 := tagMiddleware("t2\n") - t3 := tagMiddleware("t3\n") - - chained := NewChain(t1, t2, t3).Then(testApp) - - w := httptest.NewRecorder() - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatal(err) - } - - chained.ServeHTTP(w, r) - - if w.Body.String() != "t1\nt2\nt3\napp\n" { - t.Error("Then does not order handlers correctly") - } -} - -func TestAppendAddsHandlersCorrectly(t *testing.T) { - chain := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n")) - newChain := chain.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n")) - - if len(chain.constructors) != 2 { - t.Error("chain should have 2 constructors") - } - if len(newChain.constructors) != 4 { - t.Error("newChain should have 4 constructors") - } - - chained := newChain.Then(testApp) - - w := httptest.NewRecorder() - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatal(err) - } - - chained.ServeHTTP(w, r) - - if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" { - t.Error("Append does not add handlers correctly") - } -} - -func TestAppendRespectsImmutability(t *testing.T) { - chain := NewChain(tagMiddleware("")) - newChain := chain.Append(tagMiddleware("")) - - if &chain.constructors[0] == &newChain.constructors[0] { - t.Error("Apppend does not respect immutability") - } -} - -func TestExtendAddsHandlersCorrectly(t *testing.T) { - chain1 := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n")) - chain2 := NewChain(tagMiddleware("t3\n"), tagMiddleware("t4\n")) - newChain := chain1.Extend(chain2) - - if len(chain1.constructors) != 2 { - t.Error("chain1 should contain 2 constructors") - } - if len(chain2.constructors) != 2 { - t.Error("chain2 should contain 2 constructors") - } - if len(newChain.constructors) != 4 { - t.Error("newChain should contain 4 constructors") - } - - chained := newChain.Then(testApp) - - w := httptest.NewRecorder() - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatal(err) - } - - chained.ServeHTTP(w, r) - - if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" { - t.Error("Extend does not add handlers in correctly") - } -} - -func TestExtendRespectsImmutability(t *testing.T) { - chain := NewChain(tagMiddleware("")) - newChain := chain.Extend(NewChain(tagMiddleware(""))) - - if &chain.constructors[0] == &newChain.constructors[0] { - t.Error("Extend does not respect immutability") - } -} diff --git a/internal/middleware/doc.go b/internal/middleware/doc.go index 591668ca1..5c7693fed 100644 --- a/internal/middleware/doc.go +++ b/internal/middleware/doc.go @@ -1,2 +1,2 @@ -// Package middleware provides a standard set of middleware implementations for pomerium. +// Package middleware provides a standard set of middleware for pomerium. package middleware // import "github.com/pomerium/pomerium/internal/middleware" diff --git a/internal/sessions/cookie_store.go b/internal/sessions/cookie_store.go index 92c0452f2..96f2a2846 100644 --- a/internal/sessions/cookie_store.go +++ b/internal/sessions/cookie_store.go @@ -79,12 +79,10 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) { func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { domain := req.Host - if name == cs.csrfName() { - domain = req.Host - } else if cs.CookieDomain != "" { + if cs.CookieDomain != "" { domain = cs.CookieDomain } else { - domain = splitDomain(domain) + domain = ParentSubdomain(domain) } if h, _, err := net.SplitHostPort(domain); err == nil { @@ -105,19 +103,11 @@ func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, return c } -func (cs *CookieStore) csrfName() string { - return fmt.Sprintf("%s_csrf", cs.Name) -} - // makeSessionCookie constructs a session cookie given the request, an expiration time and the current time. func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { return cs.makeCookie(req, cs.Name, value, expiration, now) } -func (cs *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { - return cs.makeCookie(req, cs.csrfName(), value, expiration, now) -} - func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) { if len(cookie.String()) <= MaxChunkSize { http.SetCookie(w, cookie) @@ -134,7 +124,6 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) { nc.Name = fmt.Sprintf("%s_%d", cookie.Name, i) nc.Value = c } - fmt.Println(i) http.SetCookie(w, &nc) } } @@ -150,25 +139,6 @@ func chunk(s string, size int) []string { return ss } -// ClearCSRF clears the CSRF cookie from the request -func (cs *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) { - http.SetCookie(w, cs.makeCSRFCookie(req, "", time.Hour*-1, time.Now())) -} - -// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request -func (cs *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) { - http.SetCookie(w, cs.makeCSRFCookie(req, val, cs.CookieExpire, time.Now())) -} - -// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request -func (cs *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) { - c, err := req.Cookie(cs.csrfName()) - if err != nil { - return nil, ErrEmptyCSRF // ErrNoCookie is confusing in this context - } - return c, nil -} - // ClearSession clears the session cookie from a request func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) { http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now())) @@ -235,7 +205,8 @@ func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s * return nil } -func splitDomain(s string) string { +// ParentSubdomain returns the parent subdomain. +func ParentSubdomain(s string) string { if strings.Count(s, ".") < 2 { return "" } diff --git a/internal/sessions/cookie_store_test.go b/internal/sessions/cookie_store_test.go index 64e61b228..58c0a9f10 100644 --- a/internal/sessions/cookie_store_test.go +++ b/internal/sessions/cookie_store_test.go @@ -1,4 +1,4 @@ -package sessions +package sessions // import "github.com/pomerium/pomerium/internal/sessions" import ( "crypto/rand" @@ -38,7 +38,7 @@ func (a mockCipher) Unmarshal(s string, i interface{}) error { return nil } func TestNewCookieStore(t *testing.T) { - cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey()) + cipher, err := cryptutil.NewCipher(cryptutil.NewKey()) if err != nil { t.Fatal(err) } @@ -111,7 +111,7 @@ func TestNewCookieStore(t *testing.T) { } func TestCookieStore_makeCookie(t *testing.T) { - cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey()) + cipher, err := cryptutil.NewCipher(cryptutil.NewKey()) if err != nil { t.Fatal(err) } @@ -155,62 +155,13 @@ func TestCookieStore_makeCookie(t *testing.T) { if diff := cmp.Diff(s.makeSessionCookie(r, tt.value, tt.expiration, now), tt.want); diff != "" { t.Errorf("CookieStore.makeSessionCookie() = \n%s", diff) } - got := s.makeCSRFCookie(r, tt.value, tt.expiration, now) - tt.wantCSRF.Name = "_pomerium_csrf" - if !reflect.DeepEqual(got, tt.wantCSRF) { - t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.wantCSRF) - } - w := httptest.NewRecorder() - want := "new-csrf" - s.SetCSRF(w, r, want) - found := false - for _, cookie := range w.Result().Cookies() { - if cookie.Name == s.Name+"_csrf" && cookie.Value == want { - found = true - break - } - } - if !found { - t.Error("SetCSRF failed") - } - - w = httptest.NewRecorder() - s.ClearCSRF(w, r) - for _, cookie := range w.Result().Cookies() { - if cookie.Name == s.Name+"_csrf" && cookie.Value == want { - t.Error("clear csrf failed") - break - - } - } - w = httptest.NewRecorder() - want = "new-session" - s.setSessionCookie(w, r, want) - found = false - for _, cookie := range w.Result().Cookies() { - if cookie.Name == s.Name && cookie.Value == want { - found = true - break - } - } - if !found { - t.Error("SetCSRF failed") - } - w = httptest.NewRecorder() - s.ClearSession(w, r) - for _, cookie := range w.Result().Cookies() { - if cookie.Name == s.Name && cookie.Value == want { - t.Error("clear csrf failed") - break - } - } }) } } func TestCookieStore_SaveSession(t *testing.T) { - cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey()) + cipher, err := cryptutil.NewCipher(cryptutil.NewKey()) if err != nil { t.Fatal(err) } @@ -265,38 +216,6 @@ func TestCookieStore_SaveSession(t *testing.T) { } } -func TestMockCSRFStore(t *testing.T) { - tests := []struct { - name string - mockCSRF *MockCSRFStore - newCSRFValue string - wantErr bool - }{ - {"basic", - &MockCSRFStore{ - ResponseCSRF: "ok", - Cookie: &http.Cookie{Name: "hi"}}, - "newcsrf", - false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ms := tt.mockCSRF - ms.SetCSRF(nil, nil, tt.newCSRFValue) - ms.ClearCSRF(nil, nil) - got, err := ms.GetCSRF(nil) - if (err != nil) != tt.wantErr { - t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.mockCSRF.Cookie) { - t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Cookie) - } - - }) - } -} - func TestMockSessionStore(t *testing.T) { tests := []struct { name string @@ -341,7 +260,7 @@ func TestMockSessionStore(t *testing.T) { } } -func Test_splitDomain(t *testing.T) { +func Test_ParentSubdomain(t *testing.T) { t.Parallel() tests := []struct { s string @@ -354,8 +273,8 @@ func Test_splitDomain(t *testing.T) { } for _, tt := range tests { t.Run(tt.s, func(t *testing.T) { - if got := splitDomain(tt.s); got != tt.want { - t.Errorf("splitDomain() = %v, want %v", got, tt.want) + if got := ParentSubdomain(tt.s); got != tt.want { + t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want) } }) } diff --git a/internal/sessions/middleware.go b/internal/sessions/middleware.go new file mode 100644 index 000000000..d26b399ff --- /dev/null +++ b/internal/sessions/middleware.go @@ -0,0 +1,130 @@ +package sessions // import "github.com/pomerium/pomerium/internal/sessions" + +import ( + "context" + "errors" + "net/http" + "strings" +) + +// Context keys +var ( + SessionCtxKey = &contextKey{"Session"} + ErrorCtxKey = &contextKey{"Error"} +) + +// Library errors +var ( + ErrExpired = errors.New("internal/sessions: session is expired") + ErrNoSessionFound = errors.New("internal/sessions: session is not found") + ErrMalformed = errors.New("internal/sessions: session is malformed") +) + +// RetrieveSession http middleware handler will verify a auth session from a http request. +// +// RetrieveSession will search for a auth session in a http request, in the order: +// 1. `pomerium_session` URI query parameter +// 2. `Authorization: BEARER` request header +// 3. Cookie `_pomerium` value +func RetrieveSession(s SessionStore) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return retrieve(s, TokenFromQuery, TokenFromHeader, TokenFromCookie)(next) + } +} + +func retrieve(s SessionStore, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + hfn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + token, err := retrieveFromRequest(s, r, findTokenFns...) + ctx = NewContext(ctx, token, err) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(hfn) + } +} + +func retrieveFromRequest(s SessionStore, r *http.Request, findTokenFns ...func(r *http.Request) string) (*State, error) { + var tokenStr string + var err error + + // Extract token string from the request by calling token find functions in + // the order they where provided. Further extraction stops if a function + // returns a non-empty string. + for _, fn := range findTokenFns { + tokenStr = fn(r) + if tokenStr != "" { + break + } + } + if tokenStr == "" { + return nil, ErrNoSessionFound + } + + state, err := s.LoadSession(r) + if err != nil { + return nil, ErrMalformed + } + err = state.Valid() + if err != nil { + // a little unusual but we want to return the expired state too + return state, err + } + + // Valid! + return state, nil +} + +// NewContext sets context values for the user session state and error. +func NewContext(ctx context.Context, t *State, err error) context.Context { + ctx = context.WithValue(ctx, SessionCtxKey, t) + ctx = context.WithValue(ctx, ErrorCtxKey, err) + return ctx +} + +// FromContext retrieves context values for the user session state and error. +func FromContext(ctx context.Context) (*State, error) { + state, _ := ctx.Value(SessionCtxKey).(*State) + err, _ := ctx.Value(ErrorCtxKey).(error) + return state, err +} + +// TokenFromCookie tries to retrieve the token string from a cookie named +// "_pomerium". +func TokenFromCookie(r *http.Request) string { + cookie, err := r.Cookie("_pomerium") + if err != nil { + return "" + } + return cookie.Value +} + +// TokenFromHeader tries to retrieve the token string from the +// "Authorization" request header: "Authorization: BEARER T". +func TokenFromHeader(r *http.Request) string { + // Get token from authorization header. + bearer := r.Header.Get("Authorization") + if len(bearer) > 7 && strings.EqualFold(bearer[0:6], "BEARER") { + return bearer[7:] + } + return "" +} + +// TokenFromQuery tries to retrieve the token string from the "pomerium_session" URI +// query parameter. +// todo(bdd) : document setting session code as queryparam +func TokenFromQuery(r *http.Request) string { + // Get token from query param named "pomerium_session". + return r.URL.Query().Get("pomerium_session") +} + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. This technique +// for defining context keys was copied from Go 1.7's new use of context in net/http. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { + return "SessionStore context value " + k.name +} diff --git a/internal/sessions/middleware_test.go b/internal/sessions/middleware_test.go new file mode 100644 index 000000000..40be51824 --- /dev/null +++ b/internal/sessions/middleware_test.go @@ -0,0 +1,133 @@ +package sessions + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/pomerium/pomerium/internal/cryptutil" + + "github.com/google/go-cmp/cmp" +) + +func TestNewContext(t *testing.T) { + tests := []struct { + name string + ctx context.Context + t *State + err error + want context.Context + }{ + {"simple", context.Background(), &State{Email: "bdd@pomerium.io"}, nil, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxOut := NewContext(tt.ctx, tt.t, tt.err) + stateOut, errOut := FromContext(ctxOut) + if diff := cmp.Diff(tt.t, stateOut); diff != "" { + t.Errorf("NewContext() = %s", diff) + } + if diff := cmp.Diff(tt.err, errOut); diff != "" { + t.Errorf("NewContext() = %s", diff) + } + }) + } +} + +func testAuthorizer(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := FromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} + +func TestVerifier(t *testing.T) { + fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + fmt.Fprint(w, http.StatusText(http.StatusOK)) + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + // s SessionStore + state State + + cookie bool + header bool + param bool + + wantBody string + wantStatus int + }{ + {"good cookie session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK}, + {"expired cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is expired\n", http.StatusUnauthorized}, + {"malformed cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized}, + {"good auth header session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK}, + {"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized}, + {"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized}, + {"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK}, + {"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is expired\n", http.StatusUnauthorized}, + {"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized}, + {"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cipher, err := cryptutil.NewCipherFromBase64(cryptutil.NewBase64Key()) + if err != nil { + t.Fatal(err) + } + encSession, err := MarshalSession(&tt.state, cipher) + if err != nil { + t.Fatal(err) + } + if strings.Contains(tt.name, "malformed") { + // add some garbage to the end of the string + encSession += cryptutil.NewBase64Key() + fmt.Println(encSession) + } + + cs, err := NewCookieStore(&CookieStoreOptions{ + Name: "_pomerium", + CookieCipher: cipher, + }) + if err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + if tt.cookie { + r.AddCookie(&http.Cookie{Name: "_pomerium", Value: encSession}) + } else if tt.header { + r.Header.Set("Authorization", "Bearer "+encSession) + } else if tt.param { + q := r.URL.Query() + q.Add("pomerium_session", encSession) + r.URL.RawQuery = q.Encode() + } + + got := RetrieveSession(cs)(testAuthorizer((fnh))) + got.ServeHTTP(w, r) + + gotBody := w.Body.String() + gotStatus := w.Result().StatusCode + + if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" { + t.Errorf("RetrieveSession() = %v", diff) + } + if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" { + t.Errorf("RetrieveSession() = %v", diff) + } + }) + } +} diff --git a/internal/sessions/mock_store.go b/internal/sessions/mock_store.go index bbed23b1f..228676ea8 100644 --- a/internal/sessions/mock_store.go +++ b/internal/sessions/mock_store.go @@ -4,28 +4,6 @@ import ( "net/http" ) -// MockCSRFStore is a mock implementation of the CSRF store interface -type MockCSRFStore struct { - ResponseCSRF string - Cookie *http.Cookie - GetError error -} - -// SetCSRF sets the ResponseCSRF string to a val -func (ms MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) { - ms.ResponseCSRF = val -} - -// ClearCSRF clears the ResponseCSRF string -func (ms MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) { - ms.ResponseCSRF = "" -} - -// GetCSRF returns the cookie and error -func (ms MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) { - return ms.Cookie, ms.GetError -} - // MockSessionStore is a mock implementation of the SessionStore interface type MockSessionStore struct { ResponseSession string diff --git a/internal/sessions/state.go b/internal/sessions/state.go index bc17c4c36..59df7e4e2 100644 --- a/internal/sessions/state.go +++ b/internal/sessions/state.go @@ -10,9 +10,6 @@ import ( "github.com/pomerium/pomerium/internal/cryptutil" ) -// ErrExpired is an error for a expired sessions. -var ErrExpired = fmt.Errorf("internal/sessions: expired session") - // State is our object that keeps track of a user's session state type State struct { AccessToken string `json:"access_token"` diff --git a/internal/sessions/state_test.go b/internal/sessions/state_test.go index eaca3ff46..b92e09c95 100644 --- a/internal/sessions/state_test.go +++ b/internal/sessions/state_test.go @@ -12,7 +12,7 @@ import ( ) func TestStateSerialization(t *testing.T) { - secret := cryptutil.GenerateKey() + secret := cryptutil.NewKey() c, err := cryptutil.NewCipher(secret) if err != nil { t.Fatalf("expected to be able to create cipher: %v", err) @@ -123,7 +123,7 @@ func TestState_Impersonating(t *testing.T) { } func TestMarshalSession(t *testing.T) { - secret := cryptutil.GenerateKey() + secret := cryptutil.NewKey() c, err := cryptutil.NewCipher(secret) if err != nil { t.Fatalf("expected to be able to create cipher: %v", err) diff --git a/internal/sessions/store.go b/internal/sessions/store.go index 9ba2bad8f..77e08af32 100644 --- a/internal/sessions/store.go +++ b/internal/sessions/store.go @@ -8,16 +8,6 @@ import ( // ErrEmptySession is an error for an empty sessions. var ErrEmptySession = errors.New("internal/sessions: empty session") -// ErrEmptyCSRF is an error for an empty sessions. -var ErrEmptyCSRF = errors.New("internal/sessions: empty csrf") - -// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie -type CSRFStore interface { - SetCSRF(http.ResponseWriter, *http.Request, string) - GetCSRF(*http.Request) (*http.Cookie, error) - ClearCSRF(http.ResponseWriter, *http.Request) -} - // SessionStore has the functions for setting, getting, and clearing the Session cookie type SessionStore interface { ClearSession(http.ResponseWriter, *http.Request) diff --git a/internal/templates/templates.go b/internal/templates/templates.go index f5989fb21..3df4a225f 100644 --- a/internal/templates/templates.go +++ b/internal/templates/templates.go @@ -306,7 +306,7 @@ func New() *template.Template {
-

Session

+

Current user

Your current session details.

- - Refresh + {{ .csrfField }} +
+
+

Refresh Identity

+

Pomerium will automatically refresh your user session. However, if your group memberships have recently changed and haven't taken effect yet, you can refresh your session manually.

+ +
+
+ {{ .csrfField }} + +
+
+
+ {{if .IsAdmin}}
@@ -355,7 +367,7 @@ func New() *template.Template {
- + {{ .csrfField }}
diff --git a/internal/urlutil/url.go b/internal/urlutil/url.go index 06db9b72e..571b43637 100644 --- a/internal/urlutil/url.go +++ b/internal/urlutil/url.go @@ -1,9 +1,14 @@ package urlutil // import "github.com/pomerium/pomerium/internal/urlutil" import ( + "encoding/base64" "fmt" + "net/http" "net/url" "strings" + "time" + + "github.com/pomerium/pomerium/internal/cryptutil" ) // StripPort returns a host, without any port number. @@ -32,18 +37,73 @@ func ParseAndValidateURL(rawurl string) (*url.URL, error) { if err != nil { return nil, err } - if u.Scheme == "" { - return nil, fmt.Errorf("%s url does contain a valid scheme. Did you mean https://%s?", rawurl, rawurl) - } - if u.Host == "" { - return nil, fmt.Errorf("%s url does contain a valid hostname", rawurl) + if err := ValidateURL(u); err != nil { + return nil, err } return u, nil } +// ValidateURL wraps standard library's default url.Parse because +// it's much more lenient about what type of urls it accepts than pomerium. +func ValidateURL(u *url.URL) error { + if u == nil { + return fmt.Errorf("nil url") + } + if u.Scheme == "" { + return fmt.Errorf("%s url does contain a valid scheme. Did you mean https://%s?", u.String(), u.String()) + } + if u.Host == "" { + return fmt.Errorf("%s url does contain a valid hostname", u.String()) + } + return nil +} + func DeepCopy(u *url.URL) (*url.URL, error) { if u == nil { return nil, nil } return ParseAndValidateURL(u.String()) } + +// testTimeNow can be used in tests to set a specific int64 time +var testTimeNow int64 + +// timestamp returns the current timestamp, in seconds. +// +// For testing purposes, the function that generates the timestamp can be +// overridden. If not set, it will return time.Now().UTC().Unix(). +func timestamp() int64 { + if testTimeNow == 0 { + return time.Now().UTC().Unix() + } + return testTimeNow +} + +// SignedRedirectURL takes a destination URL and adds redirect_uri to it's +// query params, along with a timestamp and an keyed signature. +func SignedRedirectURL(key string, destination, urlToSign *url.URL) *url.URL { + now := timestamp() + rawURL := urlToSign.String() + params, _ := url.ParseQuery(destination.RawQuery) // handled by incoming mux + params.Set("redirect_uri", rawURL) + params.Set("ts", fmt.Sprint(now)) + params.Set("sig", hmacURL(key, rawURL, now)) + destination.RawQuery = params.Encode() + return destination +} + +// hmacURL takes a redirect url string and timestamp and returns the base64 +// encoded HMAC result. +func hmacURL(key, data string, timestamp int64) string { + h := cryptutil.Hash(key, []byte(fmt.Sprint(data, timestamp))) + return base64.URLEncoding.EncodeToString(h) +} + +// GetAbsoluteURL returns the current handler's absolute url. +// https://stackoverflow.com/a/23152483 +func GetAbsoluteURL(r *http.Request) *url.URL { + u := r.URL + u.Scheme = "https" + u.Host = r.Host + return u +} diff --git a/internal/urlutil/url_test.go b/internal/urlutil/url_test.go index 3c9706b51..95f77d8a8 100644 --- a/internal/urlutil/url_test.go +++ b/internal/urlutil/url_test.go @@ -1,6 +1,7 @@ package urlutil import ( + "net/http" "net/url" "reflect" "testing" @@ -35,7 +36,7 @@ func Test_StripPort(t *testing.T) { } func TestParseAndValidateURL(t *testing.T) { - + t.Parallel() tests := []struct { name string rawurl string @@ -63,7 +64,7 @@ func TestParseAndValidateURL(t *testing.T) { } func TestDeepCopy(t *testing.T) { - + t.Parallel() tests := []struct { name string u *url.URL @@ -87,3 +88,90 @@ func TestDeepCopy(t *testing.T) { }) } } + +func TestValidateURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + u *url.URL + wantErr bool + }{ + {"good", &url.URL{Scheme: "https", Host: "some.example"}, false}, + {"nil", nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := ValidateURL(tt.u); (err != nil) != tt.wantErr { + t.Errorf("ValidateURL() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestSignedRedirectURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + mockedTime int64 + key string + destination *url.URL + urlToSign *url.URL + want *url.URL + }{ + {"good", 2, "hunter42", &url.URL{Host: "pomerium.io", Scheme: "https://"}, &url.URL{Host: "pomerium.io", Scheme: "https://", Path: "/ok"}, &url.URL{Host: "pomerium.io", Scheme: "https://", RawQuery: "redirect_uri=https%3A%2F%2F%3A%2F%2Fpomerium.io%2Fok&sig=7jdo1XFcmuhjBHnpfVhll5cXflYByeMnbp5kRz87CVQ%3D&ts=2"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testTimeNow = tt.mockedTime + got := SignedRedirectURL(tt.key, tt.destination, tt.urlToSign) + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("SignedRedirectURL() = diff %v", diff) + } + }) + } +} + +func Test_timestamp(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + dontWant int64 + }{ + {"if unset should never return", 0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testTimeNow = tt.dontWant + if got := timestamp(); got == tt.dontWant { + t.Errorf("timestamp() = %v, dontWant %v", got, tt.dontWant) + } + }) + } +} + +func parseURLHelper(s string) *url.URL { + u, _ := url.Parse(s) + return u +} + +func TestGetAbsoluteURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + u *url.URL + want *url.URL + }{ + {"add https", parseURLHelper("http://pomerium.io"), parseURLHelper("https://pomerium.io")}, + {"missing scheme", parseURLHelper("https://pomerium.io"), parseURLHelper("https://pomerium.io")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := http.Request{URL: tt.u, Host: tt.u.Host} + got := GetAbsoluteURL(&r) + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("GetAbsoluteURL() = %v", diff) + } + }) + } +} diff --git a/proxy/handlers.go b/proxy/handlers.go index 1108dfd97..c7d4df5ec 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -1,15 +1,15 @@ package proxy // import "github.com/pomerium/pomerium/proxy" import ( - "encoding/base64" "fmt" "net/http" "net/url" "strings" "time" + "github.com/pomerium/csrf" + "github.com/pomerium/pomerium/internal/config" - "github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/middleware" @@ -18,34 +18,55 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" ) -// StateParameter holds the redirect id along with the session id. -type StateParameter struct { - SessionID string `json:"session_id"` - RedirectURI string `json:"redirect_uri"` -} - // Handler returns the proxy service's ServeMux func (p *Proxy) Handler() http.Handler { - // validation middleware chain - validate := middleware.NewChain() - validate = validate.Append(middleware.ValidateHost(func(host string) bool { + r := httputil.NewRouter().StrictSlash(true) + r.Use(middleware.ValidateHost(func(host string) bool { _, ok := p.routeConfigs[host] return ok })) - mux := http.NewServeMux() - mux.HandleFunc("/robots.txt", p.RobotsTxt) - mux.HandleFunc("/.pomerium", p.UserDashboard) - mux.HandleFunc("/.pomerium/impersonate", p.Impersonate) // POST - mux.HandleFunc("/.pomerium/sign_out", p.SignOut) - // handlers with validation - mux.Handle("/.pomerium/callback", validate.ThenFunc(p.AuthenticateCallback)) - mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.ForceRefresh)) - mux.Handle("/", validate.ThenFunc(p.Proxy)) - return mux + r.Use(csrf.Protect( + p.cookieSecret, + csrf.Path("/"), + csrf.Domain(p.cookieDomain), + csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieName)), + csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)), + )) + r.HandleFunc("/robots.txt", p.RobotsTxt) + // requires authN not authZ + r.Use(sessions.RetrieveSession(p.sessionStore)) + r.Use(p.VerifySession) + r.HandleFunc("/.pomerium/", p.UserDashboard).Methods(http.MethodGet) + r.HandleFunc("/.pomerium/impersonate", p.Impersonate).Methods(http.MethodPost) + r.HandleFunc("/.pomerium/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost) + r.HandleFunc("/.pomerium/refresh", p.ForceRefresh).Methods(http.MethodPost) + r.PathPrefix("/").HandlerFunc(p.Proxy) + return r +} + +// VerifySession is the middleware used to enforce a valid authentication +// session state is attached to the users's request context. +func (p *Proxy) VerifySession(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + state, err := sessions.FromContext(r.Context()) + if err != nil { + log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error") + p.authenticate(w, r) + return + } + if err := state.Valid(); err != nil { + log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to invalid session") + p.authenticate(w, r) + return + } + next.ServeHTTP(w, r) + }) } // RobotsTxt sets the User-Agent header in the response to be "Disallow" func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") w.WriteHeader(http.StatusOK) fmt.Fprintf(w, "User-agent: *\nDisallow: /") } @@ -55,110 +76,18 @@ func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) { // the local session state. func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) { redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"} - switch r.Method { - case http.MethodPost: - if err := r.ParseForm(); err != nil { - httputil.ErrorResponse(w, r, err) - return - } - uri, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri")) - if err == nil && uri.String() != "" { - redirectURL = uri - } - default: - uri, err := urlutil.ParseAndValidateURL(r.URL.Query().Get("redirect_uri")) - if err == nil && uri.String() != "" { - redirectURL = uri - } + if uri, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")); err == nil && uri.String() != "" { + redirectURL = uri } - http.Redirect(w, r, p.GetSignOutURL(p.authenticateURL, redirectURL).String(), http.StatusFound) + uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL) + http.Redirect(w, r, uri.String(), http.StatusFound) } -// OAuthStart begins the authenticate flow, encrypting the redirect url +// Authenticate begins the authenticate flow, encrypting the redirect url // in a request to the provider's sign in endpoint. -func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) { - state := &StateParameter{ - SessionID: fmt.Sprintf("%x", cryptutil.GenerateKey()), - RedirectURI: r.URL.String(), - } - - // Encrypt CSRF + redirect_uri and store in csrf session. Validated on callback. - csrfState, err := p.cipher.Marshal(state) - if err != nil { - httputil.ErrorResponse(w, r, err) - return - } - p.csrfStore.SetCSRF(w, r, csrfState) - - paramState, err := p.cipher.Marshal(state) - if err != nil { - httputil.ErrorResponse(w, r, err) - return - } - - // Sanity check. The encrypted payload of local and remote state should - // never match as each encryption round uses a cryptographic nonce. - // if paramState == csrfState { - // httputil.ErrorResponse(w, r, httputil.Error("encrypted state should not match", http.StatusBadRequest, nil)) - // return - // } - - signinURL := p.GetSignInURL(p.authenticateURL, p.GetRedirectURL(r.Host), paramState) - - // Redirect the user to the authenticate service along with the encrypted - // state which contains a redirect uri back to the proxy and a nonce - http.Redirect(w, r, signinURL.String(), http.StatusFound) -} - -// AuthenticateCallback checks the state parameter to make sure it matches the -// local csrf state then redirects the user back to the original intended route. -func (p *Proxy) AuthenticateCallback(w http.ResponseWriter, r *http.Request) { - if err := r.ParseForm(); err != nil { - httputil.ErrorResponse(w, r, err) - return - } - - // Encrypted CSRF passed from authenticate service - remoteStateEncrypted := r.Form.Get("state") - var remoteStatePlain StateParameter - if err := p.cipher.Unmarshal(remoteStateEncrypted, &remoteStatePlain); err != nil { - httputil.ErrorResponse(w, r, err) - return - } - - c, err := p.csrfStore.GetCSRF(r) - if err != nil { - httputil.ErrorResponse(w, r, err) - return - } - p.csrfStore.ClearCSRF(w, r) - - localStateEncrypted := c.Value - var localStatePlain StateParameter - err = p.cipher.Unmarshal(localStateEncrypted, &localStatePlain) - if err != nil { - httputil.ErrorResponse(w, r, err) - return - } - - // assert no nonce reuse - if remoteStateEncrypted == localStateEncrypted { - p.sessionStore.ClearSession(w, r) - httputil.ErrorResponse(w, r, - httputil.Error("local and remote state", http.StatusBadRequest, - fmt.Errorf("possible nonce-reuse / replay attack"))) - return - } - - // Decrypted remote and local state struct (inc. nonce) must match - if remoteStatePlain.SessionID != localStatePlain.SessionID { - p.sessionStore.ClearSession(w, r) - httputil.ErrorResponse(w, r, httputil.Error("CSRF mismatch", http.StatusBadRequest, nil)) - return - } - - // This is the redirect back to the original requested application - http.Redirect(w, r, remoteStatePlain.RedirectURI, http.StatusFound) +func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request) { + uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r)) + http.Redirect(w, r, uri.String(), http.StatusFound) } // shouldSkipAuthentication contains conditions for skipping authentication. @@ -189,17 +118,6 @@ func isCORSPreflight(r *http.Request) bool { r.Header.Get("Origin") != "" } -func (p *Proxy) loadExistingSession(r *http.Request) (*sessions.State, error) { - s, err := p.sessionStore.LoadSession(r) - if err != nil { - return nil, fmt.Errorf("proxy: invalid session: %w", err) - } - if err := s.Valid(); err != nil { - return nil, fmt.Errorf("proxy: invalid state: %w", err) - } - return s, nil -} - // Proxy authenticates a request, either proxying the request if it is authenticated, // or starting the authenticate service for validation if not. func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { @@ -214,11 +132,10 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { route.ServeHTTP(w, r) return } - - s, err := p.loadExistingSession(r) - if err != nil { - log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting") - p.OAuthStart(w, r) + s, err := sessions.FromContext(r.Context()) + if err != nil || s == nil { + log.Debug().Err(err).Msg("proxy: couldn't get session from context") + p.authenticate(w, r) return } authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s) @@ -226,7 +143,7 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { httputil.ErrorResponse(w, r, err) return } else if !authorized { - httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.Email), http.StatusForbidden, nil)) + httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.RequestEmail()), http.StatusForbidden, nil)) return } r.Header.Set(HeaderUserID, s.User) @@ -240,62 +157,41 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { // It also contains certain administrative actions like user impersonation. // Nota bene: This endpoint does authentication, not authorization. func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) { - session, err := p.loadExistingSession(r) + session, err := sessions.FromContext(r.Context()) if err != nil { - log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting") - p.OAuthStart(w, r) + httputil.ErrorResponse(w, r, err) return } - redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/.pomerium/sign_out"} isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) if err != nil { httputil.ErrorResponse(w, r, err) return } - - // CSRF value used to mitigate replay attacks. - csrf := &StateParameter{SessionID: fmt.Sprintf("%x", cryptutil.GenerateKey())} - csrfCookie, err := p.cipher.Marshal(csrf) - if err != nil { - httputil.ErrorResponse(w, r, err) - return - } - p.csrfStore.SetCSRF(w, r, csrfCookie) - - t := struct { - Email string - User string - Groups []string - RefreshDeadline string - SignoutURL string - - IsAdmin bool - ImpersonateEmail string - ImpersonateGroup string - CSRF string - }{ - Email: session.Email, - User: session.User, - Groups: session.Groups, - RefreshDeadline: time.Until(session.RefreshDeadline).Round(time.Second).String(), - SignoutURL: p.GetSignOutURL(p.authenticateURL, redirectURL).String(), - IsAdmin: isAdmin, - ImpersonateEmail: session.ImpersonateEmail, - ImpersonateGroup: strings.Join(session.ImpersonateGroups, ","), - CSRF: csrf.SessionID, - } - templates.New().ExecuteTemplate(w, "dashboard.html", t) + //todo(bdd): make sign out redirect a configuration option so that + // admins can set to whatever their corporate homepage is + redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"} + signoutURL := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL) + templates.New().ExecuteTemplate(w, "dashboard.html", map[string]interface{}{ + "Email": session.Email, + "User": session.User, + "Groups": session.Groups, + "RefreshDeadline": time.Until(session.RefreshDeadline).Round(time.Second).String(), + "SignoutURL": signoutURL.String(), + "IsAdmin": isAdmin, + "ImpersonateEmail": session.ImpersonateEmail, + "ImpersonateGroup": strings.Join(session.ImpersonateGroups, ","), + "csrfField": csrf.TemplateField(r), + }) } // ForceRefresh redeems and extends an existing authenticated oidc session with // the underlying identity provider. All session details including groups, // timeouts, will be renewed. func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) { - session, err := p.loadExistingSession(r) + session, err := sessions.FromContext(r.Context()) if err != nil { - log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting") - p.OAuthStart(w, r) + httputil.ErrorResponse(w, r, err) return } iss, err := session.IssuedAt() @@ -324,49 +220,25 @@ func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) { // to the user's current user sessions state if the user is currently an // administrative user. Requests are redirected back to the user dashboard. func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPost { - if err := r.ParseForm(); err != nil { - httputil.ErrorResponse(w, r, err) - return - } - session, err := p.loadExistingSession(r) - if err != nil { - log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting") - p.OAuthStart(w, r) - return - } - isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) - if err != nil || !isAdmin { - httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.Email), http.StatusForbidden, err)) - return - } - // CSRF check -- did this request originate from our form? - c, err := p.csrfStore.GetCSRF(r) - if err != nil { - httputil.ErrorResponse(w, r, err) - return - } - p.csrfStore.ClearCSRF(w, r) - encryptedCSRF := c.Value - var decryptedCSRF StateParameter - if err = p.cipher.Unmarshal(encryptedCSRF, decryptedCSRF); err != nil { - httputil.ErrorResponse(w, r, err) - return - } - if decryptedCSRF.SessionID != r.FormValue("csrf") { - httputil.ErrorResponse(w, r, httputil.Error("CSRF mismatch", http.StatusBadRequest, nil)) - return - } - - // OK to impersonation - session.ImpersonateEmail = r.FormValue("email") - session.ImpersonateGroups = strings.Split(r.FormValue("group"), ",") - - if err := p.sessionStore.SaveSession(w, r, session); err != nil { - httputil.ErrorResponse(w, r, err) - return - } + session, err := sessions.FromContext(r.Context()) + if err != nil { + httputil.ErrorResponse(w, r, err) + return } + isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) + if err != nil || !isAdmin { + httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.RequestEmail()), http.StatusForbidden, err)) + return + } + // OK to impersonation + session.ImpersonateEmail = r.FormValue("email") + session.ImpersonateGroups = strings.Split(r.FormValue("group"), ",") + + if err := p.sessionStore.SaveSession(w, r, session); err != nil { + httputil.ErrorResponse(w, r, err) + return + } + http.Redirect(w, r, "/.pomerium", http.StatusFound) } @@ -391,48 +263,3 @@ func (p *Proxy) policy(r *http.Request) (*config.Policy, bool) { } return nil, false } - -// GetRedirectURL returns the redirect url for a single reverse proxy host. HTTPS is set explicitly. -func (p *Proxy) GetRedirectURL(host string) *url.URL { - u := p.redirectURL - u.Scheme = "https" - u.Host = host - return u -} - -// signRedirectURL takes a redirect url string and timestamp and returns the base64 -// encoded HMAC result. -func (p *Proxy) signRedirectURL(rawRedirect string, timestamp time.Time) string { - data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix())) - h := cryptutil.Hash(p.SharedKey, data) - return base64.URLEncoding.EncodeToString(h) -} - -// GetSignInURL with typical oauth parameters -func (p *Proxy) GetSignInURL(authenticateURL, redirectURL *url.URL, state string) *url.URL { - a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_in"}) - now := time.Now() - rawRedirect := redirectURL.String() - params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux - params.Set("redirect_uri", rawRedirect) - params.Set("shared_secret", p.SharedKey) - params.Set("response_type", "code") - params.Add("state", state) - params.Set("ts", fmt.Sprint(now.Unix())) - params.Set("sig", p.signRedirectURL(rawRedirect, now)) - a.RawQuery = params.Encode() - return a -} - -// GetSignOutURL creates and returns the sign out URL, given a redirectURL -func (p *Proxy) GetSignOutURL(authenticateURL, redirectURL *url.URL) *url.URL { - a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_out"}) - now := time.Now() - rawRedirect := redirectURL.String() - params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux - params.Add("redirect_uri", rawRedirect) - params.Set("ts", fmt.Sprint(now.Unix())) - params.Set("sig", p.signRedirectURL(rawRedirect, now)) - a.RawQuery = params.Encode() - return a -} diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 7bc4d1154..8d59b8cbc 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -7,45 +7,17 @@ import ( "net/http" "net/http/httptest" "net/url" - "reflect" "strings" "testing" "time" "github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/proxy/clients" ) -type mockCipher struct{} - -func (a mockCipher) Encrypt(s []byte) ([]byte, error) { - if string(s) == "error" { - return []byte(""), errors.New("error encrypting") - } - return []byte("OK"), nil -} - -func (a mockCipher) Decrypt(s []byte) ([]byte, error) { - if string(s) == "error" { - return []byte(""), errors.New("error encrypting") - } - return []byte("OK"), nil -} -func (a mockCipher) Marshal(s interface{}) (string, error) { - if s == "error" { - return "", errors.New("error") - } - return "ok", nil -} -func (a mockCipher) Unmarshal(s string, i interface{}) error { - if s == "unmarshal error" || s == "error" { - return errors.New("error") - } - return nil -} - func TestProxy_RobotsTxt(t *testing.T) { proxy := Proxy{} req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil) @@ -60,94 +32,6 @@ func TestProxy_RobotsTxt(t *testing.T) { } } -func TestProxy_GetRedirectURL(t *testing.T) { - tests := []struct { - name string - host string - want *url.URL - }{ - {"google", "google.com", &url.URL{Scheme: "https", Host: "google.com", Path: "/.pomerium/callback"}}, - {"pomerium", "pomerium.io", &url.URL{Scheme: "https", Host: "pomerium.io", Path: "/.pomerium/callback"}}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := &Proxy{redirectURL: &url.URL{Path: "/.pomerium/callback"}} - if got := p.GetRedirectURL(tt.host); !reflect.DeepEqual(got, tt.want) { - t.Errorf("Proxy.GetRedirectURL() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestProxy_signRedirectURL(t *testing.T) { - tests := []struct { - name string - rawRedirect string - timestamp time.Time - want string - }{ - {"pomerium", "https://pomerium.io/.pomerium/callback", fixedDate, "wq3rAjRGN96RXS8TAzH-uxQTD0XgY_8ZYEKMiOLD5P4="}, - {"google", "https://google.com/.pomerium/callback", fixedDate, "7EYHZObq167CuyuPm5CqOtkU4zg5dFeUCs7W7QOrgNQ="}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := &Proxy{} - if got := p.signRedirectURL(tt.rawRedirect, tt.timestamp); got != tt.want { - t.Errorf("Proxy.signRedirectURL() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestProxy_GetSignOutURL(t *testing.T) { - tests := []struct { - name string - authenticate string - redirect string - wantPrefix string - }{ - {"good", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "https://auth.corp.pomerium.io/sign_out?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - authenticateURL, _ := url.Parse(tt.authenticate) - redirectURL, _ := url.Parse(tt.redirect) - - p := &Proxy{} - // signature is ignored as it is tested above. Avoids testing time.Now - if got := p.GetSignOutURL(authenticateURL, redirectURL); !strings.HasPrefix(got.String(), tt.wantPrefix) { - t.Errorf("Proxy.GetSignOutURL() = %v, wantPrefix %v", got.String(), tt.wantPrefix) - } - }) - } -} - -func TestProxy_GetSignInURL(t *testing.T) { - - tests := []struct { - name string - authenticate string - redirect string - state string - - wantPrefix string - }{ - {"good", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "example_state", "https://auth.corp.pomerium.io/sign_in?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io&response_type=code&shared_secret=shared-secret"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := &Proxy{SharedKey: "shared-secret"} - authenticateURL, _ := url.Parse(tt.authenticate) - redirectURL, _ := url.Parse(tt.redirect) - - if got := p.GetSignInURL(authenticateURL, redirectURL, tt.state); !strings.HasPrefix(got.String(), tt.wantPrefix) { - t.Errorf("Proxy.GetSignOutURL() = %v, wantPrefix %v", got.String(), tt.wantPrefix) - } - - }) - } -} - func TestProxy_Signout(t *testing.T) { opts := testOptions(t) err := ValidateOptions(opts) @@ -171,7 +55,7 @@ func TestProxy_Signout(t *testing.T) { } } -func TestProxy_OAuthStart(t *testing.T) { +func TestProxy_authenticate(t *testing.T) { proxy, err := New(testOptions(t)) if err != nil { t.Fatal(err) @@ -179,18 +63,19 @@ func TestProxy_OAuthStart(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/oauth-start", nil) rr := httptest.NewRecorder() - proxy.OAuthStart(rr, req) + proxy.authenticate(rr, req) // expect oauth redirect if status := rr.Code; status != http.StatusFound { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound) } // expected url - expected := `