From 805f0198d2e763a8074ca3b6df14bd432fd941e4 Mon Sep 17 00:00:00 2001 From: Bobby DeSimone Date: Thu, 14 Feb 2019 00:01:50 -0800 Subject: [PATCH] authenticate: add tests, fix signout (#45) - authenticate: a bug where sign out failed to revoke the remote session - docs: add code coverage to readme - authenticate: Rename shorthand receiver variable name - authenticate: consolidate sign in --- README.md | 2 +- authenticate/handlers.go | 252 +++---- authenticate/handlers_test.go | 871 ++++++++++++++++++++++ authenticate/providers/mock_provider.go | 41 + internal/templates/templates.go | 3 - proxy/authenticator/authenticator_test.go | 21 + proxy/authenticator/grpc_test.go | 4 +- proxy/handlers.go | 24 +- proxy/handlers_test.go | 6 - 9 files changed, 1061 insertions(+), 163 deletions(-) create mode 100644 authenticate/providers/mock_provider.go diff --git a/README.md b/README.md index 98dc015b4..4ada87078 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ # Pomerium -[![Travis CI](https://travis-ci.org/pomerium/pomerium.svg?branch=master)](https://travis-ci.org/pomerium/pomerium) [![Go Report Card](https://goreportcard.com/badge/github.com/pomerium/pomerium)](https://goreportcard.com/report/github.com/pomerium/pomerium) [![GoDoc](https://godoc.org/github.com/pomerium/pomerium?status.svg)][godocs] [![LICENSE](https://img.shields.io/github/license/pomerium/pomerium.svg)](https://github.com/pomerium/pomerium/blob/master/LICENSE) +[![Travis CI](https://travis-ci.org/pomerium/pomerium.svg?branch=master)](https://travis-ci.org/pomerium/pomerium) [![Go Report Card](https://goreportcard.com/badge/github.com/pomerium/pomerium)](https://goreportcard.com/report/github.com/pomerium/pomerium) [![GoDoc](https://godoc.org/github.com/pomerium/pomerium?status.svg)][godocs] [![LICENSE](https://img.shields.io/github/license/pomerium/pomerium.svg)](https://github.com/pomerium/pomerium/blob/master/LICENSE)[![codecov](https://codecov.io/gh/pomerium/pomerium/branch/master/graph/badge.svg)](https://codecov.io/gh/pomerium/pomerium) Pomerium is a tool for managing secure access to internal applications and resources. diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 9644c1970..9147f350b 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -29,7 +29,7 @@ var securityHeaders = map[string]string{ } // Handler returns the Http.Handlers for authenticate, callback, and refresh -func (p *Authenticate) Handler() http.Handler { +func (a *Authenticate) Handler() http.Handler { // set up our standard middlewares stdMiddleware := middleware.NewChain() stdMiddleware = stdMiddleware.Append(middleware.Healthcheck("/ping", version.UserAgent())) @@ -51,100 +51,101 @@ func (p *Authenticate) Handler() http.Handler { stdMiddleware = stdMiddleware.Append(middleware.RefererHandler("referer")) stdMiddleware = stdMiddleware.Append(middleware.RequestIDHandler("req_id", "Request-Id")) validateSignatureMiddleware := stdMiddleware.Append( - middleware.ValidateSignature(p.SharedKey), - middleware.ValidateRedirectURI(p.ProxyRootDomains)) + middleware.ValidateSignature(a.SharedKey), + middleware.ValidateRedirectURI(a.ProxyRootDomains)) mux := http.NewServeMux() - mux.Handle("/robots.txt", stdMiddleware.ThenFunc(p.RobotsTxt)) + mux.Handle("/robots.txt", stdMiddleware.ThenFunc(a.RobotsTxt)) // Identity Provider (IdP) callback endpoints and callbacks - mux.Handle("/start", stdMiddleware.ThenFunc(p.OAuthStart)) - mux.Handle("/oauth2/callback", stdMiddleware.ThenFunc(p.OAuthCallback)) + mux.Handle("/start", stdMiddleware.ThenFunc(a.OAuthStart)) + mux.Handle("/oauth2/callback", stdMiddleware.ThenFunc(a.OAuthCallback)) // authenticate-server endpoints - mux.Handle("/sign_in", validateSignatureMiddleware.ThenFunc(p.SignIn)) - mux.Handle("/sign_out", validateSignatureMiddleware.ThenFunc(p.SignOut)) // GET POST + mux.Handle("/sign_in", validateSignatureMiddleware.ThenFunc(a.SignIn)) + mux.Handle("/sign_out", validateSignatureMiddleware.ThenFunc(a.SignOut)) // GET POST return mux } // RobotsTxt handles the /robots.txt route. -func (p *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) { +func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) fmt.Fprintf(w, "User-agent: *\nDisallow: /") } -func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*sessions.SessionState, error) { - session, err := p.sessionStore.LoadSession(r) +func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*sessions.SessionState, error) { + session, err := a.sessionStore.LoadSession(r) if err != nil { log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to load session") - p.sessionStore.ClearSession(w, r) + a.sessionStore.ClearSession(w, r) return nil, err } // if long-lived lifetime has expired, clear session if session.LifetimePeriodExpired() { log.FromRequest(r).Warn().Msg("authenticate: lifetime expired") - p.sessionStore.ClearSession(w, r) + a.sessionStore.ClearSession(w, r) return nil, sessions.ErrLifetimeExpired } // check if session refresh period is up if session.RefreshPeriodExpired() { - newToken, err := p.provider.Refresh(session.RefreshToken) + newToken, err := a.provider.Refresh(session.RefreshToken) if err != nil { log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to refresh session") - p.sessionStore.ClearSession(w, r) + a.sessionStore.ClearSession(w, r) return nil, err } session.AccessToken = newToken.AccessToken session.RefreshDeadline = newToken.Expiry - err = p.sessionStore.SaveSession(w, r, session) + err = a.sessionStore.SaveSession(w, r, session) if err != nil { // We refreshed the session successfully, but failed to save it. // This could be from failing to encode the session properly. // But, we clear the session cookie and reject the request log.FromRequest(r).Error().Err(err).Msg("could not save refreshed session") - p.sessionStore.ClearSession(w, r) + a.sessionStore.ClearSession(w, r) return nil, err } } else { // The session has not exceeded it's lifetime or requires refresh - ok, err := p.provider.Validate(session.IDToken) + ok, err := a.provider.Validate(session.IDToken) if !ok || err != nil { log.FromRequest(r).Error().Err(err).Msg("invalid session state") - p.sessionStore.ClearSession(w, r) + a.sessionStore.ClearSession(w, r) return nil, httputil.ErrUserNotAuthorized } - err = p.sessionStore.SaveSession(w, r, session) + err = a.sessionStore.SaveSession(w, r, session) if err != nil { log.FromRequest(r).Error().Err(err).Msg("failed to save valid session") - p.sessionStore.ClearSession(w, r) + a.sessionStore.ClearSession(w, r) return nil, err } } // authenticate really should not be in the business of authorization // todo(bdd) : remove when authorization module added - if !p.Validator(session.Email) { + if !a.Validator(session.Email) { log.FromRequest(r).Error().Msg("invalid email user") return nil, httputil.ErrUserNotAuthorized } - log.Info().Msg("authenticate") return session, nil } // SignIn handles the /sign_in endpoint. It attempts to authenticate the user, // and if the user is not authenticated, it renders a sign in page. -func (p *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { - session, err := p.authenticate(w, r) +func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { + session, err := a.authenticate(w, r) switch err { case nil: - // User is authenticated, redirect back to proxy - p.ProxyOAuthRedirect(w, r, session) + // session good, redirect back to proxy + log.FromRequest(r).Info().Msg("authenticate.SignIn : authenticated") + a.ProxyCallback(w, r, session) case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession: - log.Info().Err(err).Msg("authenticate.SignIn : expected failure") + // session invalid, authenticate + log.FromRequest(r).Info().Err(err).Msg("authenticate.SignIn : expected failure") if err != http.ErrNoCookie { - p.sessionStore.ClearSession(w, r) + a.sessionStore.ClearSession(w, r) } - p.OAuthStart(w, r) + a.OAuthStart(w, r) default: log.Error().Err(err).Msg("authenticate: unexpected sign in error") @@ -152,10 +153,10 @@ func (p *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { } } -// ProxyOAuthRedirect redirects the user back to proxy's redirection endpoint. -// This workflow corresponds to Section 3.1.2 of the OAuth2 RFC. -// See https://tools.ietf.org/html/rfc6749#section-3.1.2 for more specific information. -func (p *Authenticate) ProxyOAuthRedirect(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) { +// ProxyCallback redirects the user back to proxy service along with an encrypted payload, as +// url params, of the user's session state. +// See RFC6749 3.1.2 https://tools.ietf.org/html/rfc6749#section-3.1.2 +func (a *Authenticate) ProxyCallback(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) { err := r.ParseForm() if err != nil { httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) @@ -180,7 +181,7 @@ func (p *Authenticate) ProxyOAuthRedirect(w http.ResponseWriter, r *http.Request return } // encrypt session state as json blob - encrypted, err := sessions.MarshalSession(session, p.cipher) + encrypted, err := sessions.MarshalSession(session, a.cipher) if err != nil { httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) return @@ -193,111 +194,89 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string params, _ := url.ParseQuery(u.RawQuery) params.Set("code", authCode) params.Set("state", state) - u.RawQuery = params.Encode() - if u.Scheme == "" { u.Scheme = "https" } - return u.String() } -// SignOut signs the user out. -func (p *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { +// SignOut signs the user out by trying to revoke the users remote identity provider session +// then removes the associated local session state. +// Handles both GET and POST of form. +func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + if err != nil { + httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) + return + } + // pretty safe to say that no matter what heppanes here, we want to revoke the local session redirectURI := r.Form.Get("redirect_uri") + session, err := a.sessionStore.LoadSession(r) + if err != nil { + log.Error().Err(err).Msg("authenticate: signout failed to load session") + httputil.ErrorResponse(w, r, "No session found to log out", http.StatusBadRequest) + return + } if r.Method == "GET" { - p.SignOutPage(w, r, "") + signature := r.Form.Get("sig") + timestamp := r.Form.Get("ts") + destinationURL, err := url.Parse(redirectURI) + if err != nil { + log.Error().Err(err).Msg("authenticate: malformed destination url") + httputil.ErrorResponse(w, r, "Malformed destination URL", http.StatusBadRequest) + return + } + t := struct { + Redirect string + Signature string + Timestamp string + Destination string + Email string + Version string + }{ + Redirect: redirectURI, + Signature: signature, + Timestamp: timestamp, + Destination: destinationURL.Host, + Email: session.Email, + Version: version.FullVersion(), + } + a.templates.ExecuteTemplate(w, "sign_out.html", t) + w.WriteHeader(http.StatusOK) return } - - session, err := p.sessionStore.LoadSession(r) - switch err { - case nil: - break - case http.ErrNoCookie: // if there's no cookie in the session we can just redirect - log.Error().Err(err).Msg("authenticate.SignOut : no cookie") - http.Redirect(w, r, redirectURI, http.StatusFound) - return - default: - // a different error, clear the session cookie and redirect - log.Error().Err(err).Msg("authenticate.SignOut : error loading cookie session") - p.sessionStore.ClearSession(w, r) - http.Redirect(w, r, redirectURI, http.StatusFound) - return - } - - err = p.provider.Revoke(session.AccessToken) + a.sessionStore.ClearSession(w, r) + err = a.provider.Revoke(session.AccessToken) if err != nil { - log.Error().Err(err).Msg("authenticate.SignOut : error revoking session") - p.SignOutPage(w, r, "An error occurred during sign out. Please try again.") + log.Error().Err(err).Msg("authenticate: failed to revoke user session") + httputil.ErrorResponse(w, r, fmt.Sprintf("could not revoke session: %s ", err.Error()), http.StatusBadRequest) return } - p.sessionStore.ClearSession(w, r) http.Redirect(w, r, redirectURI, http.StatusFound) } -// SignOutPage renders a sign out page with a message -func (p *Authenticate) SignOutPage(w http.ResponseWriter, r *http.Request, message string) { - // validateRedirectURI middleware already ensures that this is a valid URL - redirectURI := r.Form.Get("redirect_uri") - session, err := p.sessionStore.LoadSession(r) - if err != nil { - http.Redirect(w, r, redirectURI, http.StatusFound) - return - } - - signature := r.Form.Get("sig") - timestamp := r.Form.Get("ts") - destinationURL, err := url.Parse(redirectURI) - - // An error message indicates that an internal server error occurred - if message != "" || err != nil { - log.Error().Err(err).Msg("authenticate.SignOutPage") - w.WriteHeader(http.StatusInternalServerError) - } - - t := struct { - Redirect string - Signature string - Timestamp string - Message string - Destination string - Email string - Version string - }{ - Redirect: redirectURI, - Signature: signature, - Timestamp: timestamp, - Message: message, - Destination: destinationURL.Host, - Email: session.Email, - Version: version.FullVersion(), - } - p.templates.ExecuteTemplate(w, "sign_out.html", t) -} - // OAuthStart starts the authenticate process by redirecting to the provider. It provides a // `redirectURI`, allowing the provider to redirect back to the sso proxy after authenticate. -func (p *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { +func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { authRedirectURL, err := url.Parse(r.URL.Query().Get("redirect_uri")) if err != nil { httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest) return } - authRedirectURL = p.RedirectURL.ResolveReference(r.URL) + authRedirectURL = a.RedirectURL.ResolveReference(r.URL) nonce := fmt.Sprintf("%x", cryptutil.GenerateKey()) - p.csrfStore.SetCSRF(w, r, nonce) + a.csrfStore.SetCSRF(w, r, nonce) // verify redirect uri is from the root domain - if !middleware.ValidRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) { + if !middleware.ValidRedirectURI(authRedirectURL.String(), a.ProxyRootDomains) { httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest) return } // verify proxy url is from the root domain proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri")) - if err != nil || !middleware.ValidRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) { + if err != nil || !middleware.ValidRedirectURI(proxyRedirectURL.String(), a.ProxyRootDomains) { httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest) return } @@ -305,7 +284,7 @@ func (p *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { // get the signature and timestamp values then compare hmac proxyRedirectSig := authRedirectURL.Query().Get("sig") ts := authRedirectURL.Query().Get("ts") - if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.SharedKey) { + if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, a.SharedKey) { httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest) return } @@ -313,16 +292,35 @@ func (p *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) { // concat base64'd nonce and authenticate url to make state state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String()))) // build the provider sign in url - signInURL := p.provider.GetSignInURL(state) + signInURL := a.provider.GetSignInURL(state) http.Redirect(w, r, signInURL, http.StatusFound) } +// OAuthCallback handles the callback from the identity provider. Displays an error page if there +// was an error. If successful, redirects back to the proxy-service via the redirect-url. +func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) { + redirect, err := a.getOAuthCallback(w, r) + switch h := err.(type) { + case nil: + break + case httputil.HTTPError: + log.Error().Err(err).Msg("authenticate: oauth callback error") + httputil.ErrorResponse(w, r, h.Message, h.Code) + return + default: + log.Error().Err(err).Msg("authenticate: unexpected oauth callback error") + httputil.ErrorResponse(w, r, "Internal Error", http.StatusInternalServerError) + return + } + // redirect back to the proxy-service + http.Redirect(w, r, redirect, http.StatusFound) +} + // getOAuthCallback completes the oauth cycle from an identity provider's callback -func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) { +func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) { err := r.ParseForm() if err != nil { - log.FromRequest(r).Error().Err(err).Msg("authenticate: bad form on oauth callback") return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()} } errorString := r.Form.Get("error") @@ -336,7 +334,7 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Missing Code"} } - session, err := p.provider.Authenticate(code) + session, err := a.provider.Authenticate(code) if err != nil { log.FromRequest(r).Error().Err(err).Msg("authenticate: error redeeming authenticate code") return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()} @@ -353,50 +351,30 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) } nonce := s[0] redirect := s[1] - c, err := p.csrfStore.GetCSRF(r) + c, err := a.csrfStore.GetCSRF(r) if err != nil { log.FromRequest(r).Error().Err(err).Msg("authenticate: bad csrf") return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Missing CSRF token"} } - p.csrfStore.ClearCSRF(w, r) + a.csrfStore.ClearCSRF(w, r) if c.Value != nonce { log.FromRequest(r).Error().Err(err).Msg("authenticate: csrf mismatch") return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "CSRF failed"} } - if !middleware.ValidRedirectURI(redirect, p.ProxyRootDomains) { + if !middleware.ValidRedirectURI(redirect, a.ProxyRootDomains) { return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Redirect URI"} } // Set cookie, or deny: validates the session email and group - if !p.Validator(session.Email) { + if !a.Validator(session.Email) { log.FromRequest(r).Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied") return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "You don't have access"} } - err = p.sessionStore.SaveSession(w, r, session) + err = a.sessionStore.SaveSession(w, r, session) if err != nil { log.Error().Err(err).Msg("internal error") return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Internal Error"} } return redirect, nil } - -// OAuthCallback handles the callback from the identity provider. Displays an error page if there -// was an error. If successful, redirects back to the proxy-service via the redirect-url. -func (p *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) { - redirect, err := p.getOAuthCallback(w, r) - switch h := err.(type) { - case nil: - break - case httputil.HTTPError: - log.Error().Err(err).Msg("authenticate: oauth callback error") - httputil.ErrorResponse(w, r, h.Message, h.Code) - return - default: - log.Error().Err(err).Msg("authenticate: unexpected oauth callback error") - httputil.ErrorResponse(w, r, "Internal Error", http.StatusInternalServerError) - return - } - // redirect back to the proxy-service - http.Redirect(w, r, redirect, http.StatusFound) -} diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index a99c90d59..be4b69a29 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -1,15 +1,27 @@ package authenticate import ( + "encoding/base64" + "errors" "fmt" "net/http" "net/http/httptest" "net/url" + "strings" "testing" + "time" + "github.com/pomerium/pomerium/authenticate/providers" + "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/templates" + "golang.org/x/oauth2" ) +// mocks for validator func +func trueValidator(s string) bool { return true } +func falseValidator(s string) bool { return false } + func testAuthenticate() *Authenticate { var auth Authenticate auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback") @@ -37,3 +49,862 @@ func TestAuthenticate_RobotsTxt(t *testing.T) { t.Errorf("handler returned wrong body: got %v want %v", rr.Body.String(), expected) } } + +func TestAuthenticate_Handler(t *testing.T) { + auth := testAuthenticate() + + h := auth.Handler() + if h == nil { + t.Error("handler cannot be nil") + } + req := httptest.NewRequest("GET", "/robots.txt", nil) + + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + expected := fmt.Sprintf("User-agent: *\nDisallow: /") + + body := rr.Body.String() + if body != expected { + t.Errorf("handler returned unexpected body: got %v want %v", body, expected) + } +} + +func TestAuthenticate_authenticate(t *testing.T) { + // sessions.MockSessionStore{Session: expiredLifetime} + goodSession := sessions.MockSessionStore{ + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }} + expiredSession := sessions.MockSessionStore{ + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + LifetimeDeadline: time.Now().Add(10 * -time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }} + expiredRefresPeriod := sessions.MockSessionStore{ + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * -time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }} + + tests := []struct { + name string + session sessions.SessionStore + provider providers.MockProvider + validator func(string) bool + want *sessions.SessionState + wantErr bool + }{ + {"good", goodSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, false}, + {"good but fails validation", goodSession, providers.MockProvider{ValidateResponse: true}, falseValidator, nil, true}, + {"can't load session", sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true}, + {"validation fails", goodSession, providers.MockProvider{ValidateResponse: false}, trueValidator, nil, true}, + {"session fails after good validation", sessions.MockSessionStore{ + SaveError: errors.New("error"), + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true}, + {"lifetime expired", expiredSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true}, + {"refresh expired", + expiredRefresPeriod, + providers.MockProvider{ + ValidateResponse: true, + RefreshResponse: &oauth2.Token{ + AccessToken: "new token", + Expiry: time.Now(), + }, + }, + trueValidator, nil, false}, + {"refresh expired refresh error", + expiredRefresPeriod, + providers.MockProvider{ + ValidateResponse: true, + RefreshError: errors.New("error"), + }, + trueValidator, nil, true}, + {"refresh expired failed save", + sessions.MockSessionStore{ + SaveError: errors.New("error"), + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * -time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }}, + providers.MockProvider{ + ValidateResponse: true, + RefreshResponse: &oauth2.Token{ + AccessToken: "new token", + Expiry: time.Now(), + }, + }, + trueValidator, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Authenticate{ + sessionStore: tt.session, + provider: tt.provider, + Validator: tt.validator, + } + r := httptest.NewRequest("GET", "/auth", nil) + w := httptest.NewRecorder() + + _, err := p.authenticate(w, r) + if (err != nil) != tt.wantErr { + t.Errorf("Authenticate.authenticate() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestAuthenticate_SignIn(t *testing.T) { + tests := []struct { + name string + session sessions.SessionStore + provider providers.MockProvider + validator func(string) bool + wantCode int + }{ + {"good", + sessions.MockSessionStore{ + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }}, + providers.MockProvider{ValidateResponse: true}, + trueValidator, + 403}, + // {"no session", + // sessions.MockSessionStore{ + // Session: &sessions.SessionState{ + // AccessToken: "AccessToken", + // RefreshToken: "RefreshToken", + // LifetimeDeadline: time.Now().Add(-10 * time.Second), + // RefreshDeadline: time.Now().Add(10 * time.Second), + // ValidDeadline: time.Now().Add(10 * time.Second), + // }}, + // providers.MockProvider{ValidateResponse: true}, + // trueValidator, + // 200}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authenticate{ + sessionStore: tt.session, + provider: tt.provider, + Validator: tt.validator, + } + r := httptest.NewRequest("GET", "/sign-in", nil) + w := httptest.NewRecorder() + + a.SignIn(w, r) + if status := w.Code; status != tt.wantCode { + t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode) + } + }) + } +} + +type mockCipher struct{} + +func (a mockCipher) Encrypt(s []byte) ([]byte, error) { + if string(s) == "error" { + return []byte(""), errors.New("error encrypting") + } + return []byte("OK"), nil +} + +func (a mockCipher) Decrypt(s []byte) ([]byte, error) { + if string(s) == "error" { + return []byte(""), errors.New("error encrypting") + } + return []byte("OK"), nil +} +func (a mockCipher) Marshal(s interface{}) (string, error) { return "ok", nil } +func (a mockCipher) Unmarshal(s string, i interface{}) error { + if string(s) == "unmarshal error" || string(s) == "error" { + return errors.New("error") + } + return nil +} +func TestAuthenticate_ProxyCallback(t *testing.T) { + + tests := []struct { + name string + + uri string + state string + authCode string + + sessionState *sessions.SessionState + sessionStore sessions.SessionStore + wantCode int + wantBody string + }{ + {"good", "https://corp.pomerium.io/", "state", "code", + &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }, + sessions.MockSessionStore{}, + 302, + "Found."}, + {"no state", + "https://corp.pomerium.io/", + "", + "code", + &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }, + sessions.MockSessionStore{}, + 403, + "no state parameter supplied"}, + {"no redirect_url", + "", + "state", + "code", + &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }, + sessions.MockSessionStore{}, + 403, + "no redirect_uri parameter"}, + {"malformed redirect_url", + "https://pomerium.com%zzzzz", + "state", + "code", + &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }, + sessions.MockSessionStore{}, + 400, + "malformed redirect_uri"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authenticate{ + sessionStore: tt.sessionStore, + cipher: mockCipher{}, + } + u, _ := url.Parse("https://pomerium.io/redirect") + params, _ := url.ParseQuery(u.RawQuery) + params.Set("code", tt.authCode) + params.Set("state", tt.state) + params.Set("redirect_uri", tt.uri) + + u.RawQuery = params.Encode() + + r := httptest.NewRequest("GET", u.String(), nil) + w := httptest.NewRecorder() + a.ProxyCallback(w, r, tt.sessionState) + if status := w.Code; status != tt.wantCode { + t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode) + } + if body := w.Body.String(); !strings.Contains(body, tt.wantBody) { + t.Errorf("handler returned wrong body Body: got \n%s \n%s", body, tt.wantBody) + } + }) + } +} + +func Test_getAuthCodeRedirectURL(t *testing.T) { + tests := []struct { + name string + redirectURL *url.URL + state string + authCode string + want string + }{ + {"https", uriParse("https://www.pomerium.io"), "state", "auth-code", "https://www.pomerium.io?code=auth-code&state=state"}, + {"http", uriParse("http://www.pomerium.io"), "state", "auth-code", "http://www.pomerium.io?code=auth-code&state=state"}, + {"no subdomain", uriParse("http://pomerium.io"), "state", "auth-code", "http://pomerium.io?code=auth-code&state=state"}, + {"no scheme make https", uriParse("pomerium.io"), "state", "auth-code", "https://pomerium.io?code=auth-code&state=state"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getAuthCodeRedirectURL(tt.redirectURL, tt.state, tt.authCode); got != tt.want { + t.Errorf("getAuthCodeRedirectURL() = %v, want %v", got, tt.want) + } + }) + } +} + +func uriParse(s string) *url.URL { + uri, _ := url.Parse(s) + return uri +} + +func TestAuthenticate_SignOut(t *testing.T) { + + tests := []struct { + name string + method string + + redirectURL string + sig string + ts string + + provider providers.Provider + sessionStore sessions.SessionStore + wantCode int + wantBody string + }{ + {"good post", + http.MethodPost, + "https://corp.pomerium.io/", + "sig", + "ts", + providers.MockProvider{}, + sessions.MockSessionStore{ + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }, + }, + http.StatusFound, + ""}, + {"failed revoke", + http.MethodPost, + "https://corp.pomerium.io/", + "sig", + "ts", + providers.MockProvider{RevokeError: errors.New("OH NO")}, + sessions.MockSessionStore{ + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }, + }, + http.StatusBadRequest, + "could not revoke"}, + + {"good get", + http.MethodGet, + "https://corp.pomerium.io/", + "sig", + "ts", + providers.MockProvider{}, + sessions.MockSessionStore{ + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }, + }, + http.StatusOK, + "This will also sign you out of other internal apps."}, + {"cannot load session", + http.MethodGet, + "https://corp.pomerium.io/", + "sig", + "ts", + providers.MockProvider{}, + sessions.MockSessionStore{ + LoadError: errors.New("uh oh"), + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }, + }, + http.StatusBadRequest, + "No session found to log out"}, + {"bad redirect url get", + http.MethodGet, + "https://pomerium.com%zzzzz", + "sig", + "ts", + providers.MockProvider{}, + sessions.MockSessionStore{ + Session: &sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }, + }, + http.StatusBadRequest, + "Error"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authenticate{ + sessionStore: tt.sessionStore, + provider: tt.provider, + cipher: mockCipher{}, + templates: templates.New(), + } + u, _ := url.Parse("/sign_out") + params, _ := url.ParseQuery(u.RawQuery) + params.Add("sig", tt.sig) + params.Add("ts", tt.ts) + params.Add("redirect_uri", tt.redirectURL) + u.RawQuery = params.Encode() + + r := httptest.NewRequest(tt.method, u.String(), nil) + w := httptest.NewRecorder() + + a.SignOut(w, r) + if status := w.Code; status != tt.wantCode { + t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode) + } + if body := w.Body.String(); !strings.Contains(body, tt.wantBody) { + t.Errorf("handler returned wrong body Body: got \n%s \n%s", body, tt.wantBody) + } + }) + } +} + +func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) string { + data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix())) + h := cryptutil.Hash(secret, data) + return base64.URLEncoding.EncodeToString(h) +} + +func TestAuthenticate_OAuthStart(t *testing.T) { + + tests := []struct { + name string + method string + + redirectURL string + sig string + ts string + allowedDomains []string + + provider providers.Provider + csrfStore sessions.MockCSRFStore + // sessionStore sessions.SessionStore + wantCode int + }{ + {"good", + http.MethodGet, + "https://corp.pomerium.io/", + redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), + fmt.Sprint(time.Now().Unix()), + []string{".pomerium.io"}, + providers.MockProvider{}, + sessions.MockCSRFStore{}, + http.StatusFound, + }, + {"bad timestamp", + http.MethodGet, + "https://corp.pomerium.io/", + redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), + fmt.Sprint(time.Now().Add(10 * time.Hour).Unix()), + []string{".pomerium.io"}, + providers.MockProvider{}, + sessions.MockCSRFStore{}, + http.StatusBadRequest, + }, + {"domain not in allowed domains", + http.MethodGet, + "https://corp.pomerium.io/", + redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), + fmt.Sprint(time.Now().Unix()), + []string{"not.pomerium.io"}, + providers.MockProvider{}, + sessions.MockCSRFStore{}, + http.StatusBadRequest, + }, + {"missing redirect", + http.MethodGet, + "", + redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), + fmt.Sprint(time.Now().Unix()), + []string{".pomerium.io"}, + providers.MockProvider{}, + sessions.MockCSRFStore{}, + http.StatusBadRequest, + }, + {"malformed redirect", + http.MethodGet, + "https://pomerium.com%zzzzz", + redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), + fmt.Sprint(time.Now().Unix()), + []string{".pomerium.io"}, + providers.MockProvider{}, + sessions.MockCSRFStore{}, + http.StatusBadRequest, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authenticate{ + ProxyRootDomains: tt.allowedDomains, + RedirectURL: uriParse("http://www.pomerium.io"), + csrfStore: tt.csrfStore, + provider: tt.provider, + SharedKey: "secret", + cipher: mockCipher{}, + } + u, _ := url.Parse("/oauth_start") + params, _ := url.ParseQuery(u.RawQuery) + params.Add("sig", tt.sig) + params.Add("ts", tt.ts) + params.Add("redirect_uri", tt.redirectURL) + + u.RawQuery = params.Encode() + + r := httptest.NewRequest(tt.method, u.String(), nil) + w := httptest.NewRecorder() + + a.OAuthStart(w, r) + if status := w.Code; status != tt.wantCode { + t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode) + } + }) + } +} + +func TestAuthenticate_getOAuthCallback(t *testing.T) { + + tests := []struct { + name string + method string + + // url params + paramErr string + code string + state string + validDomains []string + validator func(string) bool + + session sessions.SessionStore + provider providers.MockProvider + csrfStore sessions.MockCSRFStore + + want string + wantErr bool + }{ + {"good", + http.MethodGet, + "", + "code", + base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), + []string{"pomerium.io"}, + trueValidator, + sessions.MockSessionStore{}, + providers.MockProvider{ + AuthenticateResponse: sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }}, + sessions.MockCSRFStore{ + ResponseCSRF: "csrf", + Cookie: &http.Cookie{Value: "nonce"}}, + "https://corp.pomerium.io", + false, + }, + {"get csrf error", + http.MethodGet, + "", + "code", + base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), + []string{"pomerium.io"}, + trueValidator, + sessions.MockSessionStore{}, + providers.MockProvider{ + AuthenticateResponse: sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }}, + sessions.MockCSRFStore{ + ResponseCSRF: "csrf", + GetError: errors.New("error"), + Cookie: &http.Cookie{Value: "not nonce"}}, + "", + true, + }, + {"csrf nonce error", + http.MethodGet, + "", + "code", + base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), + []string{"pomerium.io"}, + trueValidator, + sessions.MockSessionStore{}, + providers.MockProvider{ + AuthenticateResponse: sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }}, + sessions.MockCSRFStore{ + ResponseCSRF: "csrf", + Cookie: &http.Cookie{Value: "not nonce"}}, + "", + true, + }, + {"failed authenticate", + http.MethodGet, + "", + "code", + base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), + []string{"pomerium.io"}, + trueValidator, + sessions.MockSessionStore{}, + providers.MockProvider{ + AuthenticateError: errors.New("error"), + }, + sessions.MockCSRFStore{ + ResponseCSRF: "csrf", + Cookie: &http.Cookie{Value: "nonce"}}, + "", + true, + }, + {"failed save session", + http.MethodGet, + "", + "code", + base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), + []string{"pomerium.io"}, + trueValidator, + sessions.MockSessionStore{SaveError: errors.New("error")}, + providers.MockProvider{ + AuthenticateResponse: sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }}, + sessions.MockCSRFStore{ + ResponseCSRF: "csrf", + Cookie: &http.Cookie{Value: "nonce"}}, + "", + true, + }, + {"failed email validation", + http.MethodGet, + "", + "code", + base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), + []string{"pomerium.io"}, + falseValidator, + sessions.MockSessionStore{}, + providers.MockProvider{ + AuthenticateResponse: sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }}, + sessions.MockCSRFStore{ + ResponseCSRF: "csrf", + Cookie: &http.Cookie{Value: "nonce"}}, + "", + true, + }, + + {"error returned", + http.MethodGet, + "idp error", + "code", + base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), + []string{"pomerium.io"}, + trueValidator, + sessions.MockSessionStore{}, + providers.MockProvider{ + AuthenticateResponse: sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }}, + sessions.MockCSRFStore{ + ResponseCSRF: "csrf", + Cookie: &http.Cookie{Value: "nonce"}}, + "", + true, + }, + {"empty code", + http.MethodGet, + "", + "", + base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), + []string{"pomerium.io"}, + trueValidator, + sessions.MockSessionStore{}, + providers.MockProvider{ + AuthenticateResponse: sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }}, + sessions.MockCSRFStore{ + ResponseCSRF: "csrf", + Cookie: &http.Cookie{Value: "nonce"}}, + "", + true, + }, + {"invalid state string", + http.MethodGet, + "", + "code", + "nonce:https://corp.pomerium.io", + []string{"pomerium.io"}, + trueValidator, + sessions.MockSessionStore{}, + providers.MockProvider{ + AuthenticateResponse: sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }}, + sessions.MockCSRFStore{ + ResponseCSRF: "csrf", + Cookie: &http.Cookie{Value: "nonce"}}, + "", + true, + }, + {"malformed state", + http.MethodGet, + "", + "code", + base64.URLEncoding.EncodeToString([]byte("nonce")), + []string{"pomerium.io"}, + trueValidator, + sessions.MockSessionStore{}, + providers.MockProvider{ + AuthenticateResponse: sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }}, + sessions.MockCSRFStore{ + ResponseCSRF: "csrf", + Cookie: &http.Cookie{Value: "nonce"}}, + "", + true, + }, + {"invalid redirect uri", + http.MethodGet, + "", + "code", + base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), + []string{"pomerium.io"}, + trueValidator, + sessions.MockSessionStore{}, + providers.MockProvider{ + AuthenticateResponse: sessions.SessionState{ + AccessToken: "AccessToken", + RefreshToken: "RefreshToken", + Email: "blah@blah.com", + LifetimeDeadline: time.Now().Add(10 * time.Second), + RefreshDeadline: time.Now().Add(10 * time.Second), + ValidDeadline: time.Now().Add(10 * time.Second), + }}, + sessions.MockCSRFStore{ + ResponseCSRF: "csrf", + Cookie: &http.Cookie{Value: "nonce"}}, + "", + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authenticate{ + sessionStore: tt.session, + csrfStore: tt.csrfStore, + provider: tt.provider, + ProxyRootDomains: tt.validDomains, + Validator: tt.validator, + } + u, _ := url.Parse("/oauthGet") + params, _ := url.ParseQuery(u.RawQuery) + params.Add("error", tt.paramErr) + params.Add("code", tt.code) + params.Add("state", tt.state) + + u.RawQuery = params.Encode() + + r := httptest.NewRequest(tt.method, u.String(), nil) + w := httptest.NewRecorder() + + got, err := a.getOAuthCallback(w, r) + if (err != nil) != tt.wantErr { + t.Errorf("Authenticate.getOAuthCallback() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Authenticate.getOAuthCallback() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authenticate/providers/mock_provider.go b/authenticate/providers/mock_provider.go new file mode 100644 index 000000000..43df1a52b --- /dev/null +++ b/authenticate/providers/mock_provider.go @@ -0,0 +1,41 @@ +package providers // import "github.com/pomerium/pomerium/internal/providers" + +import ( + "github.com/pomerium/pomerium/internal/sessions" // type Provider interface { + "golang.org/x/oauth2" +) + +// MockProvider provides a mocked implementation of the providers interface. +type MockProvider struct { + AuthenticateResponse sessions.SessionState + AuthenticateError error + ValidateResponse bool + ValidateError error + RefreshResponse *oauth2.Token + RefreshError error + RevokeError error + GetSignInURLResponse string +} + +// Authenticate is a mocked providers function. +func (mp MockProvider) Authenticate(code string) (*sessions.SessionState, error) { + return &mp.AuthenticateResponse, mp.AuthenticateError +} + +// Validate is a mocked providers function. +func (mp MockProvider) Validate(s string) (bool, error) { + return mp.ValidateResponse, mp.ValidateError +} + +// Refresh is a mocked providers function. +func (mp MockProvider) Refresh(s string) (*oauth2.Token, error) { + return mp.RefreshResponse, mp.RefreshError +} + +// Revoke is a mocked providers function. +func (mp MockProvider) Revoke(s string) error { + return mp.RevokeError +} + +// GetSignInURL is a mocked providers function. +func (mp MockProvider) GetSignInURL(s string) string { return mp.GetSignInURLResponse } diff --git a/internal/templates/templates.go b/internal/templates/templates.go index 7cdde09cc..46a22a6c4 100644 --- a/internal/templates/templates.go +++ b/internal/templates/templates.go @@ -176,9 +176,6 @@ footer {
- {{ if .Message }} -
{{.Message}}
- {{ end}}

Sign out of {{.Destination}}

diff --git a/proxy/authenticator/authenticator_test.go b/proxy/authenticator/authenticator_test.go index c362aba8c..2829afd6b 100644 --- a/proxy/authenticator/authenticator_test.go +++ b/proxy/authenticator/authenticator_test.go @@ -56,3 +56,24 @@ func TestMockAuthenticate(t *testing.T) { } } + +func TestNew(t *testing.T) { + tests := []struct { + name string + serviceName string + opts *Options + wantErr bool + }{ + {"grpc good", "grpc", &Options{Addr: "test", InternalAddr: "intranet.local", SharedSecret: "secret"}, false}, + {"grpc missing shared secret", "grpc", &Options{Addr: "test", InternalAddr: "intranet.local", SharedSecret: ""}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := New(tt.serviceName, tt.opts) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/proxy/authenticator/grpc_test.go b/proxy/authenticator/grpc_test.go index 6569412e9..595a0226e 100644 --- a/proxy/authenticator/grpc_test.go +++ b/proxy/authenticator/grpc_test.go @@ -192,8 +192,8 @@ func TestNewGRPC(t *testing.T) { {"no shared secret", &Options{}, true, "proxy/authenticator: grpc client requires shared secret"}, {"empty connection", &Options{Addr: "", SharedSecret: "shh"}, true, "proxy/authenticator: connection address required"}, {"empty connections", &Options{Addr: "", InternalAddr: "", SharedSecret: "shh"}, true, "proxy/authenticator: connection address required"}, - {"internal addr", &Options{Addr: "", InternalAddr: "intranet.local", SharedSecret: "shh"}, false, "proxy/authenticator: connection address required"}, - {"cert overide", &Options{Addr: "", InternalAddr: "intranet.local", OverideCertificateName: "*.local", SharedSecret: "shh"}, false, "proxy/authenticator: connection address required"}, + {"internal addr", &Options{Addr: "", InternalAddr: "intranet.local", SharedSecret: "shh"}, false, ""}, + {"cert overide", &Options{Addr: "", InternalAddr: "intranet.local", OverideCertificateName: "*.local", SharedSecret: "shh"}, false, ""}, // {"addr and internal ", &Options{Addr: "localhost", InternalAddr: "local.localhost", SharedSecret: "shh"}, nil, true, ""}, } diff --git a/proxy/handlers.go b/proxy/handlers.go index a4716825d..b6655ddec 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -121,7 +121,7 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) { // this value will be unique since we always use a randomized nonce as part of marshaling encryptedCSRF, err := p.cipher.Marshal(state) if err != nil { - log.FromRequest(r).Error().Err(err).Msg("failed to marshal csrf") + log.FromRequest(r).Error().Err(err).Msg("proxy: failed to marshal csrf") httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) return } @@ -131,7 +131,7 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) { // this value will be unique since we always use a randomized nonce as part of marshaling encryptedState, err := p.cipher.Marshal(state) if err != nil { - log.FromRequest(r).Error().Err(err).Msg("failed to encrypt cookie") + log.FromRequest(r).Error().Err(err).Msg("proxy: failed to encrypt cookie") httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) return } @@ -149,7 +149,7 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) { func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { - log.FromRequest(r).Error().Err(err).Msg("failed parsing request form") + log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing request form") httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError) return } @@ -161,27 +161,23 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) { // We begin the process of redeeming the code for an access token. rr, err := p.AuthenticateClient.Redeem(r.Form.Get("code")) if err != nil { - log.FromRequest(r).Error().Err(err).Msg("error redeeming authorization code") + log.FromRequest(r).Error().Err(err).Msg("proxy: error redeeming authorization code") httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError) return } encryptedState := r.Form.Get("state") - log.Warn(). - Str("encryptedState", encryptedState). - Msg("OK") - stateParameter := &StateParameter{} err = p.cipher.Unmarshal(encryptedState, stateParameter) if err != nil { - log.FromRequest(r).Error().Err(err).Msg("could not unmarshal state") + log.FromRequest(r).Error().Err(err).Msg("proxy: could not unmarshal state") httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError) return } c, err := p.csrfStore.GetCSRF(r) if err != nil { - log.FromRequest(r).Error().Err(err).Msg("failed parsing csrf cookie") + log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing csrf cookie") httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest) return } @@ -191,7 +187,7 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) { csrfParameter := &StateParameter{} err = p.cipher.Unmarshal(encryptedCSRF, csrfParameter) if err != nil { - log.FromRequest(r).Error().Err(err).Msg("couldn't unmarshal CSRF") + log.FromRequest(r).Error().Err(err).Msg("proxy: couldn't unmarshal CSRF") httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError) return } @@ -283,7 +279,7 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error) } if session.LifetimePeriodExpired() { - log.FromRequest(r).Info().Msg("proxy.Authenticate: lifetime expired, restarting") + log.FromRequest(r).Info().Msg("proxy: lifetime expired") return sessions.ErrLifetimeExpired } if session.RefreshPeriodExpired() { @@ -295,12 +291,12 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error) log.FromRequest(r).Warn(). Str("RefreshToken", session.RefreshToken). Str("AccessToken", session.AccessToken). - Msg("proxy.Authenticate: refresh failure") + Msg("proxy: refresh failed") return err } session.AccessToken = accessToken session.RefreshDeadline = expiry - log.FromRequest(r).Info().Msg("proxy.Authenticate: refresh success") + log.FromRequest(r).Info().Msg("proxy: refresh success") } err = p.sessionStore.SaveSession(w, r, session) diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index f2ef40325..4bb41dfb7 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -359,12 +359,6 @@ func TestProxy_Proxy(t *testing.T) { RefreshToken: "RefreshToken", LifetimeDeadline: time.Now().Add(-10 * time.Second), } - // expiredDeadline := &sessions.SessionState{ - // AccessToken: "AccessToken", - // RefreshToken: "RefreshToken", - // LifetimeDeadline: time.Now().Add(10 * time.Second), - // RefreshDeadline: time.Now().Add(-10 * time.Second), - // } tests := []struct { name string