authenticate: add tests, fix signout (#45)

- authenticate: a bug where sign out failed to revoke the remote session
- docs: add code coverage to readme
- authenticate: Rename shorthand receiver variable name
- authenticate: consolidate sign in
This commit is contained in:
Bobby DeSimone 2019-02-14 00:01:50 -08:00 committed by GitHub
parent 35ee3247d7
commit 805f0198d2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 1061 additions and 163 deletions

View file

@ -4,7 +4,7 @@
# Pomerium
[![Travis CI](https://travis-ci.org/pomerium/pomerium.svg?branch=master)](https://travis-ci.org/pomerium/pomerium) [![Go Report Card](https://goreportcard.com/badge/github.com/pomerium/pomerium)](https://goreportcard.com/report/github.com/pomerium/pomerium) [![GoDoc](https://godoc.org/github.com/pomerium/pomerium?status.svg)][godocs] [![LICENSE](https://img.shields.io/github/license/pomerium/pomerium.svg)](https://github.com/pomerium/pomerium/blob/master/LICENSE)
[![Travis CI](https://travis-ci.org/pomerium/pomerium.svg?branch=master)](https://travis-ci.org/pomerium/pomerium) [![Go Report Card](https://goreportcard.com/badge/github.com/pomerium/pomerium)](https://goreportcard.com/report/github.com/pomerium/pomerium) [![GoDoc](https://godoc.org/github.com/pomerium/pomerium?status.svg)][godocs] [![LICENSE](https://img.shields.io/github/license/pomerium/pomerium.svg)](https://github.com/pomerium/pomerium/blob/master/LICENSE)[![codecov](https://codecov.io/gh/pomerium/pomerium/branch/master/graph/badge.svg)](https://codecov.io/gh/pomerium/pomerium)
Pomerium is a tool for managing secure access to internal applications and resources.

View file

@ -29,7 +29,7 @@ var securityHeaders = map[string]string{
}
// Handler returns the Http.Handlers for authenticate, callback, and refresh
func (p *Authenticate) Handler() http.Handler {
func (a *Authenticate) Handler() http.Handler {
// set up our standard middlewares
stdMiddleware := middleware.NewChain()
stdMiddleware = stdMiddleware.Append(middleware.Healthcheck("/ping", version.UserAgent()))
@ -51,100 +51,101 @@ func (p *Authenticate) Handler() http.Handler {
stdMiddleware = stdMiddleware.Append(middleware.RefererHandler("referer"))
stdMiddleware = stdMiddleware.Append(middleware.RequestIDHandler("req_id", "Request-Id"))
validateSignatureMiddleware := stdMiddleware.Append(
middleware.ValidateSignature(p.SharedKey),
middleware.ValidateRedirectURI(p.ProxyRootDomains))
middleware.ValidateSignature(a.SharedKey),
middleware.ValidateRedirectURI(a.ProxyRootDomains))
mux := http.NewServeMux()
mux.Handle("/robots.txt", stdMiddleware.ThenFunc(p.RobotsTxt))
mux.Handle("/robots.txt", stdMiddleware.ThenFunc(a.RobotsTxt))
// Identity Provider (IdP) callback endpoints and callbacks
mux.Handle("/start", stdMiddleware.ThenFunc(p.OAuthStart))
mux.Handle("/oauth2/callback", stdMiddleware.ThenFunc(p.OAuthCallback))
mux.Handle("/start", stdMiddleware.ThenFunc(a.OAuthStart))
mux.Handle("/oauth2/callback", stdMiddleware.ThenFunc(a.OAuthCallback))
// authenticate-server endpoints
mux.Handle("/sign_in", validateSignatureMiddleware.ThenFunc(p.SignIn))
mux.Handle("/sign_out", validateSignatureMiddleware.ThenFunc(p.SignOut)) // GET POST
mux.Handle("/sign_in", validateSignatureMiddleware.ThenFunc(a.SignIn))
mux.Handle("/sign_out", validateSignatureMiddleware.ThenFunc(a.SignOut)) // GET POST
return mux
}
// RobotsTxt handles the /robots.txt route.
func (p *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
}
func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*sessions.SessionState, error) {
session, err := p.sessionStore.LoadSession(r)
func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*sessions.SessionState, error) {
session, err := a.sessionStore.LoadSession(r)
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to load session")
p.sessionStore.ClearSession(w, r)
a.sessionStore.ClearSession(w, r)
return nil, err
}
// if long-lived lifetime has expired, clear session
if session.LifetimePeriodExpired() {
log.FromRequest(r).Warn().Msg("authenticate: lifetime expired")
p.sessionStore.ClearSession(w, r)
a.sessionStore.ClearSession(w, r)
return nil, sessions.ErrLifetimeExpired
}
// check if session refresh period is up
if session.RefreshPeriodExpired() {
newToken, err := p.provider.Refresh(session.RefreshToken)
newToken, err := a.provider.Refresh(session.RefreshToken)
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to refresh session")
p.sessionStore.ClearSession(w, r)
a.sessionStore.ClearSession(w, r)
return nil, err
}
session.AccessToken = newToken.AccessToken
session.RefreshDeadline = newToken.Expiry
err = p.sessionStore.SaveSession(w, r, session)
err = a.sessionStore.SaveSession(w, r, session)
if err != nil {
// We refreshed the session successfully, but failed to save it.
// This could be from failing to encode the session properly.
// But, we clear the session cookie and reject the request
log.FromRequest(r).Error().Err(err).Msg("could not save refreshed session")
p.sessionStore.ClearSession(w, r)
a.sessionStore.ClearSession(w, r)
return nil, err
}
} else {
// The session has not exceeded it's lifetime or requires refresh
ok, err := p.provider.Validate(session.IDToken)
ok, err := a.provider.Validate(session.IDToken)
if !ok || err != nil {
log.FromRequest(r).Error().Err(err).Msg("invalid session state")
p.sessionStore.ClearSession(w, r)
a.sessionStore.ClearSession(w, r)
return nil, httputil.ErrUserNotAuthorized
}
err = p.sessionStore.SaveSession(w, r, session)
err = a.sessionStore.SaveSession(w, r, session)
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("failed to save valid session")
p.sessionStore.ClearSession(w, r)
a.sessionStore.ClearSession(w, r)
return nil, err
}
}
// authenticate really should not be in the business of authorization
// todo(bdd) : remove when authorization module added
if !p.Validator(session.Email) {
if !a.Validator(session.Email) {
log.FromRequest(r).Error().Msg("invalid email user")
return nil, httputil.ErrUserNotAuthorized
}
log.Info().Msg("authenticate")
return session, nil
}
// SignIn handles the /sign_in endpoint. It attempts to authenticate the user,
// and if the user is not authenticated, it renders a sign in page.
func (p *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
session, err := p.authenticate(w, r)
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
session, err := a.authenticate(w, r)
switch err {
case nil:
// User is authenticated, redirect back to proxy
p.ProxyOAuthRedirect(w, r, session)
// session good, redirect back to proxy
log.FromRequest(r).Info().Msg("authenticate.SignIn : authenticated")
a.ProxyCallback(w, r, session)
case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
log.Info().Err(err).Msg("authenticate.SignIn : expected failure")
// session invalid, authenticate
log.FromRequest(r).Info().Err(err).Msg("authenticate.SignIn : expected failure")
if err != http.ErrNoCookie {
p.sessionStore.ClearSession(w, r)
a.sessionStore.ClearSession(w, r)
}
p.OAuthStart(w, r)
a.OAuthStart(w, r)
default:
log.Error().Err(err).Msg("authenticate: unexpected sign in error")
@ -152,10 +153,10 @@ func (p *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
}
}
// ProxyOAuthRedirect redirects the user back to proxy's redirection endpoint.
// This workflow corresponds to Section 3.1.2 of the OAuth2 RFC.
// See https://tools.ietf.org/html/rfc6749#section-3.1.2 for more specific information.
func (p *Authenticate) ProxyOAuthRedirect(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) {
// ProxyCallback redirects the user back to proxy service along with an encrypted payload, as
// url params, of the user's session state.
// See RFC6749 3.1.2 https://tools.ietf.org/html/rfc6749#section-3.1.2
func (a *Authenticate) ProxyCallback(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) {
err := r.ParseForm()
if err != nil {
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
@ -180,7 +181,7 @@ func (p *Authenticate) ProxyOAuthRedirect(w http.ResponseWriter, r *http.Request
return
}
// encrypt session state as json blob
encrypted, err := sessions.MarshalSession(session, p.cipher)
encrypted, err := sessions.MarshalSession(session, a.cipher)
if err != nil {
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
return
@ -193,75 +194,43 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string
params, _ := url.ParseQuery(u.RawQuery)
params.Set("code", authCode)
params.Set("state", state)
u.RawQuery = params.Encode()
if u.Scheme == "" {
u.Scheme = "https"
}
return u.String()
}
// SignOut signs the user out.
func (p *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
// SignOut signs the user out by trying to revoke the users remote identity provider session
// then removes the associated local session state.
// Handles both GET and POST of form.
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
return
}
// pretty safe to say that no matter what heppanes here, we want to revoke the local session
redirectURI := r.Form.Get("redirect_uri")
session, err := a.sessionStore.LoadSession(r)
if err != nil {
log.Error().Err(err).Msg("authenticate: signout failed to load session")
httputil.ErrorResponse(w, r, "No session found to log out", http.StatusBadRequest)
return
}
if r.Method == "GET" {
p.SignOutPage(w, r, "")
return
}
session, err := p.sessionStore.LoadSession(r)
switch err {
case nil:
break
case http.ErrNoCookie: // if there's no cookie in the session we can just redirect
log.Error().Err(err).Msg("authenticate.SignOut : no cookie")
http.Redirect(w, r, redirectURI, http.StatusFound)
return
default:
// a different error, clear the session cookie and redirect
log.Error().Err(err).Msg("authenticate.SignOut : error loading cookie session")
p.sessionStore.ClearSession(w, r)
http.Redirect(w, r, redirectURI, http.StatusFound)
return
}
err = p.provider.Revoke(session.AccessToken)
if err != nil {
log.Error().Err(err).Msg("authenticate.SignOut : error revoking session")
p.SignOutPage(w, r, "An error occurred during sign out. Please try again.")
return
}
p.sessionStore.ClearSession(w, r)
http.Redirect(w, r, redirectURI, http.StatusFound)
}
// SignOutPage renders a sign out page with a message
func (p *Authenticate) SignOutPage(w http.ResponseWriter, r *http.Request, message string) {
// validateRedirectURI middleware already ensures that this is a valid URL
redirectURI := r.Form.Get("redirect_uri")
session, err := p.sessionStore.LoadSession(r)
if err != nil {
http.Redirect(w, r, redirectURI, http.StatusFound)
return
}
signature := r.Form.Get("sig")
timestamp := r.Form.Get("ts")
destinationURL, err := url.Parse(redirectURI)
// An error message indicates that an internal server error occurred
if message != "" || err != nil {
log.Error().Err(err).Msg("authenticate.SignOutPage")
w.WriteHeader(http.StatusInternalServerError)
if err != nil {
log.Error().Err(err).Msg("authenticate: malformed destination url")
httputil.ErrorResponse(w, r, "Malformed destination URL", http.StatusBadRequest)
return
}
t := struct {
Redirect string
Signature string
Timestamp string
Message string
Destination string
Email string
Version string
@ -269,35 +238,45 @@ func (p *Authenticate) SignOutPage(w http.ResponseWriter, r *http.Request, messa
Redirect: redirectURI,
Signature: signature,
Timestamp: timestamp,
Message: message,
Destination: destinationURL.Host,
Email: session.Email,
Version: version.FullVersion(),
}
p.templates.ExecuteTemplate(w, "sign_out.html", t)
a.templates.ExecuteTemplate(w, "sign_out.html", t)
w.WriteHeader(http.StatusOK)
return
}
a.sessionStore.ClearSession(w, r)
err = a.provider.Revoke(session.AccessToken)
if err != nil {
log.Error().Err(err).Msg("authenticate: failed to revoke user session")
httputil.ErrorResponse(w, r, fmt.Sprintf("could not revoke session: %s ", err.Error()), http.StatusBadRequest)
return
}
http.Redirect(w, r, redirectURI, http.StatusFound)
}
// OAuthStart starts the authenticate process by redirecting to the provider. It provides a
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authenticate.
func (p *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
authRedirectURL, err := url.Parse(r.URL.Query().Get("redirect_uri"))
if err != nil {
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
return
}
authRedirectURL = p.RedirectURL.ResolveReference(r.URL)
authRedirectURL = a.RedirectURL.ResolveReference(r.URL)
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
p.csrfStore.SetCSRF(w, r, nonce)
a.csrfStore.SetCSRF(w, r, nonce)
// verify redirect uri is from the root domain
if !middleware.ValidRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
if !middleware.ValidRedirectURI(authRedirectURL.String(), a.ProxyRootDomains) {
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
return
}
// verify proxy url is from the root domain
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
if err != nil || !middleware.ValidRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
if err != nil || !middleware.ValidRedirectURI(proxyRedirectURL.String(), a.ProxyRootDomains) {
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
return
}
@ -305,7 +284,7 @@ func (p *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
// get the signature and timestamp values then compare hmac
proxyRedirectSig := authRedirectURL.Query().Get("sig")
ts := authRedirectURL.Query().Get("ts")
if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.SharedKey) {
if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, a.SharedKey) {
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
return
}
@ -313,16 +292,35 @@ func (p *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
// concat base64'd nonce and authenticate url to make state
state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String())))
// build the provider sign in url
signInURL := p.provider.GetSignInURL(state)
signInURL := a.provider.GetSignInURL(state)
http.Redirect(w, r, signInURL, http.StatusFound)
}
// OAuthCallback handles the callback from the identity provider. Displays an error page if there
// was an error. If successful, redirects back to the proxy-service via the redirect-url.
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
redirect, err := a.getOAuthCallback(w, r)
switch h := err.(type) {
case nil:
break
case httputil.HTTPError:
log.Error().Err(err).Msg("authenticate: oauth callback error")
httputil.ErrorResponse(w, r, h.Message, h.Code)
return
default:
log.Error().Err(err).Msg("authenticate: unexpected oauth callback error")
httputil.ErrorResponse(w, r, "Internal Error", http.StatusInternalServerError)
return
}
// redirect back to the proxy-service
http.Redirect(w, r, redirect, http.StatusFound)
}
// getOAuthCallback completes the oauth cycle from an identity provider's callback
func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) {
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) {
err := r.ParseForm()
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("authenticate: bad form on oauth callback")
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
}
errorString := r.Form.Get("error")
@ -336,7 +334,7 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Missing Code"}
}
session, err := p.provider.Authenticate(code)
session, err := a.provider.Authenticate(code)
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("authenticate: error redeeming authenticate code")
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
@ -353,50 +351,30 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
}
nonce := s[0]
redirect := s[1]
c, err := p.csrfStore.GetCSRF(r)
c, err := a.csrfStore.GetCSRF(r)
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("authenticate: bad csrf")
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Missing CSRF token"}
}
p.csrfStore.ClearCSRF(w, r)
a.csrfStore.ClearCSRF(w, r)
if c.Value != nonce {
log.FromRequest(r).Error().Err(err).Msg("authenticate: csrf mismatch")
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "CSRF failed"}
}
if !middleware.ValidRedirectURI(redirect, p.ProxyRootDomains) {
if !middleware.ValidRedirectURI(redirect, a.ProxyRootDomains) {
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Redirect URI"}
}
// Set cookie, or deny: validates the session email and group
if !p.Validator(session.Email) {
if !a.Validator(session.Email) {
log.FromRequest(r).Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "You don't have access"}
}
err = p.sessionStore.SaveSession(w, r, session)
err = a.sessionStore.SaveSession(w, r, session)
if err != nil {
log.Error().Err(err).Msg("internal error")
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Internal Error"}
}
return redirect, nil
}
// OAuthCallback handles the callback from the identity provider. Displays an error page if there
// was an error. If successful, redirects back to the proxy-service via the redirect-url.
func (p *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
redirect, err := p.getOAuthCallback(w, r)
switch h := err.(type) {
case nil:
break
case httputil.HTTPError:
log.Error().Err(err).Msg("authenticate: oauth callback error")
httputil.ErrorResponse(w, r, h.Message, h.Code)
return
default:
log.Error().Err(err).Msg("authenticate: unexpected oauth callback error")
httputil.ErrorResponse(w, r, "Internal Error", http.StatusInternalServerError)
return
}
// redirect back to the proxy-service
http.Redirect(w, r, redirect, http.StatusFound)
}

View file

@ -1,15 +1,27 @@
package authenticate
import (
"encoding/base64"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/pomerium/pomerium/authenticate/providers"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates"
"golang.org/x/oauth2"
)
// mocks for validator func
func trueValidator(s string) bool { return true }
func falseValidator(s string) bool { return false }
func testAuthenticate() *Authenticate {
var auth Authenticate
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
@ -37,3 +49,862 @@ func TestAuthenticate_RobotsTxt(t *testing.T) {
t.Errorf("handler returned wrong body: got %v want %v", rr.Body.String(), expected)
}
}
func TestAuthenticate_Handler(t *testing.T) {
auth := testAuthenticate()
h := auth.Handler()
if h == nil {
t.Error("handler cannot be nil")
}
req := httptest.NewRequest("GET", "/robots.txt", nil)
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
expected := fmt.Sprintf("User-agent: *\nDisallow: /")
body := rr.Body.String()
if body != expected {
t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
}
}
func TestAuthenticate_authenticate(t *testing.T) {
// sessions.MockSessionStore{Session: expiredLifetime}
goodSession := sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}
expiredSession := sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * -time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}
expiredRefresPeriod := sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * -time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}
tests := []struct {
name string
session sessions.SessionStore
provider providers.MockProvider
validator func(string) bool
want *sessions.SessionState
wantErr bool
}{
{"good", goodSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, false},
{"good but fails validation", goodSession, providers.MockProvider{ValidateResponse: true}, falseValidator, nil, true},
{"can't load session", sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
{"validation fails", goodSession, providers.MockProvider{ValidateResponse: false}, trueValidator, nil, true},
{"session fails after good validation", sessions.MockSessionStore{
SaveError: errors.New("error"),
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
{"lifetime expired", expiredSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
{"refresh expired",
expiredRefresPeriod,
providers.MockProvider{
ValidateResponse: true,
RefreshResponse: &oauth2.Token{
AccessToken: "new token",
Expiry: time.Now(),
},
},
trueValidator, nil, false},
{"refresh expired refresh error",
expiredRefresPeriod,
providers.MockProvider{
ValidateResponse: true,
RefreshError: errors.New("error"),
},
trueValidator, nil, true},
{"refresh expired failed save",
sessions.MockSessionStore{
SaveError: errors.New("error"),
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * -time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
providers.MockProvider{
ValidateResponse: true,
RefreshResponse: &oauth2.Token{
AccessToken: "new token",
Expiry: time.Now(),
},
},
trueValidator, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Authenticate{
sessionStore: tt.session,
provider: tt.provider,
Validator: tt.validator,
}
r := httptest.NewRequest("GET", "/auth", nil)
w := httptest.NewRecorder()
_, err := p.authenticate(w, r)
if (err != nil) != tt.wantErr {
t.Errorf("Authenticate.authenticate() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func TestAuthenticate_SignIn(t *testing.T) {
tests := []struct {
name string
session sessions.SessionStore
provider providers.MockProvider
validator func(string) bool
wantCode int
}{
{"good",
sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
providers.MockProvider{ValidateResponse: true},
trueValidator,
403},
// {"no session",
// sessions.MockSessionStore{
// Session: &sessions.SessionState{
// AccessToken: "AccessToken",
// RefreshToken: "RefreshToken",
// LifetimeDeadline: time.Now().Add(-10 * time.Second),
// RefreshDeadline: time.Now().Add(10 * time.Second),
// ValidDeadline: time.Now().Add(10 * time.Second),
// }},
// providers.MockProvider{ValidateResponse: true},
// trueValidator,
// 200},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Authenticate{
sessionStore: tt.session,
provider: tt.provider,
Validator: tt.validator,
}
r := httptest.NewRequest("GET", "/sign-in", nil)
w := httptest.NewRecorder()
a.SignIn(w, r)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
}
})
}
}
type mockCipher struct{}
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Marshal(s interface{}) (string, error) { return "ok", nil }
func (a mockCipher) Unmarshal(s string, i interface{}) error {
if string(s) == "unmarshal error" || string(s) == "error" {
return errors.New("error")
}
return nil
}
func TestAuthenticate_ProxyCallback(t *testing.T) {
tests := []struct {
name string
uri string
state string
authCode string
sessionState *sessions.SessionState
sessionStore sessions.SessionStore
wantCode int
wantBody string
}{
{"good", "https://corp.pomerium.io/", "state", "code",
&sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
sessions.MockSessionStore{},
302,
"<a href=\"https://corp.pomerium.io/?code=ok&amp;state=state\">Found</a>."},
{"no state",
"https://corp.pomerium.io/",
"",
"code",
&sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
sessions.MockSessionStore{},
403,
"no state parameter supplied"},
{"no redirect_url",
"",
"state",
"code",
&sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
sessions.MockSessionStore{},
403,
"no redirect_uri parameter"},
{"malformed redirect_url",
"https://pomerium.com%zzzzz",
"state",
"code",
&sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
sessions.MockSessionStore{},
400,
"malformed redirect_uri"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Authenticate{
sessionStore: tt.sessionStore,
cipher: mockCipher{},
}
u, _ := url.Parse("https://pomerium.io/redirect")
params, _ := url.ParseQuery(u.RawQuery)
params.Set("code", tt.authCode)
params.Set("state", tt.state)
params.Set("redirect_uri", tt.uri)
u.RawQuery = params.Encode()
r := httptest.NewRequest("GET", u.String(), nil)
w := httptest.NewRecorder()
a.ProxyCallback(w, r, tt.sessionState)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
}
if body := w.Body.String(); !strings.Contains(body, tt.wantBody) {
t.Errorf("handler returned wrong body Body: got \n%s \n%s", body, tt.wantBody)
}
})
}
}
func Test_getAuthCodeRedirectURL(t *testing.T) {
tests := []struct {
name string
redirectURL *url.URL
state string
authCode string
want string
}{
{"https", uriParse("https://www.pomerium.io"), "state", "auth-code", "https://www.pomerium.io?code=auth-code&state=state"},
{"http", uriParse("http://www.pomerium.io"), "state", "auth-code", "http://www.pomerium.io?code=auth-code&state=state"},
{"no subdomain", uriParse("http://pomerium.io"), "state", "auth-code", "http://pomerium.io?code=auth-code&state=state"},
{"no scheme make https", uriParse("pomerium.io"), "state", "auth-code", "https://pomerium.io?code=auth-code&state=state"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getAuthCodeRedirectURL(tt.redirectURL, tt.state, tt.authCode); got != tt.want {
t.Errorf("getAuthCodeRedirectURL() = %v, want %v", got, tt.want)
}
})
}
}
func uriParse(s string) *url.URL {
uri, _ := url.Parse(s)
return uri
}
func TestAuthenticate_SignOut(t *testing.T) {
tests := []struct {
name string
method string
redirectURL string
sig string
ts string
provider providers.Provider
sessionStore sessions.SessionStore
wantCode int
wantBody string
}{
{"good post",
http.MethodPost,
"https://corp.pomerium.io/",
"sig",
"ts",
providers.MockProvider{},
sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusFound,
""},
{"failed revoke",
http.MethodPost,
"https://corp.pomerium.io/",
"sig",
"ts",
providers.MockProvider{RevokeError: errors.New("OH NO")},
sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusBadRequest,
"could not revoke"},
{"good get",
http.MethodGet,
"https://corp.pomerium.io/",
"sig",
"ts",
providers.MockProvider{},
sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusOK,
"This will also sign you out of other internal apps."},
{"cannot load session",
http.MethodGet,
"https://corp.pomerium.io/",
"sig",
"ts",
providers.MockProvider{},
sessions.MockSessionStore{
LoadError: errors.New("uh oh"),
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusBadRequest,
"No session found to log out"},
{"bad redirect url get",
http.MethodGet,
"https://pomerium.com%zzzzz",
"sig",
"ts",
providers.MockProvider{},
sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
},
},
http.StatusBadRequest,
"Error"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Authenticate{
sessionStore: tt.sessionStore,
provider: tt.provider,
cipher: mockCipher{},
templates: templates.New(),
}
u, _ := url.Parse("/sign_out")
params, _ := url.ParseQuery(u.RawQuery)
params.Add("sig", tt.sig)
params.Add("ts", tt.ts)
params.Add("redirect_uri", tt.redirectURL)
u.RawQuery = params.Encode()
r := httptest.NewRequest(tt.method, u.String(), nil)
w := httptest.NewRecorder()
a.SignOut(w, r)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
}
if body := w.Body.String(); !strings.Contains(body, tt.wantBody) {
t.Errorf("handler returned wrong body Body: got \n%s \n%s", body, tt.wantBody)
}
})
}
}
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) string {
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
h := cryptutil.Hash(secret, data)
return base64.URLEncoding.EncodeToString(h)
}
func TestAuthenticate_OAuthStart(t *testing.T) {
tests := []struct {
name string
method string
redirectURL string
sig string
ts string
allowedDomains []string
provider providers.Provider
csrfStore sessions.MockCSRFStore
// sessionStore sessions.SessionStore
wantCode int
}{
{"good",
http.MethodGet,
"https://corp.pomerium.io/",
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
fmt.Sprint(time.Now().Unix()),
[]string{".pomerium.io"},
providers.MockProvider{},
sessions.MockCSRFStore{},
http.StatusFound,
},
{"bad timestamp",
http.MethodGet,
"https://corp.pomerium.io/",
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
fmt.Sprint(time.Now().Add(10 * time.Hour).Unix()),
[]string{".pomerium.io"},
providers.MockProvider{},
sessions.MockCSRFStore{},
http.StatusBadRequest,
},
{"domain not in allowed domains",
http.MethodGet,
"https://corp.pomerium.io/",
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
fmt.Sprint(time.Now().Unix()),
[]string{"not.pomerium.io"},
providers.MockProvider{},
sessions.MockCSRFStore{},
http.StatusBadRequest,
},
{"missing redirect",
http.MethodGet,
"",
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
fmt.Sprint(time.Now().Unix()),
[]string{".pomerium.io"},
providers.MockProvider{},
sessions.MockCSRFStore{},
http.StatusBadRequest,
},
{"malformed redirect",
http.MethodGet,
"https://pomerium.com%zzzzz",
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
fmt.Sprint(time.Now().Unix()),
[]string{".pomerium.io"},
providers.MockProvider{},
sessions.MockCSRFStore{},
http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Authenticate{
ProxyRootDomains: tt.allowedDomains,
RedirectURL: uriParse("http://www.pomerium.io"),
csrfStore: tt.csrfStore,
provider: tt.provider,
SharedKey: "secret",
cipher: mockCipher{},
}
u, _ := url.Parse("/oauth_start")
params, _ := url.ParseQuery(u.RawQuery)
params.Add("sig", tt.sig)
params.Add("ts", tt.ts)
params.Add("redirect_uri", tt.redirectURL)
u.RawQuery = params.Encode()
r := httptest.NewRequest(tt.method, u.String(), nil)
w := httptest.NewRecorder()
a.OAuthStart(w, r)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
}
})
}
}
func TestAuthenticate_getOAuthCallback(t *testing.T) {
tests := []struct {
name string
method string
// url params
paramErr string
code string
state string
validDomains []string
validator func(string) bool
session sessions.SessionStore
provider providers.MockProvider
csrfStore sessions.MockCSRFStore
want string
wantErr bool
}{
{"good",
http.MethodGet,
"",
"code",
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
Cookie: &http.Cookie{Value: "nonce"}},
"https://corp.pomerium.io",
false,
},
{"get csrf error",
http.MethodGet,
"",
"code",
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
GetError: errors.New("error"),
Cookie: &http.Cookie{Value: "not nonce"}},
"",
true,
},
{"csrf nonce error",
http.MethodGet,
"",
"code",
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
Cookie: &http.Cookie{Value: "not nonce"}},
"",
true,
},
{"failed authenticate",
http.MethodGet,
"",
"code",
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateError: errors.New("error"),
},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
Cookie: &http.Cookie{Value: "nonce"}},
"",
true,
},
{"failed save session",
http.MethodGet,
"",
"code",
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{SaveError: errors.New("error")},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
Cookie: &http.Cookie{Value: "nonce"}},
"",
true,
},
{"failed email validation",
http.MethodGet,
"",
"code",
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
falseValidator,
sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
Cookie: &http.Cookie{Value: "nonce"}},
"",
true,
},
{"error returned",
http.MethodGet,
"idp error",
"code",
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
Cookie: &http.Cookie{Value: "nonce"}},
"",
true,
},
{"empty code",
http.MethodGet,
"",
"",
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
Cookie: &http.Cookie{Value: "nonce"}},
"",
true,
},
{"invalid state string",
http.MethodGet,
"",
"code",
"nonce:https://corp.pomerium.io",
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
Cookie: &http.Cookie{Value: "nonce"}},
"",
true,
},
{"malformed state",
http.MethodGet,
"",
"code",
base64.URLEncoding.EncodeToString([]byte("nonce")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
Cookie: &http.Cookie{Value: "nonce"}},
"",
true,
},
{"invalid redirect uri",
http.MethodGet,
"",
"code",
base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")),
[]string{"pomerium.io"},
trueValidator,
sessions.MockSessionStore{},
providers.MockProvider{
AuthenticateResponse: sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Email: "blah@blah.com",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}},
sessions.MockCSRFStore{
ResponseCSRF: "csrf",
Cookie: &http.Cookie{Value: "nonce"}},
"",
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Authenticate{
sessionStore: tt.session,
csrfStore: tt.csrfStore,
provider: tt.provider,
ProxyRootDomains: tt.validDomains,
Validator: tt.validator,
}
u, _ := url.Parse("/oauthGet")
params, _ := url.ParseQuery(u.RawQuery)
params.Add("error", tt.paramErr)
params.Add("code", tt.code)
params.Add("state", tt.state)
u.RawQuery = params.Encode()
r := httptest.NewRequest(tt.method, u.String(), nil)
w := httptest.NewRecorder()
got, err := a.getOAuthCallback(w, r)
if (err != nil) != tt.wantErr {
t.Errorf("Authenticate.getOAuthCallback() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Authenticate.getOAuthCallback() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,41 @@
package providers // import "github.com/pomerium/pomerium/internal/providers"
import (
"github.com/pomerium/pomerium/internal/sessions" // type Provider interface {
"golang.org/x/oauth2"
)
// MockProvider provides a mocked implementation of the providers interface.
type MockProvider struct {
AuthenticateResponse sessions.SessionState
AuthenticateError error
ValidateResponse bool
ValidateError error
RefreshResponse *oauth2.Token
RefreshError error
RevokeError error
GetSignInURLResponse string
}
// Authenticate is a mocked providers function.
func (mp MockProvider) Authenticate(code string) (*sessions.SessionState, error) {
return &mp.AuthenticateResponse, mp.AuthenticateError
}
// Validate is a mocked providers function.
func (mp MockProvider) Validate(s string) (bool, error) {
return mp.ValidateResponse, mp.ValidateError
}
// Refresh is a mocked providers function.
func (mp MockProvider) Refresh(s string) (*oauth2.Token, error) {
return mp.RefreshResponse, mp.RefreshError
}
// Revoke is a mocked providers function.
func (mp MockProvider) Revoke(s string) error {
return mp.RevokeError
}
// GetSignInURL is a mocked providers function.
func (mp MockProvider) GetSignInURL(s string) string { return mp.GetSignInURLResponse }

View file

@ -176,9 +176,6 @@ footer {
</head>
<body>
<div class="container">
{{ if .Message }}
<div class="message">{{.Message}}</div>
{{ end}}
<div class="content">
<header>
<h1>Sign out of <b>{{.Destination}}</b></h1>

View file

@ -56,3 +56,24 @@ func TestMockAuthenticate(t *testing.T) {
}
}
func TestNew(t *testing.T) {
tests := []struct {
name string
serviceName string
opts *Options
wantErr bool
}{
{"grpc good", "grpc", &Options{Addr: "test", InternalAddr: "intranet.local", SharedSecret: "secret"}, false},
{"grpc missing shared secret", "grpc", &Options{Addr: "test", InternalAddr: "intranet.local", SharedSecret: ""}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := New(tt.serviceName, tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

View file

@ -192,8 +192,8 @@ func TestNewGRPC(t *testing.T) {
{"no shared secret", &Options{}, true, "proxy/authenticator: grpc client requires shared secret"},
{"empty connection", &Options{Addr: "", SharedSecret: "shh"}, true, "proxy/authenticator: connection address required"},
{"empty connections", &Options{Addr: "", InternalAddr: "", SharedSecret: "shh"}, true, "proxy/authenticator: connection address required"},
{"internal addr", &Options{Addr: "", InternalAddr: "intranet.local", SharedSecret: "shh"}, false, "proxy/authenticator: connection address required"},
{"cert overide", &Options{Addr: "", InternalAddr: "intranet.local", OverideCertificateName: "*.local", SharedSecret: "shh"}, false, "proxy/authenticator: connection address required"},
{"internal addr", &Options{Addr: "", InternalAddr: "intranet.local", SharedSecret: "shh"}, false, ""},
{"cert overide", &Options{Addr: "", InternalAddr: "intranet.local", OverideCertificateName: "*.local", SharedSecret: "shh"}, false, ""},
// {"addr and internal ", &Options{Addr: "localhost", InternalAddr: "local.localhost", SharedSecret: "shh"}, nil, true, ""},
}

View file

@ -121,7 +121,7 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
// this value will be unique since we always use a randomized nonce as part of marshaling
encryptedCSRF, err := p.cipher.Marshal(state)
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("failed to marshal csrf")
log.FromRequest(r).Error().Err(err).Msg("proxy: failed to marshal csrf")
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
return
}
@ -131,7 +131,7 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
// this value will be unique since we always use a randomized nonce as part of marshaling
encryptedState, err := p.cipher.Marshal(state)
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("failed to encrypt cookie")
log.FromRequest(r).Error().Err(err).Msg("proxy: failed to encrypt cookie")
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
return
}
@ -149,7 +149,7 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("failed parsing request form")
log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing request form")
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
return
}
@ -161,27 +161,23 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
// We begin the process of redeeming the code for an access token.
rr, err := p.AuthenticateClient.Redeem(r.Form.Get("code"))
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("error redeeming authorization code")
log.FromRequest(r).Error().Err(err).Msg("proxy: error redeeming authorization code")
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
return
}
encryptedState := r.Form.Get("state")
log.Warn().
Str("encryptedState", encryptedState).
Msg("OK")
stateParameter := &StateParameter{}
err = p.cipher.Unmarshal(encryptedState, stateParameter)
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("could not unmarshal state")
log.FromRequest(r).Error().Err(err).Msg("proxy: could not unmarshal state")
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
return
}
c, err := p.csrfStore.GetCSRF(r)
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("failed parsing csrf cookie")
log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing csrf cookie")
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
return
}
@ -191,7 +187,7 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
csrfParameter := &StateParameter{}
err = p.cipher.Unmarshal(encryptedCSRF, csrfParameter)
if err != nil {
log.FromRequest(r).Error().Err(err).Msg("couldn't unmarshal CSRF")
log.FromRequest(r).Error().Err(err).Msg("proxy: couldn't unmarshal CSRF")
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
return
}
@ -283,7 +279,7 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
}
if session.LifetimePeriodExpired() {
log.FromRequest(r).Info().Msg("proxy.Authenticate: lifetime expired, restarting")
log.FromRequest(r).Info().Msg("proxy: lifetime expired")
return sessions.ErrLifetimeExpired
}
if session.RefreshPeriodExpired() {
@ -295,12 +291,12 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
log.FromRequest(r).Warn().
Str("RefreshToken", session.RefreshToken).
Str("AccessToken", session.AccessToken).
Msg("proxy.Authenticate: refresh failure")
Msg("proxy: refresh failed")
return err
}
session.AccessToken = accessToken
session.RefreshDeadline = expiry
log.FromRequest(r).Info().Msg("proxy.Authenticate: refresh success")
log.FromRequest(r).Info().Msg("proxy: refresh success")
}
err = p.sessionStore.SaveSession(w, r, session)

View file

@ -359,12 +359,6 @@ func TestProxy_Proxy(t *testing.T) {
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(-10 * time.Second),
}
// expiredDeadline := &sessions.SessionState{
// AccessToken: "AccessToken",
// RefreshToken: "RefreshToken",
// LifetimeDeadline: time.Now().Add(10 * time.Second),
// RefreshDeadline: time.Now().Add(-10 * time.Second),
// }
tests := []struct {
name string