mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-31 18:07:17 +02:00
authenticate: add tests, fix signout (#45)
- authenticate: a bug where sign out failed to revoke the remote session - docs: add code coverage to readme - authenticate: Rename shorthand receiver variable name - authenticate: consolidate sign in
This commit is contained in:
parent
35ee3247d7
commit
805f0198d2
9 changed files with 1061 additions and 163 deletions
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
# Pomerium
|
# Pomerium
|
||||||
|
|
||||||
[](https://travis-ci.org/pomerium/pomerium) [](https://goreportcard.com/report/github.com/pomerium/pomerium) [][godocs] [](https://github.com/pomerium/pomerium/blob/master/LICENSE)
|
[](https://travis-ci.org/pomerium/pomerium) [](https://goreportcard.com/report/github.com/pomerium/pomerium) [][godocs] [](https://github.com/pomerium/pomerium/blob/master/LICENSE)[](https://codecov.io/gh/pomerium/pomerium)
|
||||||
|
|
||||||
Pomerium is a tool for managing secure access to internal applications and resources.
|
Pomerium is a tool for managing secure access to internal applications and resources.
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ var securityHeaders = map[string]string{
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler returns the Http.Handlers for authenticate, callback, and refresh
|
// Handler returns the Http.Handlers for authenticate, callback, and refresh
|
||||||
func (p *Authenticate) Handler() http.Handler {
|
func (a *Authenticate) Handler() http.Handler {
|
||||||
// set up our standard middlewares
|
// set up our standard middlewares
|
||||||
stdMiddleware := middleware.NewChain()
|
stdMiddleware := middleware.NewChain()
|
||||||
stdMiddleware = stdMiddleware.Append(middleware.Healthcheck("/ping", version.UserAgent()))
|
stdMiddleware = stdMiddleware.Append(middleware.Healthcheck("/ping", version.UserAgent()))
|
||||||
|
@ -51,100 +51,101 @@ func (p *Authenticate) Handler() http.Handler {
|
||||||
stdMiddleware = stdMiddleware.Append(middleware.RefererHandler("referer"))
|
stdMiddleware = stdMiddleware.Append(middleware.RefererHandler("referer"))
|
||||||
stdMiddleware = stdMiddleware.Append(middleware.RequestIDHandler("req_id", "Request-Id"))
|
stdMiddleware = stdMiddleware.Append(middleware.RequestIDHandler("req_id", "Request-Id"))
|
||||||
validateSignatureMiddleware := stdMiddleware.Append(
|
validateSignatureMiddleware := stdMiddleware.Append(
|
||||||
middleware.ValidateSignature(p.SharedKey),
|
middleware.ValidateSignature(a.SharedKey),
|
||||||
middleware.ValidateRedirectURI(p.ProxyRootDomains))
|
middleware.ValidateRedirectURI(a.ProxyRootDomains))
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.Handle("/robots.txt", stdMiddleware.ThenFunc(p.RobotsTxt))
|
mux.Handle("/robots.txt", stdMiddleware.ThenFunc(a.RobotsTxt))
|
||||||
// Identity Provider (IdP) callback endpoints and callbacks
|
// Identity Provider (IdP) callback endpoints and callbacks
|
||||||
mux.Handle("/start", stdMiddleware.ThenFunc(p.OAuthStart))
|
mux.Handle("/start", stdMiddleware.ThenFunc(a.OAuthStart))
|
||||||
mux.Handle("/oauth2/callback", stdMiddleware.ThenFunc(p.OAuthCallback))
|
mux.Handle("/oauth2/callback", stdMiddleware.ThenFunc(a.OAuthCallback))
|
||||||
// authenticate-server endpoints
|
// authenticate-server endpoints
|
||||||
mux.Handle("/sign_in", validateSignatureMiddleware.ThenFunc(p.SignIn))
|
mux.Handle("/sign_in", validateSignatureMiddleware.ThenFunc(a.SignIn))
|
||||||
mux.Handle("/sign_out", validateSignatureMiddleware.ThenFunc(p.SignOut)) // GET POST
|
mux.Handle("/sign_out", validateSignatureMiddleware.ThenFunc(a.SignOut)) // GET POST
|
||||||
|
|
||||||
return mux
|
return mux
|
||||||
}
|
}
|
||||||
|
|
||||||
// RobotsTxt handles the /robots.txt route.
|
// RobotsTxt handles the /robots.txt route.
|
||||||
func (p *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
|
func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*sessions.SessionState, error) {
|
func (a *Authenticate) authenticate(w http.ResponseWriter, r *http.Request) (*sessions.SessionState, error) {
|
||||||
session, err := p.sessionStore.LoadSession(r)
|
session, err := a.sessionStore.LoadSession(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to load session")
|
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to load session")
|
||||||
p.sessionStore.ClearSession(w, r)
|
a.sessionStore.ClearSession(w, r)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// if long-lived lifetime has expired, clear session
|
// if long-lived lifetime has expired, clear session
|
||||||
if session.LifetimePeriodExpired() {
|
if session.LifetimePeriodExpired() {
|
||||||
log.FromRequest(r).Warn().Msg("authenticate: lifetime expired")
|
log.FromRequest(r).Warn().Msg("authenticate: lifetime expired")
|
||||||
p.sessionStore.ClearSession(w, r)
|
a.sessionStore.ClearSession(w, r)
|
||||||
return nil, sessions.ErrLifetimeExpired
|
return nil, sessions.ErrLifetimeExpired
|
||||||
}
|
}
|
||||||
// check if session refresh period is up
|
// check if session refresh period is up
|
||||||
if session.RefreshPeriodExpired() {
|
if session.RefreshPeriodExpired() {
|
||||||
newToken, err := p.provider.Refresh(session.RefreshToken)
|
newToken, err := a.provider.Refresh(session.RefreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to refresh session")
|
log.FromRequest(r).Error().Err(err).Msg("authenticate: failed to refresh session")
|
||||||
p.sessionStore.ClearSession(w, r)
|
a.sessionStore.ClearSession(w, r)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
session.AccessToken = newToken.AccessToken
|
session.AccessToken = newToken.AccessToken
|
||||||
session.RefreshDeadline = newToken.Expiry
|
session.RefreshDeadline = newToken.Expiry
|
||||||
err = p.sessionStore.SaveSession(w, r, session)
|
err = a.sessionStore.SaveSession(w, r, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// We refreshed the session successfully, but failed to save it.
|
// We refreshed the session successfully, but failed to save it.
|
||||||
// This could be from failing to encode the session properly.
|
// This could be from failing to encode the session properly.
|
||||||
// But, we clear the session cookie and reject the request
|
// But, we clear the session cookie and reject the request
|
||||||
log.FromRequest(r).Error().Err(err).Msg("could not save refreshed session")
|
log.FromRequest(r).Error().Err(err).Msg("could not save refreshed session")
|
||||||
p.sessionStore.ClearSession(w, r)
|
a.sessionStore.ClearSession(w, r)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// The session has not exceeded it's lifetime or requires refresh
|
// The session has not exceeded it's lifetime or requires refresh
|
||||||
ok, err := p.provider.Validate(session.IDToken)
|
ok, err := a.provider.Validate(session.IDToken)
|
||||||
if !ok || err != nil {
|
if !ok || err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("invalid session state")
|
log.FromRequest(r).Error().Err(err).Msg("invalid session state")
|
||||||
p.sessionStore.ClearSession(w, r)
|
a.sessionStore.ClearSession(w, r)
|
||||||
return nil, httputil.ErrUserNotAuthorized
|
return nil, httputil.ErrUserNotAuthorized
|
||||||
}
|
}
|
||||||
err = p.sessionStore.SaveSession(w, r, session)
|
err = a.sessionStore.SaveSession(w, r, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("failed to save valid session")
|
log.FromRequest(r).Error().Err(err).Msg("failed to save valid session")
|
||||||
p.sessionStore.ClearSession(w, r)
|
a.sessionStore.ClearSession(w, r)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// authenticate really should not be in the business of authorization
|
// authenticate really should not be in the business of authorization
|
||||||
// todo(bdd) : remove when authorization module added
|
// todo(bdd) : remove when authorization module added
|
||||||
if !p.Validator(session.Email) {
|
if !a.Validator(session.Email) {
|
||||||
log.FromRequest(r).Error().Msg("invalid email user")
|
log.FromRequest(r).Error().Msg("invalid email user")
|
||||||
return nil, httputil.ErrUserNotAuthorized
|
return nil, httputil.ErrUserNotAuthorized
|
||||||
}
|
}
|
||||||
log.Info().Msg("authenticate")
|
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignIn handles the /sign_in endpoint. It attempts to authenticate the user,
|
// SignIn handles the /sign_in endpoint. It attempts to authenticate the user,
|
||||||
// and if the user is not authenticated, it renders a sign in page.
|
// and if the user is not authenticated, it renders a sign in page.
|
||||||
func (p *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
||||||
session, err := p.authenticate(w, r)
|
session, err := a.authenticate(w, r)
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
// User is authenticated, redirect back to proxy
|
// session good, redirect back to proxy
|
||||||
p.ProxyOAuthRedirect(w, r, session)
|
log.FromRequest(r).Info().Msg("authenticate.SignIn : authenticated")
|
||||||
|
a.ProxyCallback(w, r, session)
|
||||||
case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
||||||
log.Info().Err(err).Msg("authenticate.SignIn : expected failure")
|
// session invalid, authenticate
|
||||||
|
log.FromRequest(r).Info().Err(err).Msg("authenticate.SignIn : expected failure")
|
||||||
if err != http.ErrNoCookie {
|
if err != http.ErrNoCookie {
|
||||||
p.sessionStore.ClearSession(w, r)
|
a.sessionStore.ClearSession(w, r)
|
||||||
}
|
}
|
||||||
p.OAuthStart(w, r)
|
a.OAuthStart(w, r)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
log.Error().Err(err).Msg("authenticate: unexpected sign in error")
|
log.Error().Err(err).Msg("authenticate: unexpected sign in error")
|
||||||
|
@ -152,10 +153,10 @@ func (p *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProxyOAuthRedirect redirects the user back to proxy's redirection endpoint.
|
// ProxyCallback redirects the user back to proxy service along with an encrypted payload, as
|
||||||
// This workflow corresponds to Section 3.1.2 of the OAuth2 RFC.
|
// url params, of the user's session state.
|
||||||
// See https://tools.ietf.org/html/rfc6749#section-3.1.2 for more specific information.
|
// See RFC6749 3.1.2 https://tools.ietf.org/html/rfc6749#section-3.1.2
|
||||||
func (p *Authenticate) ProxyOAuthRedirect(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) {
|
func (a *Authenticate) ProxyCallback(w http.ResponseWriter, r *http.Request, session *sessions.SessionState) {
|
||||||
err := r.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
||||||
|
@ -180,7 +181,7 @@ func (p *Authenticate) ProxyOAuthRedirect(w http.ResponseWriter, r *http.Request
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// encrypt session state as json blob
|
// encrypt session state as json blob
|
||||||
encrypted, err := sessions.MarshalSession(session, p.cipher)
|
encrypted, err := sessions.MarshalSession(session, a.cipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
@ -193,111 +194,89 @@ func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string
|
||||||
params, _ := url.ParseQuery(u.RawQuery)
|
params, _ := url.ParseQuery(u.RawQuery)
|
||||||
params.Set("code", authCode)
|
params.Set("code", authCode)
|
||||||
params.Set("state", state)
|
params.Set("state", state)
|
||||||
|
|
||||||
u.RawQuery = params.Encode()
|
u.RawQuery = params.Encode()
|
||||||
|
|
||||||
if u.Scheme == "" {
|
if u.Scheme == "" {
|
||||||
u.Scheme = "https"
|
u.Scheme = "https"
|
||||||
}
|
}
|
||||||
|
|
||||||
return u.String()
|
return u.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignOut signs the user out.
|
// SignOut signs the user out by trying to revoke the users remote identity provider session
|
||||||
func (p *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
// then removes the associated local session state.
|
||||||
|
// Handles both GET and POST of form.
|
||||||
|
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// pretty safe to say that no matter what heppanes here, we want to revoke the local session
|
||||||
redirectURI := r.Form.Get("redirect_uri")
|
redirectURI := r.Form.Get("redirect_uri")
|
||||||
|
session, err := a.sessionStore.LoadSession(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("authenticate: signout failed to load session")
|
||||||
|
httputil.ErrorResponse(w, r, "No session found to log out", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
if r.Method == "GET" {
|
if r.Method == "GET" {
|
||||||
p.SignOutPage(w, r, "")
|
signature := r.Form.Get("sig")
|
||||||
|
timestamp := r.Form.Get("ts")
|
||||||
|
destinationURL, err := url.Parse(redirectURI)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("authenticate: malformed destination url")
|
||||||
|
httputil.ErrorResponse(w, r, "Malformed destination URL", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t := struct {
|
||||||
|
Redirect string
|
||||||
|
Signature string
|
||||||
|
Timestamp string
|
||||||
|
Destination string
|
||||||
|
Email string
|
||||||
|
Version string
|
||||||
|
}{
|
||||||
|
Redirect: redirectURI,
|
||||||
|
Signature: signature,
|
||||||
|
Timestamp: timestamp,
|
||||||
|
Destination: destinationURL.Host,
|
||||||
|
Email: session.Email,
|
||||||
|
Version: version.FullVersion(),
|
||||||
|
}
|
||||||
|
a.templates.ExecuteTemplate(w, "sign_out.html", t)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
a.sessionStore.ClearSession(w, r)
|
||||||
session, err := p.sessionStore.LoadSession(r)
|
err = a.provider.Revoke(session.AccessToken)
|
||||||
switch err {
|
|
||||||
case nil:
|
|
||||||
break
|
|
||||||
case http.ErrNoCookie: // if there's no cookie in the session we can just redirect
|
|
||||||
log.Error().Err(err).Msg("authenticate.SignOut : no cookie")
|
|
||||||
http.Redirect(w, r, redirectURI, http.StatusFound)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
// a different error, clear the session cookie and redirect
|
|
||||||
log.Error().Err(err).Msg("authenticate.SignOut : error loading cookie session")
|
|
||||||
p.sessionStore.ClearSession(w, r)
|
|
||||||
http.Redirect(w, r, redirectURI, http.StatusFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = p.provider.Revoke(session.AccessToken)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("authenticate.SignOut : error revoking session")
|
log.Error().Err(err).Msg("authenticate: failed to revoke user session")
|
||||||
p.SignOutPage(w, r, "An error occurred during sign out. Please try again.")
|
httputil.ErrorResponse(w, r, fmt.Sprintf("could not revoke session: %s ", err.Error()), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.sessionStore.ClearSession(w, r)
|
|
||||||
http.Redirect(w, r, redirectURI, http.StatusFound)
|
http.Redirect(w, r, redirectURI, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignOutPage renders a sign out page with a message
|
|
||||||
func (p *Authenticate) SignOutPage(w http.ResponseWriter, r *http.Request, message string) {
|
|
||||||
// validateRedirectURI middleware already ensures that this is a valid URL
|
|
||||||
redirectURI := r.Form.Get("redirect_uri")
|
|
||||||
session, err := p.sessionStore.LoadSession(r)
|
|
||||||
if err != nil {
|
|
||||||
http.Redirect(w, r, redirectURI, http.StatusFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
signature := r.Form.Get("sig")
|
|
||||||
timestamp := r.Form.Get("ts")
|
|
||||||
destinationURL, err := url.Parse(redirectURI)
|
|
||||||
|
|
||||||
// An error message indicates that an internal server error occurred
|
|
||||||
if message != "" || err != nil {
|
|
||||||
log.Error().Err(err).Msg("authenticate.SignOutPage")
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
t := struct {
|
|
||||||
Redirect string
|
|
||||||
Signature string
|
|
||||||
Timestamp string
|
|
||||||
Message string
|
|
||||||
Destination string
|
|
||||||
Email string
|
|
||||||
Version string
|
|
||||||
}{
|
|
||||||
Redirect: redirectURI,
|
|
||||||
Signature: signature,
|
|
||||||
Timestamp: timestamp,
|
|
||||||
Message: message,
|
|
||||||
Destination: destinationURL.Host,
|
|
||||||
Email: session.Email,
|
|
||||||
Version: version.FullVersion(),
|
|
||||||
}
|
|
||||||
p.templates.ExecuteTemplate(w, "sign_out.html", t)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OAuthStart starts the authenticate process by redirecting to the provider. It provides a
|
// OAuthStart starts the authenticate process by redirecting to the provider. It provides a
|
||||||
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authenticate.
|
// `redirectURI`, allowing the provider to redirect back to the sso proxy after authenticate.
|
||||||
func (p *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||||
authRedirectURL, err := url.Parse(r.URL.Query().Get("redirect_uri"))
|
authRedirectURL, err := url.Parse(r.URL.Query().Get("redirect_uri"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authRedirectURL = p.RedirectURL.ResolveReference(r.URL)
|
authRedirectURL = a.RedirectURL.ResolveReference(r.URL)
|
||||||
|
|
||||||
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
|
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
|
||||||
p.csrfStore.SetCSRF(w, r, nonce)
|
a.csrfStore.SetCSRF(w, r, nonce)
|
||||||
|
|
||||||
// verify redirect uri is from the root domain
|
// verify redirect uri is from the root domain
|
||||||
if !middleware.ValidRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
|
if !middleware.ValidRedirectURI(authRedirectURL.String(), a.ProxyRootDomains) {
|
||||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// verify proxy url is from the root domain
|
// verify proxy url is from the root domain
|
||||||
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
|
proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
|
||||||
if err != nil || !middleware.ValidRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
|
if err != nil || !middleware.ValidRedirectURI(proxyRedirectURL.String(), a.ProxyRootDomains) {
|
||||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -305,7 +284,7 @@ func (p *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||||
// get the signature and timestamp values then compare hmac
|
// get the signature and timestamp values then compare hmac
|
||||||
proxyRedirectSig := authRedirectURL.Query().Get("sig")
|
proxyRedirectSig := authRedirectURL.Query().Get("sig")
|
||||||
ts := authRedirectURL.Query().Get("ts")
|
ts := authRedirectURL.Query().Get("ts")
|
||||||
if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.SharedKey) {
|
if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, a.SharedKey) {
|
||||||
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, "Invalid redirect parameter", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -313,16 +292,35 @@ func (p *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||||
// concat base64'd nonce and authenticate url to make state
|
// concat base64'd nonce and authenticate url to make state
|
||||||
state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String())))
|
state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String())))
|
||||||
// build the provider sign in url
|
// build the provider sign in url
|
||||||
signInURL := p.provider.GetSignInURL(state)
|
signInURL := a.provider.GetSignInURL(state)
|
||||||
|
|
||||||
http.Redirect(w, r, signInURL, http.StatusFound)
|
http.Redirect(w, r, signInURL, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OAuthCallback handles the callback from the identity provider. Displays an error page if there
|
||||||
|
// was an error. If successful, redirects back to the proxy-service via the redirect-url.
|
||||||
|
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
redirect, err := a.getOAuthCallback(w, r)
|
||||||
|
switch h := err.(type) {
|
||||||
|
case nil:
|
||||||
|
break
|
||||||
|
case httputil.HTTPError:
|
||||||
|
log.Error().Err(err).Msg("authenticate: oauth callback error")
|
||||||
|
httputil.ErrorResponse(w, r, h.Message, h.Code)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
log.Error().Err(err).Msg("authenticate: unexpected oauth callback error")
|
||||||
|
httputil.ErrorResponse(w, r, "Internal Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// redirect back to the proxy-service
|
||||||
|
http.Redirect(w, r, redirect, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
// getOAuthCallback completes the oauth cycle from an identity provider's callback
|
// getOAuthCallback completes the oauth cycle from an identity provider's callback
|
||||||
func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) {
|
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||||
err := r.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: bad form on oauth callback")
|
|
||||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
|
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
|
||||||
}
|
}
|
||||||
errorString := r.Form.Get("error")
|
errorString := r.Form.Get("error")
|
||||||
|
@ -336,7 +334,7 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Missing Code"}
|
return "", httputil.HTTPError{Code: http.StatusBadRequest, Message: "Missing Code"}
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := p.provider.Authenticate(code)
|
session, err := a.provider.Authenticate(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: error redeeming authenticate code")
|
log.FromRequest(r).Error().Err(err).Msg("authenticate: error redeeming authenticate code")
|
||||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
|
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: err.Error()}
|
||||||
|
@ -353,50 +351,30 @@ func (p *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
nonce := s[0]
|
nonce := s[0]
|
||||||
redirect := s[1]
|
redirect := s[1]
|
||||||
c, err := p.csrfStore.GetCSRF(r)
|
c, err := a.csrfStore.GetCSRF(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: bad csrf")
|
log.FromRequest(r).Error().Err(err).Msg("authenticate: bad csrf")
|
||||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Missing CSRF token"}
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Missing CSRF token"}
|
||||||
}
|
}
|
||||||
p.csrfStore.ClearCSRF(w, r)
|
a.csrfStore.ClearCSRF(w, r)
|
||||||
if c.Value != nonce {
|
if c.Value != nonce {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("authenticate: csrf mismatch")
|
log.FromRequest(r).Error().Err(err).Msg("authenticate: csrf mismatch")
|
||||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "CSRF failed"}
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "CSRF failed"}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !middleware.ValidRedirectURI(redirect, p.ProxyRootDomains) {
|
if !middleware.ValidRedirectURI(redirect, a.ProxyRootDomains) {
|
||||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Redirect URI"}
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "Invalid Redirect URI"}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set cookie, or deny: validates the session email and group
|
// Set cookie, or deny: validates the session email and group
|
||||||
if !p.Validator(session.Email) {
|
if !a.Validator(session.Email) {
|
||||||
log.FromRequest(r).Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
|
log.FromRequest(r).Error().Err(err).Str("email", session.Email).Msg("invalid email permissions denied")
|
||||||
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "You don't have access"}
|
return "", httputil.HTTPError{Code: http.StatusForbidden, Message: "You don't have access"}
|
||||||
}
|
}
|
||||||
err = p.sessionStore.SaveSession(w, r, session)
|
err = a.sessionStore.SaveSession(w, r, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("internal error")
|
log.Error().Err(err).Msg("internal error")
|
||||||
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Internal Error"}
|
return "", httputil.HTTPError{Code: http.StatusInternalServerError, Message: "Internal Error"}
|
||||||
}
|
}
|
||||||
return redirect, nil
|
return redirect, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthCallback handles the callback from the identity provider. Displays an error page if there
|
|
||||||
// was an error. If successful, redirects back to the proxy-service via the redirect-url.
|
|
||||||
func (p *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|
||||||
redirect, err := p.getOAuthCallback(w, r)
|
|
||||||
switch h := err.(type) {
|
|
||||||
case nil:
|
|
||||||
break
|
|
||||||
case httputil.HTTPError:
|
|
||||||
log.Error().Err(err).Msg("authenticate: oauth callback error")
|
|
||||||
httputil.ErrorResponse(w, r, h.Message, h.Code)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
log.Error().Err(err).Msg("authenticate: unexpected oauth callback error")
|
|
||||||
httputil.ErrorResponse(w, r, "Internal Error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// redirect back to the proxy-service
|
|
||||||
http.Redirect(w, r, redirect, http.StatusFound)
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,15 +1,27 @@
|
||||||
package authenticate
|
package authenticate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/authenticate/providers"
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/templates"
|
"github.com/pomerium/pomerium/internal/templates"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// mocks for validator func
|
||||||
|
func trueValidator(s string) bool { return true }
|
||||||
|
func falseValidator(s string) bool { return false }
|
||||||
|
|
||||||
func testAuthenticate() *Authenticate {
|
func testAuthenticate() *Authenticate {
|
||||||
var auth Authenticate
|
var auth Authenticate
|
||||||
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
|
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
|
||||||
|
@ -37,3 +49,862 @@ func TestAuthenticate_RobotsTxt(t *testing.T) {
|
||||||
t.Errorf("handler returned wrong body: got %v want %v", rr.Body.String(), expected)
|
t.Errorf("handler returned wrong body: got %v want %v", rr.Body.String(), expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthenticate_Handler(t *testing.T) {
|
||||||
|
auth := testAuthenticate()
|
||||||
|
|
||||||
|
h := auth.Handler()
|
||||||
|
if h == nil {
|
||||||
|
t.Error("handler cannot be nil")
|
||||||
|
}
|
||||||
|
req := httptest.NewRequest("GET", "/robots.txt", nil)
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rr, req)
|
||||||
|
expected := fmt.Sprintf("User-agent: *\nDisallow: /")
|
||||||
|
|
||||||
|
body := rr.Body.String()
|
||||||
|
if body != expected {
|
||||||
|
t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticate_authenticate(t *testing.T) {
|
||||||
|
// sessions.MockSessionStore{Session: expiredLifetime}
|
||||||
|
goodSession := sessions.MockSessionStore{
|
||||||
|
Session: &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}}
|
||||||
|
expiredSession := sessions.MockSessionStore{
|
||||||
|
Session: &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * -time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}}
|
||||||
|
expiredRefresPeriod := sessions.MockSessionStore{
|
||||||
|
Session: &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * -time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
session sessions.SessionStore
|
||||||
|
provider providers.MockProvider
|
||||||
|
validator func(string) bool
|
||||||
|
want *sessions.SessionState
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"good", goodSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, false},
|
||||||
|
{"good but fails validation", goodSession, providers.MockProvider{ValidateResponse: true}, falseValidator, nil, true},
|
||||||
|
{"can't load session", sessions.MockSessionStore{LoadError: errors.New("error")}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
|
||||||
|
{"validation fails", goodSession, providers.MockProvider{ValidateResponse: false}, trueValidator, nil, true},
|
||||||
|
{"session fails after good validation", sessions.MockSessionStore{
|
||||||
|
SaveError: errors.New("error"),
|
||||||
|
Session: &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}}, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
|
||||||
|
{"lifetime expired", expiredSession, providers.MockProvider{ValidateResponse: true}, trueValidator, nil, true},
|
||||||
|
{"refresh expired",
|
||||||
|
expiredRefresPeriod,
|
||||||
|
providers.MockProvider{
|
||||||
|
ValidateResponse: true,
|
||||||
|
RefreshResponse: &oauth2.Token{
|
||||||
|
AccessToken: "new token",
|
||||||
|
Expiry: time.Now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
trueValidator, nil, false},
|
||||||
|
{"refresh expired refresh error",
|
||||||
|
expiredRefresPeriod,
|
||||||
|
providers.MockProvider{
|
||||||
|
ValidateResponse: true,
|
||||||
|
RefreshError: errors.New("error"),
|
||||||
|
},
|
||||||
|
trueValidator, nil, true},
|
||||||
|
{"refresh expired failed save",
|
||||||
|
sessions.MockSessionStore{
|
||||||
|
SaveError: errors.New("error"),
|
||||||
|
Session: &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * -time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}},
|
||||||
|
providers.MockProvider{
|
||||||
|
ValidateResponse: true,
|
||||||
|
RefreshResponse: &oauth2.Token{
|
||||||
|
AccessToken: "new token",
|
||||||
|
Expiry: time.Now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
trueValidator, nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &Authenticate{
|
||||||
|
sessionStore: tt.session,
|
||||||
|
provider: tt.provider,
|
||||||
|
Validator: tt.validator,
|
||||||
|
}
|
||||||
|
r := httptest.NewRequest("GET", "/auth", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
_, err := p.authenticate(w, r)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Authenticate.authenticate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
session sessions.SessionStore
|
||||||
|
provider providers.MockProvider
|
||||||
|
validator func(string) bool
|
||||||
|
wantCode int
|
||||||
|
}{
|
||||||
|
{"good",
|
||||||
|
sessions.MockSessionStore{
|
||||||
|
Session: &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}},
|
||||||
|
providers.MockProvider{ValidateResponse: true},
|
||||||
|
trueValidator,
|
||||||
|
403},
|
||||||
|
// {"no session",
|
||||||
|
// sessions.MockSessionStore{
|
||||||
|
// Session: &sessions.SessionState{
|
||||||
|
// AccessToken: "AccessToken",
|
||||||
|
// RefreshToken: "RefreshToken",
|
||||||
|
// LifetimeDeadline: time.Now().Add(-10 * time.Second),
|
||||||
|
// RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
// ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
// }},
|
||||||
|
// providers.MockProvider{ValidateResponse: true},
|
||||||
|
// trueValidator,
|
||||||
|
// 200},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := &Authenticate{
|
||||||
|
sessionStore: tt.session,
|
||||||
|
provider: tt.provider,
|
||||||
|
Validator: tt.validator,
|
||||||
|
}
|
||||||
|
r := httptest.NewRequest("GET", "/sign-in", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
a.SignIn(w, r)
|
||||||
|
if status := w.Code; status != tt.wantCode {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockCipher struct{}
|
||||||
|
|
||||||
|
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
|
||||||
|
if string(s) == "error" {
|
||||||
|
return []byte(""), errors.New("error encrypting")
|
||||||
|
}
|
||||||
|
return []byte("OK"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
|
||||||
|
if string(s) == "error" {
|
||||||
|
return []byte(""), errors.New("error encrypting")
|
||||||
|
}
|
||||||
|
return []byte("OK"), nil
|
||||||
|
}
|
||||||
|
func (a mockCipher) Marshal(s interface{}) (string, error) { return "ok", nil }
|
||||||
|
func (a mockCipher) Unmarshal(s string, i interface{}) error {
|
||||||
|
if string(s) == "unmarshal error" || string(s) == "error" {
|
||||||
|
return errors.New("error")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func TestAuthenticate_ProxyCallback(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
|
||||||
|
uri string
|
||||||
|
state string
|
||||||
|
authCode string
|
||||||
|
|
||||||
|
sessionState *sessions.SessionState
|
||||||
|
sessionStore sessions.SessionStore
|
||||||
|
wantCode int
|
||||||
|
wantBody string
|
||||||
|
}{
|
||||||
|
{"good", "https://corp.pomerium.io/", "state", "code",
|
||||||
|
&sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
},
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
302,
|
||||||
|
"<a href=\"https://corp.pomerium.io/?code=ok&state=state\">Found</a>."},
|
||||||
|
{"no state",
|
||||||
|
"https://corp.pomerium.io/",
|
||||||
|
"",
|
||||||
|
"code",
|
||||||
|
&sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
},
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
403,
|
||||||
|
"no state parameter supplied"},
|
||||||
|
{"no redirect_url",
|
||||||
|
"",
|
||||||
|
"state",
|
||||||
|
"code",
|
||||||
|
&sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
},
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
403,
|
||||||
|
"no redirect_uri parameter"},
|
||||||
|
{"malformed redirect_url",
|
||||||
|
"https://pomerium.com%zzzzz",
|
||||||
|
"state",
|
||||||
|
"code",
|
||||||
|
&sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
},
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
400,
|
||||||
|
"malformed redirect_uri"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := &Authenticate{
|
||||||
|
sessionStore: tt.sessionStore,
|
||||||
|
cipher: mockCipher{},
|
||||||
|
}
|
||||||
|
u, _ := url.Parse("https://pomerium.io/redirect")
|
||||||
|
params, _ := url.ParseQuery(u.RawQuery)
|
||||||
|
params.Set("code", tt.authCode)
|
||||||
|
params.Set("state", tt.state)
|
||||||
|
params.Set("redirect_uri", tt.uri)
|
||||||
|
|
||||||
|
u.RawQuery = params.Encode()
|
||||||
|
|
||||||
|
r := httptest.NewRequest("GET", u.String(), nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
a.ProxyCallback(w, r, tt.sessionState)
|
||||||
|
if status := w.Code; status != tt.wantCode {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
|
||||||
|
}
|
||||||
|
if body := w.Body.String(); !strings.Contains(body, tt.wantBody) {
|
||||||
|
t.Errorf("handler returned wrong body Body: got \n%s \n%s", body, tt.wantBody)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_getAuthCodeRedirectURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
redirectURL *url.URL
|
||||||
|
state string
|
||||||
|
authCode string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"https", uriParse("https://www.pomerium.io"), "state", "auth-code", "https://www.pomerium.io?code=auth-code&state=state"},
|
||||||
|
{"http", uriParse("http://www.pomerium.io"), "state", "auth-code", "http://www.pomerium.io?code=auth-code&state=state"},
|
||||||
|
{"no subdomain", uriParse("http://pomerium.io"), "state", "auth-code", "http://pomerium.io?code=auth-code&state=state"},
|
||||||
|
{"no scheme make https", uriParse("pomerium.io"), "state", "auth-code", "https://pomerium.io?code=auth-code&state=state"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := getAuthCodeRedirectURL(tt.redirectURL, tt.state, tt.authCode); got != tt.want {
|
||||||
|
t.Errorf("getAuthCodeRedirectURL() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func uriParse(s string) *url.URL {
|
||||||
|
uri, _ := url.Parse(s)
|
||||||
|
return uri
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
|
||||||
|
redirectURL string
|
||||||
|
sig string
|
||||||
|
ts string
|
||||||
|
|
||||||
|
provider providers.Provider
|
||||||
|
sessionStore sessions.SessionStore
|
||||||
|
wantCode int
|
||||||
|
wantBody string
|
||||||
|
}{
|
||||||
|
{"good post",
|
||||||
|
http.MethodPost,
|
||||||
|
"https://corp.pomerium.io/",
|
||||||
|
"sig",
|
||||||
|
"ts",
|
||||||
|
providers.MockProvider{},
|
||||||
|
sessions.MockSessionStore{
|
||||||
|
Session: &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
http.StatusFound,
|
||||||
|
""},
|
||||||
|
{"failed revoke",
|
||||||
|
http.MethodPost,
|
||||||
|
"https://corp.pomerium.io/",
|
||||||
|
"sig",
|
||||||
|
"ts",
|
||||||
|
providers.MockProvider{RevokeError: errors.New("OH NO")},
|
||||||
|
sessions.MockSessionStore{
|
||||||
|
Session: &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"could not revoke"},
|
||||||
|
|
||||||
|
{"good get",
|
||||||
|
http.MethodGet,
|
||||||
|
"https://corp.pomerium.io/",
|
||||||
|
"sig",
|
||||||
|
"ts",
|
||||||
|
providers.MockProvider{},
|
||||||
|
sessions.MockSessionStore{
|
||||||
|
Session: &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
http.StatusOK,
|
||||||
|
"This will also sign you out of other internal apps."},
|
||||||
|
{"cannot load session",
|
||||||
|
http.MethodGet,
|
||||||
|
"https://corp.pomerium.io/",
|
||||||
|
"sig",
|
||||||
|
"ts",
|
||||||
|
providers.MockProvider{},
|
||||||
|
sessions.MockSessionStore{
|
||||||
|
LoadError: errors.New("uh oh"),
|
||||||
|
Session: &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"No session found to log out"},
|
||||||
|
{"bad redirect url get",
|
||||||
|
http.MethodGet,
|
||||||
|
"https://pomerium.com%zzzzz",
|
||||||
|
"sig",
|
||||||
|
"ts",
|
||||||
|
providers.MockProvider{},
|
||||||
|
sessions.MockSessionStore{
|
||||||
|
Session: &sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"Error"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := &Authenticate{
|
||||||
|
sessionStore: tt.sessionStore,
|
||||||
|
provider: tt.provider,
|
||||||
|
cipher: mockCipher{},
|
||||||
|
templates: templates.New(),
|
||||||
|
}
|
||||||
|
u, _ := url.Parse("/sign_out")
|
||||||
|
params, _ := url.ParseQuery(u.RawQuery)
|
||||||
|
params.Add("sig", tt.sig)
|
||||||
|
params.Add("ts", tt.ts)
|
||||||
|
params.Add("redirect_uri", tt.redirectURL)
|
||||||
|
u.RawQuery = params.Encode()
|
||||||
|
|
||||||
|
r := httptest.NewRequest(tt.method, u.String(), nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
a.SignOut(w, r)
|
||||||
|
if status := w.Code; status != tt.wantCode {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
|
||||||
|
}
|
||||||
|
if body := w.Body.String(); !strings.Contains(body, tt.wantBody) {
|
||||||
|
t.Errorf("handler returned wrong body Body: got \n%s \n%s", body, tt.wantBody)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) string {
|
||||||
|
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
|
||||||
|
h := cryptutil.Hash(secret, data)
|
||||||
|
return base64.URLEncoding.EncodeToString(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticate_OAuthStart(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
|
||||||
|
redirectURL string
|
||||||
|
sig string
|
||||||
|
ts string
|
||||||
|
allowedDomains []string
|
||||||
|
|
||||||
|
provider providers.Provider
|
||||||
|
csrfStore sessions.MockCSRFStore
|
||||||
|
// sessionStore sessions.SessionStore
|
||||||
|
wantCode int
|
||||||
|
}{
|
||||||
|
{"good",
|
||||||
|
http.MethodGet,
|
||||||
|
"https://corp.pomerium.io/",
|
||||||
|
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
|
||||||
|
fmt.Sprint(time.Now().Unix()),
|
||||||
|
[]string{".pomerium.io"},
|
||||||
|
providers.MockProvider{},
|
||||||
|
sessions.MockCSRFStore{},
|
||||||
|
http.StatusFound,
|
||||||
|
},
|
||||||
|
{"bad timestamp",
|
||||||
|
http.MethodGet,
|
||||||
|
"https://corp.pomerium.io/",
|
||||||
|
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
|
||||||
|
fmt.Sprint(time.Now().Add(10 * time.Hour).Unix()),
|
||||||
|
[]string{".pomerium.io"},
|
||||||
|
providers.MockProvider{},
|
||||||
|
sessions.MockCSRFStore{},
|
||||||
|
http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{"domain not in allowed domains",
|
||||||
|
http.MethodGet,
|
||||||
|
"https://corp.pomerium.io/",
|
||||||
|
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
|
||||||
|
fmt.Sprint(time.Now().Unix()),
|
||||||
|
[]string{"not.pomerium.io"},
|
||||||
|
providers.MockProvider{},
|
||||||
|
sessions.MockCSRFStore{},
|
||||||
|
http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{"missing redirect",
|
||||||
|
http.MethodGet,
|
||||||
|
"",
|
||||||
|
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
|
||||||
|
fmt.Sprint(time.Now().Unix()),
|
||||||
|
[]string{".pomerium.io"},
|
||||||
|
providers.MockProvider{},
|
||||||
|
sessions.MockCSRFStore{},
|
||||||
|
http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{"malformed redirect",
|
||||||
|
http.MethodGet,
|
||||||
|
"https://pomerium.com%zzzzz",
|
||||||
|
redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"),
|
||||||
|
fmt.Sprint(time.Now().Unix()),
|
||||||
|
[]string{".pomerium.io"},
|
||||||
|
providers.MockProvider{},
|
||||||
|
sessions.MockCSRFStore{},
|
||||||
|
http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := &Authenticate{
|
||||||
|
ProxyRootDomains: tt.allowedDomains,
|
||||||
|
RedirectURL: uriParse("http://www.pomerium.io"),
|
||||||
|
csrfStore: tt.csrfStore,
|
||||||
|
provider: tt.provider,
|
||||||
|
SharedKey: "secret",
|
||||||
|
cipher: mockCipher{},
|
||||||
|
}
|
||||||
|
u, _ := url.Parse("/oauth_start")
|
||||||
|
params, _ := url.ParseQuery(u.RawQuery)
|
||||||
|
params.Add("sig", tt.sig)
|
||||||
|
params.Add("ts", tt.ts)
|
||||||
|
params.Add("redirect_uri", tt.redirectURL)
|
||||||
|
|
||||||
|
u.RawQuery = params.Encode()
|
||||||
|
|
||||||
|
r := httptest.NewRequest(tt.method, u.String(), nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
a.OAuthStart(w, r)
|
||||||
|
if status := w.Code; status != tt.wantCode {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticate_getOAuthCallback(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
|
||||||
|
// url params
|
||||||
|
paramErr string
|
||||||
|
code string
|
||||||
|
state string
|
||||||
|
validDomains []string
|
||||||
|
validator func(string) bool
|
||||||
|
|
||||||
|
session sessions.SessionStore
|
||||||
|
provider providers.MockProvider
|
||||||
|
csrfStore sessions.MockCSRFStore
|
||||||
|
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"good",
|
||||||
|
http.MethodGet,
|
||||||
|
"",
|
||||||
|
"code",
|
||||||
|
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||||
|
[]string{"pomerium.io"},
|
||||||
|
trueValidator,
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
providers.MockProvider{
|
||||||
|
AuthenticateResponse: sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}},
|
||||||
|
sessions.MockCSRFStore{
|
||||||
|
ResponseCSRF: "csrf",
|
||||||
|
Cookie: &http.Cookie{Value: "nonce"}},
|
||||||
|
"https://corp.pomerium.io",
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{"get csrf error",
|
||||||
|
http.MethodGet,
|
||||||
|
"",
|
||||||
|
"code",
|
||||||
|
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||||
|
[]string{"pomerium.io"},
|
||||||
|
trueValidator,
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
providers.MockProvider{
|
||||||
|
AuthenticateResponse: sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}},
|
||||||
|
sessions.MockCSRFStore{
|
||||||
|
ResponseCSRF: "csrf",
|
||||||
|
GetError: errors.New("error"),
|
||||||
|
Cookie: &http.Cookie{Value: "not nonce"}},
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{"csrf nonce error",
|
||||||
|
http.MethodGet,
|
||||||
|
"",
|
||||||
|
"code",
|
||||||
|
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||||
|
[]string{"pomerium.io"},
|
||||||
|
trueValidator,
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
providers.MockProvider{
|
||||||
|
AuthenticateResponse: sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}},
|
||||||
|
sessions.MockCSRFStore{
|
||||||
|
ResponseCSRF: "csrf",
|
||||||
|
Cookie: &http.Cookie{Value: "not nonce"}},
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{"failed authenticate",
|
||||||
|
http.MethodGet,
|
||||||
|
"",
|
||||||
|
"code",
|
||||||
|
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||||
|
[]string{"pomerium.io"},
|
||||||
|
trueValidator,
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
providers.MockProvider{
|
||||||
|
AuthenticateError: errors.New("error"),
|
||||||
|
},
|
||||||
|
sessions.MockCSRFStore{
|
||||||
|
ResponseCSRF: "csrf",
|
||||||
|
Cookie: &http.Cookie{Value: "nonce"}},
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{"failed save session",
|
||||||
|
http.MethodGet,
|
||||||
|
"",
|
||||||
|
"code",
|
||||||
|
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||||
|
[]string{"pomerium.io"},
|
||||||
|
trueValidator,
|
||||||
|
sessions.MockSessionStore{SaveError: errors.New("error")},
|
||||||
|
providers.MockProvider{
|
||||||
|
AuthenticateResponse: sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}},
|
||||||
|
sessions.MockCSRFStore{
|
||||||
|
ResponseCSRF: "csrf",
|
||||||
|
Cookie: &http.Cookie{Value: "nonce"}},
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{"failed email validation",
|
||||||
|
http.MethodGet,
|
||||||
|
"",
|
||||||
|
"code",
|
||||||
|
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||||
|
[]string{"pomerium.io"},
|
||||||
|
falseValidator,
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
providers.MockProvider{
|
||||||
|
AuthenticateResponse: sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}},
|
||||||
|
sessions.MockCSRFStore{
|
||||||
|
ResponseCSRF: "csrf",
|
||||||
|
Cookie: &http.Cookie{Value: "nonce"}},
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
|
||||||
|
{"error returned",
|
||||||
|
http.MethodGet,
|
||||||
|
"idp error",
|
||||||
|
"code",
|
||||||
|
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||||
|
[]string{"pomerium.io"},
|
||||||
|
trueValidator,
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
providers.MockProvider{
|
||||||
|
AuthenticateResponse: sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}},
|
||||||
|
sessions.MockCSRFStore{
|
||||||
|
ResponseCSRF: "csrf",
|
||||||
|
Cookie: &http.Cookie{Value: "nonce"}},
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{"empty code",
|
||||||
|
http.MethodGet,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")),
|
||||||
|
[]string{"pomerium.io"},
|
||||||
|
trueValidator,
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
providers.MockProvider{
|
||||||
|
AuthenticateResponse: sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}},
|
||||||
|
sessions.MockCSRFStore{
|
||||||
|
ResponseCSRF: "csrf",
|
||||||
|
Cookie: &http.Cookie{Value: "nonce"}},
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{"invalid state string",
|
||||||
|
http.MethodGet,
|
||||||
|
"",
|
||||||
|
"code",
|
||||||
|
"nonce:https://corp.pomerium.io",
|
||||||
|
[]string{"pomerium.io"},
|
||||||
|
trueValidator,
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
providers.MockProvider{
|
||||||
|
AuthenticateResponse: sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}},
|
||||||
|
sessions.MockCSRFStore{
|
||||||
|
ResponseCSRF: "csrf",
|
||||||
|
Cookie: &http.Cookie{Value: "nonce"}},
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{"malformed state",
|
||||||
|
http.MethodGet,
|
||||||
|
"",
|
||||||
|
"code",
|
||||||
|
base64.URLEncoding.EncodeToString([]byte("nonce")),
|
||||||
|
[]string{"pomerium.io"},
|
||||||
|
trueValidator,
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
providers.MockProvider{
|
||||||
|
AuthenticateResponse: sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}},
|
||||||
|
sessions.MockCSRFStore{
|
||||||
|
ResponseCSRF: "csrf",
|
||||||
|
Cookie: &http.Cookie{Value: "nonce"}},
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{"invalid redirect uri",
|
||||||
|
http.MethodGet,
|
||||||
|
"",
|
||||||
|
"code",
|
||||||
|
base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")),
|
||||||
|
[]string{"pomerium.io"},
|
||||||
|
trueValidator,
|
||||||
|
sessions.MockSessionStore{},
|
||||||
|
providers.MockProvider{
|
||||||
|
AuthenticateResponse: sessions.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "blah@blah.com",
|
||||||
|
LifetimeDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
ValidDeadline: time.Now().Add(10 * time.Second),
|
||||||
|
}},
|
||||||
|
sessions.MockCSRFStore{
|
||||||
|
ResponseCSRF: "csrf",
|
||||||
|
Cookie: &http.Cookie{Value: "nonce"}},
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := &Authenticate{
|
||||||
|
sessionStore: tt.session,
|
||||||
|
csrfStore: tt.csrfStore,
|
||||||
|
provider: tt.provider,
|
||||||
|
ProxyRootDomains: tt.validDomains,
|
||||||
|
Validator: tt.validator,
|
||||||
|
}
|
||||||
|
u, _ := url.Parse("/oauthGet")
|
||||||
|
params, _ := url.ParseQuery(u.RawQuery)
|
||||||
|
params.Add("error", tt.paramErr)
|
||||||
|
params.Add("code", tt.code)
|
||||||
|
params.Add("state", tt.state)
|
||||||
|
|
||||||
|
u.RawQuery = params.Encode()
|
||||||
|
|
||||||
|
r := httptest.NewRequest(tt.method, u.String(), nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
got, err := a.getOAuthCallback(w, r)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Authenticate.getOAuthCallback() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("Authenticate.getOAuthCallback() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
41
authenticate/providers/mock_provider.go
Normal file
41
authenticate/providers/mock_provider.go
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
package providers // import "github.com/pomerium/pomerium/internal/providers"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions" // type Provider interface {
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockProvider provides a mocked implementation of the providers interface.
|
||||||
|
type MockProvider struct {
|
||||||
|
AuthenticateResponse sessions.SessionState
|
||||||
|
AuthenticateError error
|
||||||
|
ValidateResponse bool
|
||||||
|
ValidateError error
|
||||||
|
RefreshResponse *oauth2.Token
|
||||||
|
RefreshError error
|
||||||
|
RevokeError error
|
||||||
|
GetSignInURLResponse string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate is a mocked providers function.
|
||||||
|
func (mp MockProvider) Authenticate(code string) (*sessions.SessionState, error) {
|
||||||
|
return &mp.AuthenticateResponse, mp.AuthenticateError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate is a mocked providers function.
|
||||||
|
func (mp MockProvider) Validate(s string) (bool, error) {
|
||||||
|
return mp.ValidateResponse, mp.ValidateError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh is a mocked providers function.
|
||||||
|
func (mp MockProvider) Refresh(s string) (*oauth2.Token, error) {
|
||||||
|
return mp.RefreshResponse, mp.RefreshError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke is a mocked providers function.
|
||||||
|
func (mp MockProvider) Revoke(s string) error {
|
||||||
|
return mp.RevokeError
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSignInURL is a mocked providers function.
|
||||||
|
func (mp MockProvider) GetSignInURL(s string) string { return mp.GetSignInURLResponse }
|
|
@ -176,9 +176,6 @@ footer {
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div class="container">
|
<div class="container">
|
||||||
{{ if .Message }}
|
|
||||||
<div class="message">{{.Message}}</div>
|
|
||||||
{{ end}}
|
|
||||||
<div class="content">
|
<div class="content">
|
||||||
<header>
|
<header>
|
||||||
<h1>Sign out of <b>{{.Destination}}</b></h1>
|
<h1>Sign out of <b>{{.Destination}}</b></h1>
|
||||||
|
|
|
@ -56,3 +56,24 @@ func TestMockAuthenticate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
serviceName string
|
||||||
|
opts *Options
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"grpc good", "grpc", &Options{Addr: "test", InternalAddr: "intranet.local", SharedSecret: "secret"}, false},
|
||||||
|
{"grpc missing shared secret", "grpc", &Options{Addr: "test", InternalAddr: "intranet.local", SharedSecret: ""}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := New(tt.serviceName, tt.opts)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -192,8 +192,8 @@ func TestNewGRPC(t *testing.T) {
|
||||||
{"no shared secret", &Options{}, true, "proxy/authenticator: grpc client requires shared secret"},
|
{"no shared secret", &Options{}, true, "proxy/authenticator: grpc client requires shared secret"},
|
||||||
{"empty connection", &Options{Addr: "", SharedSecret: "shh"}, true, "proxy/authenticator: connection address required"},
|
{"empty connection", &Options{Addr: "", SharedSecret: "shh"}, true, "proxy/authenticator: connection address required"},
|
||||||
{"empty connections", &Options{Addr: "", InternalAddr: "", SharedSecret: "shh"}, true, "proxy/authenticator: connection address required"},
|
{"empty connections", &Options{Addr: "", InternalAddr: "", SharedSecret: "shh"}, true, "proxy/authenticator: connection address required"},
|
||||||
{"internal addr", &Options{Addr: "", InternalAddr: "intranet.local", SharedSecret: "shh"}, false, "proxy/authenticator: connection address required"},
|
{"internal addr", &Options{Addr: "", InternalAddr: "intranet.local", SharedSecret: "shh"}, false, ""},
|
||||||
{"cert overide", &Options{Addr: "", InternalAddr: "intranet.local", OverideCertificateName: "*.local", SharedSecret: "shh"}, false, "proxy/authenticator: connection address required"},
|
{"cert overide", &Options{Addr: "", InternalAddr: "intranet.local", OverideCertificateName: "*.local", SharedSecret: "shh"}, false, ""},
|
||||||
|
|
||||||
// {"addr and internal ", &Options{Addr: "localhost", InternalAddr: "local.localhost", SharedSecret: "shh"}, nil, true, ""},
|
// {"addr and internal ", &Options{Addr: "localhost", InternalAddr: "local.localhost", SharedSecret: "shh"}, nil, true, ""},
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,7 +121,7 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||||
// this value will be unique since we always use a randomized nonce as part of marshaling
|
// this value will be unique since we always use a randomized nonce as part of marshaling
|
||||||
encryptedCSRF, err := p.cipher.Marshal(state)
|
encryptedCSRF, err := p.cipher.Marshal(state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("failed to marshal csrf")
|
log.FromRequest(r).Error().Err(err).Msg("proxy: failed to marshal csrf")
|
||||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -131,7 +131,7 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||||
// this value will be unique since we always use a randomized nonce as part of marshaling
|
// this value will be unique since we always use a randomized nonce as part of marshaling
|
||||||
encryptedState, err := p.cipher.Marshal(state)
|
encryptedState, err := p.cipher.Marshal(state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("failed to encrypt cookie")
|
log.FromRequest(r).Error().Err(err).Msg("proxy: failed to encrypt cookie")
|
||||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -149,7 +149,7 @@ func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
||||||
func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
err := r.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("failed parsing request form")
|
log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing request form")
|
||||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -161,27 +161,23 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
// We begin the process of redeeming the code for an access token.
|
// We begin the process of redeeming the code for an access token.
|
||||||
rr, err := p.AuthenticateClient.Redeem(r.Form.Get("code"))
|
rr, err := p.AuthenticateClient.Redeem(r.Form.Get("code"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("error redeeming authorization code")
|
log.FromRequest(r).Error().Err(err).Msg("proxy: error redeeming authorization code")
|
||||||
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptedState := r.Form.Get("state")
|
encryptedState := r.Form.Get("state")
|
||||||
log.Warn().
|
|
||||||
Str("encryptedState", encryptedState).
|
|
||||||
Msg("OK")
|
|
||||||
|
|
||||||
stateParameter := &StateParameter{}
|
stateParameter := &StateParameter{}
|
||||||
err = p.cipher.Unmarshal(encryptedState, stateParameter)
|
err = p.cipher.Unmarshal(encryptedState, stateParameter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("could not unmarshal state")
|
log.FromRequest(r).Error().Err(err).Msg("proxy: could not unmarshal state")
|
||||||
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := p.csrfStore.GetCSRF(r)
|
c, err := p.csrfStore.GetCSRF(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("failed parsing csrf cookie")
|
log.FromRequest(r).Error().Err(err).Msg("proxy: failed parsing csrf cookie")
|
||||||
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
|
httputil.ErrorResponse(w, r, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -191,7 +187,7 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
csrfParameter := &StateParameter{}
|
csrfParameter := &StateParameter{}
|
||||||
err = p.cipher.Unmarshal(encryptedCSRF, csrfParameter)
|
err = p.cipher.Unmarshal(encryptedCSRF, csrfParameter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Error().Err(err).Msg("couldn't unmarshal CSRF")
|
log.FromRequest(r).Error().Err(err).Msg("proxy: couldn't unmarshal CSRF")
|
||||||
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
httputil.ErrorResponse(w, r, "Internal error", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -283,7 +279,7 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if session.LifetimePeriodExpired() {
|
if session.LifetimePeriodExpired() {
|
||||||
log.FromRequest(r).Info().Msg("proxy.Authenticate: lifetime expired, restarting")
|
log.FromRequest(r).Info().Msg("proxy: lifetime expired")
|
||||||
return sessions.ErrLifetimeExpired
|
return sessions.ErrLifetimeExpired
|
||||||
}
|
}
|
||||||
if session.RefreshPeriodExpired() {
|
if session.RefreshPeriodExpired() {
|
||||||
|
@ -295,12 +291,12 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
|
||||||
log.FromRequest(r).Warn().
|
log.FromRequest(r).Warn().
|
||||||
Str("RefreshToken", session.RefreshToken).
|
Str("RefreshToken", session.RefreshToken).
|
||||||
Str("AccessToken", session.AccessToken).
|
Str("AccessToken", session.AccessToken).
|
||||||
Msg("proxy.Authenticate: refresh failure")
|
Msg("proxy: refresh failed")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
session.AccessToken = accessToken
|
session.AccessToken = accessToken
|
||||||
session.RefreshDeadline = expiry
|
session.RefreshDeadline = expiry
|
||||||
log.FromRequest(r).Info().Msg("proxy.Authenticate: refresh success")
|
log.FromRequest(r).Info().Msg("proxy: refresh success")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.sessionStore.SaveSession(w, r, session)
|
err = p.sessionStore.SaveSession(w, r, session)
|
||||||
|
|
|
@ -359,12 +359,6 @@ func TestProxy_Proxy(t *testing.T) {
|
||||||
RefreshToken: "RefreshToken",
|
RefreshToken: "RefreshToken",
|
||||||
LifetimeDeadline: time.Now().Add(-10 * time.Second),
|
LifetimeDeadline: time.Now().Add(-10 * time.Second),
|
||||||
}
|
}
|
||||||
// expiredDeadline := &sessions.SessionState{
|
|
||||||
// AccessToken: "AccessToken",
|
|
||||||
// RefreshToken: "RefreshToken",
|
|
||||||
// LifetimeDeadline: time.Now().Add(10 * time.Second),
|
|
||||||
// RefreshDeadline: time.Now().Add(-10 * time.Second),
|
|
||||||
// }
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue