mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 15:47:36 +02:00
all: refactor handler logic
- all: prefer `FormValues` to `ParseForm` with subsequent `Form.Get`s - all: refactor authentication stack to be checked by middleware, and accessible via request context. - all: replace http.ServeMux with gorilla/mux’s router - all: replace custom CSRF checks with gorilla/csrf middleware - authenticate: extract callback path as constant. - internal/config: implement stringer interface for policy - internal/cryptutil: add helper func `NewBase64Key` - internal/cryptutil: rename `GenerateKey` to `NewKey` - internal/cryptutil: rename `GenerateRandomString` to `NewRandomStringN` - internal/middleware: removed alice in favor of gorilla/mux - internal/sessions: remove unused `ValidateRedirectURI` and `ValidateClientSecret` - internal/sessions: replace custom CSRF with gorilla/csrf fork that supports custom handler protection - internal/urlutil: add `SignedRedirectURL` to create hmac'd URLs - internal/urlutil: add `ValidateURL` helper to parse URL options - internal/urlutil: add `GetAbsoluteURL` which takes a request and returns its absolute URL. - proxy: remove holdover state verification checks; we no longer are setting sessions in any proxy routes so we don’t need them. - proxy: replace un-named http.ServeMux with named domain routes. Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
a793249386
commit
dc12947241
37 changed files with 1132 additions and 1384 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -76,4 +76,5 @@ yarn.lock
|
||||||
node_modules
|
node_modules
|
||||||
i18n/*
|
i18n/*
|
||||||
docs/.vuepress/dist/
|
docs/.vuepress/dist/
|
||||||
.firebase/
|
.firebase/
|
||||||
|
.changes.md
|
25
3RD-PARTY
25
3RD-PARTY
|
@ -87,31 +87,6 @@ https://github.com/bitly/oauth2_proxy/blob/master/LICENSE
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
THE SOFTWARE.
|
THE SOFTWARE.
|
||||||
|
|
||||||
alice
|
|
||||||
SPDX-License-Identifier: MIT
|
|
||||||
https://github.com/justinas/alice/blob/master/LICENSE
|
|
||||||
|
|
||||||
The MIT License (MIT)
|
|
||||||
|
|
||||||
Copyright (c) 2014 Justinas Stankevicius
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
||||||
this software and associated documentation files (the "Software"), to deal in
|
|
||||||
the Software without restriction, including without limitation the rights to
|
|
||||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
||||||
the Software, and to permit persons to whom the Software is furnished to do so,
|
|
||||||
subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
||||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
||||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
||||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
||||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
|
|
||||||
goji
|
goji
|
||||||
SPDX-License-Identifier: MIT
|
SPDX-License-Identifier: MIT
|
||||||
https://github.com/zenazn/goji/blob/master/LICENSE
|
https://github.com/zenazn/goji/blob/master/LICENSE
|
||||||
|
|
|
@ -15,6 +15,8 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const callbackPath = "/oauth2/callback"
|
||||||
|
|
||||||
// ValidateOptions checks that configuration are complete and valid.
|
// ValidateOptions checks that configuration are complete and valid.
|
||||||
// Returns on first error found.
|
// Returns on first error found.
|
||||||
func ValidateOptions(o config.Options) error {
|
func ValidateOptions(o config.Options) error {
|
||||||
|
@ -24,11 +26,8 @@ func ValidateOptions(o config.Options) error {
|
||||||
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
|
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
|
||||||
return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %v", err)
|
return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %v", err)
|
||||||
}
|
}
|
||||||
if o.AuthenticateURL == nil {
|
if err := urlutil.ValidateURL(o.AuthenticateURL); err != nil {
|
||||||
return errors.New("authenticate: 'AUTHENTICATE_SERVICE_URL' is required")
|
return fmt.Errorf("authenticate: invalid 'AUTHENTICATE_SERVICE_URL': %v", err)
|
||||||
}
|
|
||||||
if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil {
|
|
||||||
return fmt.Errorf("authenticate: couldn't parse 'AUTHENTICATE_SERVICE_URL': %v", err)
|
|
||||||
}
|
}
|
||||||
if o.ClientID == "" {
|
if o.ClientID == "" {
|
||||||
return errors.New("authenticate: 'IDP_CLIENT_ID' is required")
|
return errors.New("authenticate: 'IDP_CLIENT_ID' is required")
|
||||||
|
@ -44,8 +43,10 @@ type Authenticate struct {
|
||||||
SharedKey string
|
SharedKey string
|
||||||
RedirectURL *url.URL
|
RedirectURL *url.URL
|
||||||
|
|
||||||
|
cookieName string
|
||||||
|
cookieDomain string
|
||||||
|
cookieSecret []byte
|
||||||
templates *template.Template
|
templates *template.Template
|
||||||
csrfStore sessions.CSRFStore
|
|
||||||
sessionStore sessions.SessionStore
|
sessionStore sessions.SessionStore
|
||||||
cipher cryptutil.Cipher
|
cipher cryptutil.Cipher
|
||||||
provider identity.Authenticator
|
provider identity.Authenticator
|
||||||
|
@ -61,6 +62,9 @@ func New(opts config.Options) (*Authenticate, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if opts.CookieDomain == "" {
|
||||||
|
opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String())
|
||||||
|
}
|
||||||
cookieStore, err := sessions.NewCookieStore(
|
cookieStore, err := sessions.NewCookieStore(
|
||||||
&sessions.CookieStoreOptions{
|
&sessions.CookieStoreOptions{
|
||||||
Name: opts.CookieName,
|
Name: opts.CookieName,
|
||||||
|
@ -74,7 +78,7 @@ func New(opts config.Options) (*Authenticate, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
|
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
|
||||||
redirectURL.Path = "/oauth2/callback"
|
redirectURL.Path = callbackPath
|
||||||
provider, err := identity.New(
|
provider, err := identity.New(
|
||||||
opts.Provider,
|
opts.Provider,
|
||||||
&identity.Provider{
|
&identity.Provider{
|
||||||
|
@ -94,9 +98,11 @@ func New(opts config.Options) (*Authenticate, error) {
|
||||||
SharedKey: opts.SharedKey,
|
SharedKey: opts.SharedKey,
|
||||||
RedirectURL: redirectURL,
|
RedirectURL: redirectURL,
|
||||||
templates: templates.New(),
|
templates: templates.New(),
|
||||||
csrfStore: cookieStore,
|
|
||||||
sessionStore: cookieStore,
|
sessionStore: cookieStore,
|
||||||
cipher: cipher,
|
cipher: cipher,
|
||||||
provider: provider,
|
provider: provider,
|
||||||
|
cookieSecret: decodedCookieSecret,
|
||||||
|
cookieName: opts.CookieName,
|
||||||
|
cookieDomain: opts.CookieDomain,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/csrf"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
@ -31,24 +32,68 @@ var CSPHeaders = map[string]string{
|
||||||
|
|
||||||
// Handler returns the authenticate service's HTTP multiplexer, and routes.
|
// Handler returns the authenticate service's HTTP multiplexer, and routes.
|
||||||
func (a *Authenticate) Handler() http.Handler {
|
func (a *Authenticate) Handler() http.Handler {
|
||||||
// validation middleware chain
|
r := httputil.NewRouter()
|
||||||
c := middleware.NewChain()
|
r.Use(middleware.SetHeaders(CSPHeaders))
|
||||||
c = c.Append(middleware.SetHeaders(CSPHeaders))
|
r.Use(csrf.Protect(
|
||||||
mux := http.NewServeMux()
|
a.cookieSecret,
|
||||||
mux.Handle("/robots.txt", c.ThenFunc(a.RobotsTxt))
|
csrf.Path("/"),
|
||||||
|
csrf.Domain(a.cookieDomain),
|
||||||
|
csrf.UnsafePaths([]string{"/oauth2/callback"}), // enforce CSRF on "safe" handler
|
||||||
|
csrf.FormValueName("state"), // rfc6749 section-10.12
|
||||||
|
csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieName)),
|
||||||
|
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
|
||||||
|
))
|
||||||
|
|
||||||
|
r.HandleFunc("/robots.txt", a.RobotsTxt).Methods(http.MethodGet)
|
||||||
// Identity Provider (IdP) endpoints
|
// Identity Provider (IdP) endpoints
|
||||||
mux.Handle("/oauth2", c.ThenFunc(a.OAuthStart))
|
r.HandleFunc("/oauth2/callback", a.OAuthCallback).Methods(http.MethodGet)
|
||||||
mux.Handle("/oauth2/callback", c.ThenFunc(a.OAuthCallback))
|
r.HandleFunc("/api/v1/token", a.ExchangeToken)
|
||||||
|
|
||||||
// Proxy service endpoints
|
// Proxy service endpoints
|
||||||
validationMiddlewares := c.Append(
|
v := r.PathPrefix("/.pomerium").Subrouter()
|
||||||
middleware.ValidateSignature(a.SharedKey),
|
v.Use(middleware.ValidateSignature(a.SharedKey))
|
||||||
middleware.ValidateRedirectURI(a.RedirectURL),
|
v.Use(middleware.ValidateRedirectURI(a.RedirectURL))
|
||||||
)
|
v.Use(sessions.RetrieveSession(a.sessionStore))
|
||||||
mux.Handle("/sign_in", validationMiddlewares.ThenFunc(a.SignIn))
|
v.Use(a.VerifySession)
|
||||||
mux.Handle("/sign_out", validationMiddlewares.ThenFunc(a.SignOut)) // POST
|
|
||||||
// Direct user access endpoints
|
v.HandleFunc("/sign_in", a.SignIn)
|
||||||
mux.Handle("/api/v1/token", c.ThenFunc(a.ExchangeToken))
|
v.HandleFunc("/sign_out", a.SignOut)
|
||||||
return mux
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifySession is the middleware used to enforce a valid authentication
|
||||||
|
// session state is attached to the users's request context.
|
||||||
|
func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, err := sessions.FromContext(r.Context())
|
||||||
|
if errors.Is(err, sessions.ErrExpired) {
|
||||||
|
if err := a.refresh(w, r, state); err != nil {
|
||||||
|
log.FromRequest(r).Debug().Str("cause", err.Error()).Msg("authenticate: couldn't refresh session")
|
||||||
|
a.sessionStore.ClearSession(w, r)
|
||||||
|
a.redirectToIdentityProvider(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if err != nil {
|
||||||
|
log.FromRequest(r).Err(err).Msg("authenticate: unexpected session state")
|
||||||
|
a.sessionStore.ClearSession(w, r)
|
||||||
|
a.redirectToIdentityProvider(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) error {
|
||||||
|
newSession, err := a.provider.Refresh(r.Context(), s)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("authenticate: refresh failed: %w", err)
|
||||||
|
}
|
||||||
|
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
|
||||||
|
return fmt.Errorf("authenticate: refresh save failed: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RobotsTxt handles the /robots.txt route.
|
// RobotsTxt handles the /robots.txt route.
|
||||||
|
@ -59,87 +104,22 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticate) loadExisting(w http.ResponseWriter, r *http.Request) (*sessions.State, error) {
|
|
||||||
session, err := a.sessionStore.LoadSession(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = session.Valid()
|
|
||||||
if err == nil {
|
|
||||||
return session, nil
|
|
||||||
} else if !errors.Is(err, sessions.ErrExpired) {
|
|
||||||
return nil, fmt.Errorf("authenticate: non-refreshable error: %w", err)
|
|
||||||
} else {
|
|
||||||
return a.refresh(w, r, session)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (*sessions.State, error) {
|
|
||||||
newSession, err := a.provider.Refresh(r.Context(), s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
|
|
||||||
}
|
|
||||||
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
|
|
||||||
return nil, fmt.Errorf("authenticate: refresh save failed: %w", err)
|
|
||||||
}
|
|
||||||
return newSession, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignIn handles to authenticating a user.
|
// SignIn handles to authenticating a user.
|
||||||
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
||||||
session, err := a.loadExisting(w, r)
|
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
|
||||||
if err != nil {
|
|
||||||
log.FromRequest(r).Debug().Err(err).Msg("authenticate: need new session")
|
|
||||||
a.sessionStore.ClearSession(w, r)
|
|
||||||
a.OAuthStart(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := r.ParseForm(); err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
state := r.Form.Get("state")
|
|
||||||
if state == "" {
|
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("sign in state empty", http.StatusBadRequest, nil))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
redirectURL, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// encrypt session state as json blob
|
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||||
encrypted, err := sessions.MarshalSession(session, a.cipher)
|
|
||||||
if err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("couldn't marshal session", http.StatusInternalServerError, err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
http.Redirect(w, r, getAuthCodeRedirectURL(redirectURL, state, encrypted), http.StatusFound)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getAuthCodeRedirectURL(redirectURL *url.URL, state, authCode string) string {
|
|
||||||
// ParseQuery err handled by go's mux stack
|
|
||||||
params, _ := url.ParseQuery(redirectURL.RawQuery)
|
|
||||||
params.Set("code", authCode)
|
|
||||||
params.Set("state", state)
|
|
||||||
redirectURL.RawQuery = params.Encode()
|
|
||||||
return redirectURL.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignOut signs the user out and attempts to revoke the user's identity session
|
// SignOut signs the user out and attempts to revoke the user's identity session
|
||||||
// Handles both GET and POST.
|
// Handles both GET and POST.
|
||||||
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := r.ParseForm(); err != nil {
|
session, err := sessions.FromContext(r.Context())
|
||||||
httputil.ErrorResponse(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
redirectURI := r.Form.Get("redirect_uri")
|
|
||||||
session, err := a.sessionStore.LoadSession(r)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("authenticate: no session to signout, redirect and clear")
|
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
||||||
http.Redirect(w, r, redirectURI, http.StatusFound)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
a.sessionStore.ClearSession(w, r)
|
a.sessionStore.ClearSession(w, r)
|
||||||
|
@ -148,46 +128,30 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err))
|
httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
http.Redirect(w, r, redirectURI, http.StatusFound)
|
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthStart starts the authenticate process by redirecting to the identity provider.
|
// redirectToIdentityProvider starts the authenticate process by redirecting the
|
||||||
|
// user to their respective identity provider.
|
||||||
|
//
|
||||||
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
|
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
|
||||||
// https://tools.ietf.org/html/rfc6749#section-4.2.1
|
// https://tools.ietf.org/html/rfc6749#section-4.2.1
|
||||||
func (a *Authenticate) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||||
authRedirectURL := a.RedirectURL.ResolveReference(r.URL)
|
redirectURL := a.RedirectURL.ResolveReference(r.URL)
|
||||||
|
nonce := csrf.Token(r)
|
||||||
// Nonce is the opaque, cryptographically binding value used to maintain
|
state := fmt.Sprintf("%v:%v", nonce, redirectURL.String())
|
||||||
// state between the request and the callback.
|
encodedState := base64.URLEncoding.EncodeToString([]byte(state))
|
||||||
// OIDC : 3.1.2.1. Authentication Request
|
http.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
|
||||||
nonce := fmt.Sprintf("%x", cryptutil.GenerateKey())
|
|
||||||
a.csrfStore.SetCSRF(w, r, nonce)
|
|
||||||
// Redirection URI to which the response will be sent. This URI MUST exactly
|
|
||||||
// match one of the Redirection URI values for the Client pre-registered at
|
|
||||||
// at your identity provider
|
|
||||||
proxyRedirectURL, err := urlutil.ParseAndValidateURL(authRedirectURL.Query().Get("redirect_uri"))
|
|
||||||
if err != nil || !middleware.SameDomain(proxyRedirectURL, a.RedirectURL) {
|
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("proxy url not from the root domain", http.StatusBadRequest, err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the signature and timestamp values then compare hmac
|
|
||||||
proxyRedirectSig := authRedirectURL.Query().Get("sig")
|
|
||||||
ts := authRedirectURL.Query().Get("ts")
|
|
||||||
if !middleware.ValidSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, a.SharedKey) {
|
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// State is the opaque value used to maintain state between the request and
|
|
||||||
// the callback; contains both the nonce and redirect URI
|
|
||||||
state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String())))
|
|
||||||
|
|
||||||
// build the provider sign in url
|
|
||||||
signInURL := a.provider.GetSignInURL(state)
|
|
||||||
http.Redirect(w, r, signInURL, http.StatusFound)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthCallback handles the callback from the identity provider.
|
// OAuthCallback handles the callback from the identity provider.
|
||||||
|
//
|
||||||
|
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowSteps
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse
|
// https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse
|
||||||
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
redirect, err := a.getOAuthCallback(w, r)
|
redirect, err := a.getOAuthCallback(w, r)
|
||||||
|
@ -195,57 +159,49 @@ func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err))
|
httputil.ErrorResponse(w, r, fmt.Errorf("oauth callback : %w", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// redirect back to the proxy-service via sign_in
|
|
||||||
http.Redirect(w, r, redirect.String(), http.StatusFound)
|
http.Redirect(w, r, redirect.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
|
func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) (*url.URL, error) {
|
||||||
if err := r.ParseForm(); err != nil {
|
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
|
||||||
return nil, httputil.Error("invalid signature", http.StatusBadRequest, err)
|
//
|
||||||
|
// first, check if the identity provider returned an error
|
||||||
|
if idpError := r.FormValue("error"); idpError != "" {
|
||||||
|
return nil, httputil.Error(idpError, http.StatusBadRequest, fmt.Errorf("identity provider: %v", idpError))
|
||||||
}
|
}
|
||||||
// OIDC : 3.1.2.6. Authentication Error Response
|
// fail if no session redemption code is returned
|
||||||
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthError
|
code := r.FormValue("code")
|
||||||
if idpError := r.Form.Get("error"); idpError != "" {
|
|
||||||
return nil, httputil.Error("provider returned an error", http.StatusBadRequest, fmt.Errorf("provider error: %v", idpError))
|
|
||||||
}
|
|
||||||
code := r.Form.Get("code")
|
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, httputil.Error("provider didn't reply with code", http.StatusBadRequest, nil)
|
return nil, httputil.Error("identity provider returned empty code", http.StatusBadRequest, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate the returned code with the identity provider
|
// Successful Authentication Response: rfc6749#section-4.1.2 & OIDC#3.1.2.5
|
||||||
|
//
|
||||||
|
// Exchange the supplied Authorization Code for a valid user session.
|
||||||
session, err := a.provider.Authenticate(r.Context(), code)
|
session, err := a.provider.Authenticate(r.Context(), code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
|
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
|
||||||
}
|
}
|
||||||
|
// state includes a csrf nonce (validated by middleware) and redirect uri
|
||||||
// OIDC : 3.1.2.5. Successful Authentication Response
|
bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state"))
|
||||||
// Opaque value used to maintain state between the request and the callback.
|
|
||||||
bytes, err := base64.URLEncoding.DecodeString(r.Form.Get("state"))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed decoding state: %w", err)
|
return nil, httputil.Error("malformed state", http.StatusBadRequest, err)
|
||||||
}
|
}
|
||||||
s := strings.SplitN(string(bytes), ":", 2)
|
// split state into its it's components (nonce:redirect_uri)
|
||||||
if len(s) != 2 {
|
statePayload := strings.SplitN(string(bytes), ":", 2)
|
||||||
return nil, fmt.Errorf("invalid state size: %d", len(s))
|
if len(statePayload) != 2 {
|
||||||
|
return nil, fmt.Errorf("state malformed, size: %d", len(statePayload))
|
||||||
}
|
}
|
||||||
// state contains the csrf nonce and redirect uri
|
// parse redirect_uri; ignore csrf nonce (validity asserted by middleware)
|
||||||
nonce := s[0]
|
redirectURL, err := urlutil.ParseAndValidateURL(statePayload[1])
|
||||||
redirect := s[1]
|
|
||||||
c, err := a.csrfStore.GetCSRF(r)
|
|
||||||
defer a.csrfStore.ClearCSRF(w, r)
|
|
||||||
if err != nil || c.Value != nonce {
|
|
||||||
return nil, fmt.Errorf("csrf failure: %w", err)
|
|
||||||
}
|
|
||||||
redirectURL, err := urlutil.ParseAndValidateURL(redirect)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, httputil.Error(fmt.Sprintf("invalid redirect uri %s", redirect), http.StatusBadRequest, err)
|
return nil, httputil.Error("invalid redirect uri", http.StatusBadRequest, err)
|
||||||
}
|
|
||||||
// sanity check, we are redirecting back to the same subdomain right?
|
|
||||||
if !middleware.SameDomain(redirectURL, a.RedirectURL) {
|
|
||||||
return nil, httputil.Error(fmt.Sprintf("invalid redirect domain %v, %v", redirectURL, a.RedirectURL), http.StatusBadRequest, nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// todo(bdd): if we want to be _extra_ sure, we can validate that the
|
||||||
|
// redirectURL hmac is valid. But the nonce should cover the integrity...
|
||||||
|
|
||||||
|
// OK. Looks good so let's persist our user session
|
||||||
if err := a.sessionStore.SaveSession(w, r, session); err != nil {
|
if err := a.sessionStore.SaveSession(w, r, session); err != nil {
|
||||||
return nil, fmt.Errorf("failed saving new session: %w", err)
|
return nil, fmt.Errorf("failed saving new session: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -256,11 +212,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
// and exchanges that token for a pomerium session. The provided token's
|
// and exchanges that token for a pomerium session. The provided token's
|
||||||
// audience ('aud') attribute must match Pomerium's client_id.
|
// audience ('aud') attribute must match Pomerium's client_id.
|
||||||
func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) {
|
func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := r.ParseForm(); err != nil {
|
code := r.FormValue("id_token")
|
||||||
httputil.ErrorResponse(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
code := r.Form.Get("id_token")
|
|
||||||
if code == "" {
|
if code == "" {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("missing id token", http.StatusBadRequest, nil))
|
httputil.ErrorResponse(w, r, httputil.Error("missing id token", http.StatusBadRequest, nil))
|
||||||
return
|
return
|
||||||
|
|
|
@ -21,6 +21,7 @@ func testAuthenticate() *Authenticate {
|
||||||
var auth Authenticate
|
var auth Authenticate
|
||||||
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
|
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
|
||||||
auth.SharedKey = "IzY7MOZwzfOkmELXgozHDKTxoT3nOYhwkcmUVINsRww="
|
auth.SharedKey = "IzY7MOZwzfOkmELXgozHDKTxoT3nOYhwkcmUVINsRww="
|
||||||
|
auth.cookieSecret = []byte(auth.SharedKey)
|
||||||
auth.templates = templates.New()
|
auth.templates = templates.New()
|
||||||
return &auth
|
return &auth
|
||||||
}
|
}
|
||||||
|
@ -51,6 +52,7 @@ func TestAuthenticate_Handler(t *testing.T) {
|
||||||
t.Error("handler cannot be nil")
|
t.Error("handler cannot be nil")
|
||||||
}
|
}
|
||||||
req := httptest.NewRequest("GET", "/robots.txt", nil)
|
req := httptest.NewRequest("GET", "/robots.txt", nil)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
h.ServeHTTP(rr, req)
|
h.ServeHTTP(rr, req)
|
||||||
|
@ -63,6 +65,7 @@ func TestAuthenticate_Handler(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthenticate_SignIn(t *testing.T) {
|
func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
state string
|
state string
|
||||||
|
@ -76,36 +79,35 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
{"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
|
{"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
|
||||||
{"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusFound},
|
{"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusFound},
|
||||||
{"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound},
|
{"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound},
|
||||||
{"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusBadRequest}, // mocking hmac is meh
|
{"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusFound}, // mocking hmac is meh
|
||||||
{"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
{"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound},
|
||||||
|
|
||||||
// {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
|
// {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
|
||||||
{"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
{"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
|
||||||
{"malformed form", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
|
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
|
||||||
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
|
||||||
{"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
{"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
||||||
// actually caught by go's handler, but we should keep the test.
|
// actually caught by go's handler, but we should keep the test.
|
||||||
{"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
|
{"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
||||||
{"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusInternalServerError},
|
{"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusFound},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
sessionStore: tt.session,
|
sessionStore: tt.session,
|
||||||
provider: tt.provider,
|
provider: tt.provider,
|
||||||
RedirectURL: uriParse("https://some.example"),
|
RedirectURL: uriParseHelper("https://some.example"),
|
||||||
csrfStore: &sessions.MockCSRFStore{},
|
|
||||||
SharedKey: "secret",
|
SharedKey: "secret",
|
||||||
cipher: tt.cipher,
|
cipher: tt.cipher,
|
||||||
}
|
}
|
||||||
uri := &url.URL{Host: "corp.some.example", Scheme: "https", Path: "/"}
|
uri := &url.URL{Host: "corp.some.example", Scheme: "https", Path: "/"}
|
||||||
if tt.name == "malformed form" {
|
uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI)
|
||||||
uri.RawQuery = "example=%zzzzz"
|
|
||||||
} else {
|
|
||||||
uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI)
|
|
||||||
}
|
|
||||||
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
|
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
|
||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
|
state, _ := tt.session.LoadSession(r)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = sessions.NewContext(ctx, state, nil)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
a.SignIn(w, r)
|
a.SignIn(w, r)
|
||||||
|
@ -117,61 +119,18 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockCipher struct{}
|
func uriParseHelper(s string) *url.URL {
|
||||||
|
|
||||||
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
|
|
||||||
if string(s) == "error" {
|
|
||||||
return []byte(""), errors.New("error encrypting")
|
|
||||||
}
|
|
||||||
return []byte("OK"), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
|
|
||||||
if string(s) == "error" {
|
|
||||||
return []byte(""), errors.New("error encrypting")
|
|
||||||
}
|
|
||||||
return []byte("OK"), nil
|
|
||||||
}
|
|
||||||
func (a mockCipher) Marshal(s interface{}) (string, error) { return "ok", nil }
|
|
||||||
func (a mockCipher) Unmarshal(s string, i interface{}) error {
|
|
||||||
if s == "unmarshal error" || s == "error" {
|
|
||||||
return errors.New("error")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_getAuthCodeRedirectURL(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
redirectURL *url.URL
|
|
||||||
state string
|
|
||||||
authCode string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"https", uriParse("https://www.pomerium.io"), "state", "auth-code", "https://www.pomerium.io?code=auth-code&state=state"},
|
|
||||||
{"http", uriParse("http://www.pomerium.io"), "state", "auth-code", "http://www.pomerium.io?code=auth-code&state=state"},
|
|
||||||
{"no subdomain", uriParse("http://pomerium.io"), "state", "auth-code", "http://pomerium.io?code=auth-code&state=state"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := getAuthCodeRedirectURL(tt.redirectURL, tt.state, tt.authCode); got != tt.want {
|
|
||||||
t.Errorf("getAuthCodeRedirectURL() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func uriParse(s string) *url.URL {
|
|
||||||
uri, _ := url.Parse(s)
|
uri, _ := url.Parse(s)
|
||||||
return uri
|
return uri
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthenticate_SignOut(t *testing.T) {
|
func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
method string
|
method string
|
||||||
|
|
||||||
|
ctxError error
|
||||||
redirectURL string
|
redirectURL string
|
||||||
sig string
|
sig string
|
||||||
ts string
|
ts string
|
||||||
|
@ -181,17 +140,16 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
wantCode int
|
wantCode int
|
||||||
wantBody string
|
wantBody string
|
||||||
}{
|
}{
|
||||||
{"good post", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
|
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
|
||||||
{"failed revoke", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"},
|
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"},
|
||||||
{"malformed form", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusInternalServerError, ""},
|
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, ""},
|
||||||
{"load session error", http.MethodPost, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
|
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, ""},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
sessionStore: tt.sessionStore,
|
sessionStore: tt.sessionStore,
|
||||||
provider: tt.provider,
|
provider: tt.provider,
|
||||||
cipher: mockCipher{},
|
|
||||||
templates: templates.New(),
|
templates: templates.New(),
|
||||||
}
|
}
|
||||||
u, _ := url.Parse("/sign_out")
|
u, _ := url.Parse("/sign_out")
|
||||||
|
@ -200,10 +158,11 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
params.Add("ts", tt.ts)
|
params.Add("ts", tt.ts)
|
||||||
params.Add("redirect_uri", tt.redirectURL)
|
params.Add("redirect_uri", tt.redirectURL)
|
||||||
u.RawQuery = params.Encode()
|
u.RawQuery = params.Encode()
|
||||||
if tt.name == "malformed form" {
|
|
||||||
u.RawQuery = "example=%zzzzz"
|
|
||||||
}
|
|
||||||
r := httptest.NewRequest(tt.method, u.String(), nil)
|
r := httptest.NewRequest(tt.method, u.String(), nil)
|
||||||
|
state, _ := tt.sessionStore.LoadSession(r)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
a.SignOut(w, r)
|
a.SignOut(w, r)
|
||||||
|
@ -217,64 +176,8 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) string {
|
|
||||||
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
|
|
||||||
h := cryptutil.Hash(secret, data)
|
|
||||||
return base64.URLEncoding.EncodeToString(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthenticate_OAuthStart(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
method string
|
|
||||||
redirectURLSetting string
|
|
||||||
|
|
||||||
redirectURL string
|
|
||||||
sig string
|
|
||||||
ts string
|
|
||||||
|
|
||||||
provider identity.Authenticator
|
|
||||||
csrfStore sessions.MockCSRFStore
|
|
||||||
// sessionStore sessions.SessionStore
|
|
||||||
wantCode int
|
|
||||||
}{
|
|
||||||
{"good", http.MethodGet, "https://corp.pomerium.io/", "https://corp.pomerium.io/", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusFound},
|
|
||||||
{"bad timestamp", http.MethodGet, "https://corp.pomerium.io/", "https://corp.pomerium.io/", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Add(10 * time.Hour).Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest},
|
|
||||||
{"missing redirect", http.MethodGet, "https://corp.pomerium.io/", "", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest},
|
|
||||||
{"malformed redirect", http.MethodGet, "https://corp.pomerium.io/", "https://pomerium.com%zzzzz", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest},
|
|
||||||
{"different domains", http.MethodGet, "https://corp.notpomerium.io/", "https://corp.pomerium.io/", redirectURLSignature("https://corp.pomerium.io/", time.Now(), "secret"), fmt.Sprint(time.Now().Unix()), identity.MockProvider{}, sessions.MockCSRFStore{}, http.StatusBadRequest},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
a := &Authenticate{
|
|
||||||
RedirectURL: uriParse(tt.redirectURLSetting),
|
|
||||||
csrfStore: tt.csrfStore,
|
|
||||||
provider: tt.provider,
|
|
||||||
SharedKey: "secret",
|
|
||||||
cipher: mockCipher{},
|
|
||||||
}
|
|
||||||
u, _ := url.Parse("/oauth_start")
|
|
||||||
params, _ := url.ParseQuery(u.RawQuery)
|
|
||||||
params.Add("sig", tt.sig)
|
|
||||||
params.Add("ts", tt.ts)
|
|
||||||
params.Add("redirect_uri", tt.redirectURL)
|
|
||||||
|
|
||||||
u.RawQuery = params.Encode()
|
|
||||||
|
|
||||||
r := httptest.NewRequest(tt.method, u.String(), nil)
|
|
||||||
r.Header.Set("Accept", "application/json")
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
a.OAuthStart(w, r)
|
|
||||||
if status := w.Code; status != tt.wantCode {
|
|
||||||
t.Errorf("handler returned wrong status code: got %v want %v\n%v", status, tt.wantCode, w.Body.String())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthenticate_OAuthCallback(t *testing.T) {
|
func TestAuthenticate_OAuthCallback(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
method string
|
method string
|
||||||
|
@ -286,24 +189,20 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
||||||
authenticateURL string
|
authenticateURL string
|
||||||
session sessions.SessionStore
|
session sessions.SessionStore
|
||||||
provider identity.MockProvider
|
provider identity.MockProvider
|
||||||
csrfStore sessions.MockCSRFStore
|
|
||||||
|
|
||||||
want string
|
want string
|
||||||
wantCode int
|
wantCode int
|
||||||
}{
|
}{
|
||||||
{"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusFound},
|
{"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusFound},
|
||||||
{"get csrf error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", GetError: errors.New("error"), Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError},
|
{"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
|
||||||
{"csrf nonce error", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "not nonce"}}, "", http.StatusInternalServerError},
|
{"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusInternalServerError},
|
||||||
{"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
|
{"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
|
||||||
{"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
|
{"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
|
||||||
{"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
|
{"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
|
||||||
{"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
|
{"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"invalid state string", http.MethodGet, "", "code", "nonce:https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
|
{"bad base64 state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")) + "%", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"malformed state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusInternalServerError},
|
{"too many state delimeters", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io:wait")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
|
{"too few state delimeters", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusInternalServerError},
|
||||||
{"malformed form", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "", http.StatusBadRequest},
|
|
||||||
{"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
|
||||||
{"different domains", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://some.example.notpomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, sessions.MockCSRFStore{ResponseCSRF: "csrf", Cookie: &http.Cookie{Value: "nonce"}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -311,7 +210,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
RedirectURL: authURL,
|
RedirectURL: authURL,
|
||||||
sessionStore: tt.session,
|
sessionStore: tt.session,
|
||||||
csrfStore: tt.csrfStore,
|
|
||||||
provider: tt.provider,
|
provider: tt.provider,
|
||||||
}
|
}
|
||||||
u, _ := url.Parse("/oauthGet")
|
u, _ := url.Parse("/oauthGet")
|
||||||
|
@ -322,9 +220,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
||||||
|
|
||||||
u.RawQuery = params.Encode()
|
u.RawQuery = params.Encode()
|
||||||
|
|
||||||
if tt.name == "malformed form" {
|
|
||||||
u.RawQuery = "example=%zzzzz"
|
|
||||||
}
|
|
||||||
r := httptest.NewRequest(tt.method, u.String(), nil)
|
r := httptest.NewRequest(tt.method, u.String(), nil)
|
||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
@ -339,6 +234,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthenticate_ExchangeToken(t *testing.T) {
|
func TestAuthenticate_ExchangeToken(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
method string
|
method string
|
||||||
|
@ -384,3 +280,55 @@ func TestAuthenticate_ExchangeToken(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
|
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
session sessions.SessionStore
|
||||||
|
ctxError error
|
||||||
|
provider identity.Authenticator
|
||||||
|
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
|
||||||
|
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||||
|
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusOK},
|
||||||
|
{"expired,refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
|
||||||
|
{"expired,save error", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusFound},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
a := Authenticate{
|
||||||
|
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||||
|
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
||||||
|
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
||||||
|
sessionStore: tt.session,
|
||||||
|
provider: tt.provider,
|
||||||
|
}
|
||||||
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
state, _ := tt.session.LoadSession(r)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
got := a.VerifySession(fn)
|
||||||
|
got.ServeHTTP(w, r)
|
||||||
|
if status := w.Code; status != tt.wantStatus {
|
||||||
|
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||||
|
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
func ValidateOptions(o config.Options) error {
|
func ValidateOptions(o config.Options) error {
|
||||||
decoded, err := base64.StdEncoding.DecodeString(o.SharedKey)
|
decoded, err := base64.StdEncoding.DecodeString(o.SharedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("authorize: `SHARED_SECRET` setting is invalid base64: %v", err)
|
return fmt.Errorf("authorize: `SHARED_SECRET` malformed base64: %v", err)
|
||||||
}
|
}
|
||||||
if len(decoded) != 32 {
|
if len(decoded) != 32 {
|
||||||
return fmt.Errorf("authorize: `SHARED_SECRET` want 32 but got %d bytes", len(decoded))
|
return fmt.Errorf("authorize: `SHARED_SECRET` want 32 but got %d bytes", len(decoded))
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
@ -44,9 +45,9 @@ func main() {
|
||||||
setupTracing(opt)
|
setupTracing(opt)
|
||||||
setupHTTPRedirectServer(opt)
|
setupHTTPRedirectServer(opt)
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
r := newGlobalRouter(opt)
|
||||||
grpcServer := setupGRPCServer(opt)
|
grpcServer := setupGRPCServer(opt)
|
||||||
_, err = newAuthenticateService(*opt, mux)
|
_, err = newAuthenticateService(*opt, r.Host(urlutil.StripPort(opt.AuthenticateURL.Host)).Subrouter())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("cmd/pomerium: authenticate")
|
log.Fatal().Err(err).Msg("cmd/pomerium: authenticate")
|
||||||
}
|
}
|
||||||
|
@ -56,7 +57,7 @@ func main() {
|
||||||
log.Fatal().Err(err).Msg("cmd/pomerium: authorize")
|
log.Fatal().Err(err).Msg("cmd/pomerium: authorize")
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy, err := newProxyService(*opt, mux)
|
proxy, err := newProxyService(*opt, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("cmd/pomerium: proxy")
|
log.Fatal().Err(err).Msg("cmd/pomerium: proxy")
|
||||||
}
|
}
|
||||||
|
@ -70,8 +71,7 @@ func main() {
|
||||||
log.Info().Str("file", e.Name).Msg("cmd/pomerium: config file changed")
|
log.Info().Str("file", e.Name).Msg("cmd/pomerium: config file changed")
|
||||||
opt = config.HandleConfigUpdate(*configFile, opt, []config.OptionsUpdater{authz, proxy})
|
opt = config.HandleConfigUpdate(*configFile, opt, []config.OptionsUpdater{authz, proxy})
|
||||||
})
|
})
|
||||||
|
srv, err := httputil.NewTLSServer(configToServerOptions(opt), r, grpcServer)
|
||||||
srv, err := httputil.NewTLSServer(configToServerOptions(opt), mainHandler(opt, mux), grpcServer)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("cmd/pomerium: couldn't start pomerium")
|
log.Fatal().Err(err).Msg("cmd/pomerium: couldn't start pomerium")
|
||||||
}
|
}
|
||||||
|
@ -80,7 +80,7 @@ func main() {
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAuthenticateService(opt config.Options, mux *http.ServeMux) (*authenticate.Authenticate, error) {
|
func newAuthenticateService(opt config.Options, r *mux.Router) (*authenticate.Authenticate, error) {
|
||||||
if !config.IsAuthenticate(opt.Services) {
|
if !config.IsAuthenticate(opt.Services) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -88,7 +88,7 @@ func newAuthenticateService(opt config.Options, mux *http.ServeMux) (*authentica
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mux.Handle(urlutil.StripPort(opt.AuthenticateURL.Host)+"/", service.Handler())
|
r.PathPrefix("/").Handler(service.Handler())
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ func newAuthorizeService(opt config.Options, rpc *grpc.Server) (*authorize.Autho
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newProxyService(opt config.Options, mux *http.ServeMux) (*proxy.Proxy, error) {
|
func newProxyService(opt config.Options, r *mux.Router) (*proxy.Proxy, error) {
|
||||||
if !config.IsProxy(opt.Services) {
|
if !config.IsProxy(opt.Services) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -112,15 +112,15 @@ func newProxyService(opt config.Options, mux *http.ServeMux) (*proxy.Proxy, erro
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mux.Handle("/", service.Handler())
|
r.PathPrefix("/").Handler(service.Handler())
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func mainHandler(o *config.Options, mux http.Handler) http.Handler {
|
func newGlobalRouter(o *config.Options) *mux.Router {
|
||||||
c := middleware.NewChain()
|
mux := httputil.NewRouter()
|
||||||
c = c.Append(metrics.HTTPMetricsHandler(o.Services))
|
mux.Use(metrics.HTTPMetricsHandler(o.Services))
|
||||||
c = c.Append(log.NewHandler(log.Logger))
|
mux.Use(log.NewHandler(log.Logger))
|
||||||
c = c.Append(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
mux.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
||||||
log.FromRequest(r).Debug().
|
log.FromRequest(r).Debug().
|
||||||
Dur("duration", duration).
|
Dur("duration", duration).
|
||||||
Int("size", size).
|
Int("size", size).
|
||||||
|
@ -133,15 +133,15 @@ func mainHandler(o *config.Options, mux http.Handler) http.Handler {
|
||||||
Msg("http-request")
|
Msg("http-request")
|
||||||
}))
|
}))
|
||||||
if len(o.Headers) != 0 {
|
if len(o.Headers) != 0 {
|
||||||
c = c.Append(middleware.SetHeaders(o.Headers))
|
mux.Use(middleware.SetHeaders(o.Headers))
|
||||||
}
|
}
|
||||||
c = c.Append(log.ForwardedAddrHandler("fwd_ip"))
|
mux.Use(log.ForwardedAddrHandler("fwd_ip"))
|
||||||
c = c.Append(log.RemoteAddrHandler("ip"))
|
mux.Use(log.RemoteAddrHandler("ip"))
|
||||||
c = c.Append(log.UserAgentHandler("user_agent"))
|
mux.Use(log.UserAgentHandler("user_agent"))
|
||||||
c = c.Append(log.RefererHandler("referer"))
|
mux.Use(log.RefererHandler("referer"))
|
||||||
c = c.Append(log.RequestIDHandler("req_id", "Request-Id"))
|
mux.Use(log.RequestIDHandler("req_id", "Request-Id"))
|
||||||
c = c.Append(middleware.Healthcheck("/ping", version.UserAgent()))
|
mux.Use(middleware.Healthcheck("/ping", version.UserAgent()))
|
||||||
return c.Then(mux)
|
return mux
|
||||||
}
|
}
|
||||||
|
|
||||||
func configToServerOptions(opt *config.Options) *httputil.ServerOptions {
|
func configToServerOptions(opt *config.Options) *httputil.ServerOptions {
|
||||||
|
|
|
@ -21,7 +21,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_newAuthenticateService(t *testing.T) {
|
func Test_newAuthenticateService(t *testing.T) {
|
||||||
mux := http.NewServeMux()
|
mux := httputil.NewRouter()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -127,7 +127,7 @@ func Test_newProxyeService(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
mux := http.NewServeMux()
|
mux := httputil.NewRouter()
|
||||||
testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
|
testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -161,7 +161,7 @@ func Test_newProxyeService(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_mainHandler(t *testing.T) {
|
func Test_newGlobalRouter(t *testing.T) {
|
||||||
o := config.Options{
|
o := config.Options{
|
||||||
Services: "all",
|
Services: "all",
|
||||||
Headers: map[string]string{
|
Headers: map[string]string{
|
||||||
|
@ -172,7 +172,6 @@ func Test_mainHandler(t *testing.T) {
|
||||||
"Content-Security-Policy": "default-src 'none'; style-src 'self' 'sha256-pSTVzZsFAqd2U3QYu+BoBDtuJWaPM/+qMy/dBRrhb5Y='; img-src 'self';",
|
"Content-Security-Policy": "default-src 'none'; style-src 'self' 'sha256-pSTVzZsFAqd2U3QYu+BoBDtuJWaPM/+qMy/dBRrhb5Y='; img-src 'self';",
|
||||||
"Referrer-Policy": "Same-origin",
|
"Referrer-Policy": "Same-origin",
|
||||||
}}
|
}}
|
||||||
mux := http.NewServeMux()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/404", nil)
|
req := httptest.NewRequest(http.MethodGet, "/404", nil)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -181,8 +180,9 @@ func Test_mainHandler(t *testing.T) {
|
||||||
io.WriteString(w, `OK`)
|
io.WriteString(w, `OK`)
|
||||||
})
|
})
|
||||||
|
|
||||||
mux.Handle("/404", h)
|
out := newGlobalRouter(&o)
|
||||||
out := mainHandler(&o, mux)
|
out.Handle("/404", h)
|
||||||
|
|
||||||
out.ServeHTTP(rr, req)
|
out.ServeHTTP(rr, req)
|
||||||
expected := fmt.Sprintf("OK")
|
expected := fmt.Sprintf("OK")
|
||||||
body := rr.Body.String()
|
body := rr.Body.String()
|
||||||
|
|
|
@ -5,6 +5,23 @@
|
||||||
### New
|
### New
|
||||||
|
|
||||||
- Add ability to override HTTPS backend's TLS Server Name. [GH-297](https://github.com/pomerium/pomerium/pull/297)
|
- Add ability to override HTTPS backend's TLS Server Name. [GH-297](https://github.com/pomerium/pomerium/pull/297)
|
||||||
|
- Add ability to set pomerium's encrypted session in a auth bearer token, or query param.
|
||||||
|
|
||||||
|
### Security
|
||||||
|
|
||||||
|
- Under certain circumstances, where debug logging was enabled, pomerium's shared secret could be leaked to http access logs as a query param.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Fixed an issue where CSRF would fail if multiple tabs were open. [GH-306](https://github.com/pomerium/pomerium/issues/306)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Authenticate service no longer uses gRPC.
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
- Removed `AUTHENTICATE_INTERNAL_URL`/`authenticate_internal_url` which is no longer used.
|
||||||
|
|
||||||
## v0.3.0
|
## v0.3.0
|
||||||
|
|
||||||
|
|
|
@ -228,8 +228,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor
|
||||||
|
|
||||||
| Config Key | Description | Required |
|
| Config Key | Description | Required |
|
||||||
| :--------------- | :---------------------------------------------------------------- | -------- |
|
| :--------------- | :---------------------------------------------------------------- | -------- |
|
||||||
| tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ |
|
| tracing_provider | The name of the tracing provider. (e.g. jaeger) | ✅ |
|
||||||
| tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ |
|
| tracing_debug | Will disable [sampling](https://opencensus.io/tracing/sampling/). | ❌ |
|
||||||
|
|
||||||
### Jaeger
|
### Jaeger
|
||||||
|
|
||||||
|
@ -243,8 +243,8 @@ Each unit work is called a Span in a trace. Spans include metadata about the wor
|
||||||
|
|
||||||
| Config Key | Description | Required |
|
| Config Key | Description | Required |
|
||||||
| :-------------------------------- | :------------------------------------------ | -------- |
|
| :-------------------------------- | :------------------------------------------ | -------- |
|
||||||
| tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ |
|
| tracing_jaeger_collector_endpoint | Url to the Jaeger HTTP Thrift collector. | ✅ |
|
||||||
| tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ |
|
| tracing_jaeger_agent_endpoint | Send spans to jaeger-agent at this address. | ✅ |
|
||||||
|
|
||||||
#### Example
|
#### Example
|
||||||
|
|
||||||
|
@ -478,11 +478,11 @@ Authenticate Service URL is the externally accessible URL for the authenticate s
|
||||||
- Config File Key: `authorize_service_url`
|
- Config File Key: `authorize_service_url`
|
||||||
- Type: `URL`
|
- Type: `URL`
|
||||||
- Required
|
- Required
|
||||||
- Example: `https://access.corp.example.com` or `https://pomerium-authorize-service.default.svc.cluster.local`
|
- Example: `https://authorize.corp.example.com` or `https://pomerium-authorize-service.default.svc.cluster.local`
|
||||||
|
|
||||||
Authorize Service URL is the location of the internally accessible authorize service. NOTE: Unlike authenticate, authorize has no publicly accessible http handlers so this setting is purely for gRPC communication.
|
Authorize Service URL is the location of the internally accessible authorize service. NOTE: Unlike authenticate, authorize has no publicly accessible http handlers so this setting is purely for gRPC communication.
|
||||||
|
|
||||||
If your load balancer does not support gRPC pass-through you'll need to set this value to an internally routable location (`https://pomerium-authorize-service.default.svc.cluster.local`) instead of an externally routable one (`https://access.corp.example.com`).
|
If your load balancer does not support gRPC pass-through you'll need to set this value to an internally routable location (`https://pomerium-authorize-service.default.svc.cluster.local`) instead of an externally routable one (`https://authorize.corp.example.com`).
|
||||||
|
|
||||||
## Override Certificate Name
|
## Override Certificate Name
|
||||||
|
|
||||||
|
|
4
go.mod
4
go.mod
|
@ -9,10 +9,12 @@ require (
|
||||||
github.com/fsnotify/fsnotify v1.4.7
|
github.com/fsnotify/fsnotify v1.4.7
|
||||||
github.com/golang/mock v1.3.1
|
github.com/golang/mock v1.3.1
|
||||||
github.com/golang/protobuf v1.3.1
|
github.com/golang/protobuf v1.3.1
|
||||||
github.com/google/go-cmp v0.3.0
|
github.com/google/go-cmp v0.3.1
|
||||||
|
github.com/gorilla/mux v1.6.2
|
||||||
github.com/magiconair/properties v1.8.1 // indirect
|
github.com/magiconair/properties v1.8.1 // indirect
|
||||||
github.com/mitchellh/hashstructure v1.0.0
|
github.com/mitchellh/hashstructure v1.0.0
|
||||||
github.com/pelletier/go-toml v1.4.0 // indirect
|
github.com/pelletier/go-toml v1.4.0 // indirect
|
||||||
|
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30
|
||||||
github.com/pomerium/go-oidc v2.0.0+incompatible
|
github.com/pomerium/go-oidc v2.0.0+incompatible
|
||||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||||
github.com/prometheus/client_golang v0.9.3
|
github.com/prometheus/client_golang v0.9.3
|
||||||
|
|
8
go.sum
8
go.sum
|
@ -65,11 +65,16 @@ github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ
|
||||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||||
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
|
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
|
||||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||||
|
github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
|
||||||
|
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||||
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||||
|
github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk=
|
||||||
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
|
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
|
||||||
|
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
||||||
|
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||||
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
|
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
|
||||||
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 h1:Iju5GlWwrvL6UBg4zJJt3btmonfrMlCDdsejg4CZE7c=
|
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 h1:Iju5GlWwrvL6UBg4zJJt3btmonfrMlCDdsejg4CZE7c=
|
||||||
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
|
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
|
||||||
|
@ -115,9 +120,12 @@ github.com/pelletier/go-toml v1.4.0 h1:u3Z1r+oOXJIkxqw34zVhyPgjBsm6X2wn21NWs/HfS
|
||||||
github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo=
|
github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo=
|
||||||
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
|
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
|
||||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30 h1:jggCv6hZvcxjGa3gqkYY2EUuOkITI9Znugz/f3QJfRQ=
|
||||||
|
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0=
|
||||||
github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o=
|
github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o=
|
||||||
github.com/pomerium/go-oidc v2.0.0+incompatible/go.mod h1:DRsGVw6MOgxbfq4Y57jKOE8lbEfayxeiY0A8/4vxjBM=
|
github.com/pomerium/go-oidc v2.0.0+incompatible/go.mod h1:DRsGVw6MOgxbfq4Y57jKOE8lbEfayxeiY0A8/4vxjBM=
|
||||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
||||||
|
|
|
@ -217,7 +217,7 @@ func (o *Options) Validate() error {
|
||||||
// shared key must be set for all modes other than "all"
|
// shared key must be set for all modes other than "all"
|
||||||
if o.SharedKey == "" {
|
if o.SharedKey == "" {
|
||||||
if o.Services == "all" {
|
if o.Services == "all" {
|
||||||
o.SharedKey = cryptutil.GenerateRandomString(32)
|
o.SharedKey = cryptutil.NewBase64Key()
|
||||||
} else {
|
} else {
|
||||||
return errors.New("shared-key cannot be empty")
|
return errors.New("shared-key cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,3 +116,9 @@ func (p *Policy) Validate() error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (p *Policy) String() string {
|
||||||
|
if p.Source == nil || p.Destination == nil {
|
||||||
|
return fmt.Sprintf("%s → %s", p.From, p.To)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s → %s", p.Source.String(), p.Destination.String())
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package config // import "github.com/pomerium/pomerium/internal/config"
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -44,3 +44,28 @@ func Test_Validate(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPolicy_String(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
From string
|
||||||
|
To string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"good", "https://pomerium.io", "https://localhost", "https://pomerium.io → https://localhost"},
|
||||||
|
{"failed to validate", "https://pomerium.io", "localhost", "https://pomerium.io → localhost"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &Policy{
|
||||||
|
From: tt.From,
|
||||||
|
To: tt.To,
|
||||||
|
}
|
||||||
|
p.Validate()
|
||||||
|
if got := p.String(); got != tt.want {
|
||||||
|
t.Errorf("Policy.String() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -13,20 +13,27 @@ import (
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DefaultKeySize is the default key size in bytes.
|
||||||
const DefaultKeySize = 32
|
const DefaultKeySize = 32
|
||||||
|
|
||||||
// GenerateKey generates a random 32-byte key.
|
// NewKey generates a random 32-byte key.
|
||||||
//
|
//
|
||||||
// Panics if source of randomness fails.
|
// Panics if source of randomness fails.
|
||||||
func GenerateKey() []byte {
|
func NewKey() []byte {
|
||||||
return randomBytes(DefaultKeySize)
|
return randomBytes(DefaultKeySize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateRandomString returns base64 encoded securely generated random string
|
// NewBase64Key generates a random base64 encoded 32-byte key.
|
||||||
// of a given set of bytes.
|
|
||||||
//
|
//
|
||||||
// Panics if source of randomness fails.
|
// Panics if source of randomness fails.
|
||||||
func GenerateRandomString(c int) string {
|
func NewBase64Key() string {
|
||||||
|
return NewRandomStringN(DefaultKeySize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRandomStringN returns base64 encoded random string of a given num of bytes.
|
||||||
|
//
|
||||||
|
// Panics if source of randomness fails.
|
||||||
|
func NewRandomStringN(c int) string {
|
||||||
return base64.StdEncoding.EncodeToString(randomBytes(c))
|
return base64.StdEncoding.EncodeToString(randomBytes(c))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
package cryptutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
@ -13,7 +13,7 @@ import (
|
||||||
func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
||||||
plaintext := []byte("my plain text value")
|
plaintext := []byte("my plain text value")
|
||||||
|
|
||||||
key := GenerateKey()
|
key := NewKey()
|
||||||
c, err := NewCipher(key)
|
c, err := NewCipher(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected err: %v", err)
|
t.Fatalf("unexpected err: %v", err)
|
||||||
|
@ -47,7 +47,7 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
||||||
key := GenerateKey()
|
key := NewKey()
|
||||||
|
|
||||||
c, err := NewCipher(key)
|
c, err := NewCipher(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -102,7 +102,7 @@ func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCipherDataRace(t *testing.T) {
|
func TestCipherDataRace(t *testing.T) {
|
||||||
cipher, err := NewCipher(GenerateKey())
|
cipher, err := NewCipher(NewKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected generating cipher err: %v", err)
|
t.Fatalf("unexpected generating cipher err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -183,21 +183,21 @@ func TestGenerateRandomString(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
o := GenerateRandomString(tt.c)
|
o := NewRandomStringN(tt.c)
|
||||||
b, err := base64.StdEncoding.DecodeString(o)
|
b, err := base64.StdEncoding.DecodeString(o)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
got := len(b)
|
got := len(b)
|
||||||
if got != tt.want {
|
if got != tt.want {
|
||||||
t.Errorf("GenerateRandomString() = %d, want %d", got, tt.want)
|
t.Errorf("NewRandomStringN() = %d, want %d", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestXChaCha20Cipher_Marshal(t *testing.T) {
|
func TestXChaCha20Cipher_Marshal(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
s interface{}
|
s interface{}
|
||||||
|
@ -225,7 +225,7 @@ func TestXChaCha20Cipher_Marshal(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
||||||
c, err := NewCipher(GenerateKey())
|
c, err := NewCipher(NewKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected err: %v", err)
|
t.Fatalf("unexpected err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -239,15 +239,15 @@ func TestXChaCha20Cipher_Marshal(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewCipher(t *testing.T) {
|
func TestNewCipher(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
secret []byte
|
secret []byte
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"simple 32 byte key", GenerateKey(), false},
|
{"simple 32 byte key", NewKey(), false},
|
||||||
{"key too short", []byte("what is entropy"), true},
|
{"key too short", []byte("what is entropy"), true},
|
||||||
{"key too long", []byte(GenerateRandomString(33)), true},
|
{"key too long", []byte(NewRandomStringN(33)), true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -261,16 +261,16 @@ func TestNewCipher(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewCipherFromBase64(t *testing.T) {
|
func TestNewCipherFromBase64(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
s string
|
s string
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"simple 32 byte key", base64.StdEncoding.EncodeToString(GenerateKey()), false},
|
{"simple 32 byte key", base64.StdEncoding.EncodeToString(NewKey()), false},
|
||||||
{"key too short", base64.StdEncoding.EncodeToString([]byte("what is entropy")), true},
|
{"key too short", base64.StdEncoding.EncodeToString([]byte("what is entropy")), true},
|
||||||
{"key too long", GenerateRandomString(33), true},
|
{"key too long", NewRandomStringN(33), true},
|
||||||
{"bad base 64", string(GenerateKey()), true},
|
{"bad base 64", string(NewKey()), true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -282,3 +282,26 @@ func TestNewCipherFromBase64(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewBase64Key(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"simple", 32},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
o := NewBase64Key()
|
||||||
|
b, err := base64.StdEncoding.DecodeString(o)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
got := len(b)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("NewBase64Key() = %d, want %d", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
19
internal/httputil/router.go
Normal file
19
internal/httputil/router.go
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/pomerium/csrf"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewRouter returns a new router instance.
|
||||||
|
func NewRouter() *mux.Router {
|
||||||
|
return mux.NewRouter()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRFFailureHandler sets a HTTP 403 Forbidden status and writes the
|
||||||
|
// CSRF failure reason to the response.
|
||||||
|
func CSRFFailureHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ErrorResponse(w, r, Error("CSRF Failure", http.StatusForbidden, csrf.FailureReason(r)))
|
||||||
|
}
|
37
internal/httputil/router_test.go
Normal file
37
internal/httputil/router_test.go
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCSRFFailureHandler(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
|
||||||
|
wantBody string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{"basic csrf failure", "{\"error\":\"CSRF Failure\"}\n", http.StatusForbidden},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
CSRFFailureHandler(w, r)
|
||||||
|
gotBody := w.Body.String()
|
||||||
|
gotStatus := w.Result().StatusCode
|
||||||
|
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,109 +0,0 @@
|
||||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
|
||||||
|
|
||||||
import "net/http"
|
|
||||||
|
|
||||||
// Constructor is a type alias for func(http.Handler) http.Handler
|
|
||||||
type Constructor func(http.Handler) http.Handler
|
|
||||||
|
|
||||||
// Chain acts as a list of http.Handler constructors.
|
|
||||||
// Chain is effectively immutable:
|
|
||||||
// once created, it will always hold
|
|
||||||
// the same set of constructors in the same order.
|
|
||||||
type Chain struct {
|
|
||||||
constructors []Constructor
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewChain creates a new chain,
|
|
||||||
// memorizing the given list of middleware constructors.
|
|
||||||
// New serves no other function,
|
|
||||||
// constructors are only called upon a call to Then().
|
|
||||||
func NewChain(constructors ...Constructor) Chain {
|
|
||||||
return Chain{append([]Constructor(nil), constructors...)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then chains the middleware and returns the final http.Handler.
|
|
||||||
// NewChain(m1, m2, m3).Then(h)
|
|
||||||
// is equivalent to:
|
|
||||||
// m1(m2(m3(h)))
|
|
||||||
// When the request comes in, it will be passed to m1, then m2, then m3
|
|
||||||
// and finally, the given handler
|
|
||||||
// (assuming every middleware calls the following one).
|
|
||||||
//
|
|
||||||
// A chain can be safely reused by calling Then() several times.
|
|
||||||
// stdStack := middleware.NewChain(ratelimitHandler, csrfHandler)
|
|
||||||
// indexPipe = stdStack.Then(indexHandler)
|
|
||||||
// authPipe = stdStack.Then(authHandler)
|
|
||||||
// Note that constructors are called on every call to Then()
|
|
||||||
// and thus several instances of the same middleware will be created
|
|
||||||
// when a chain is reused in this way.
|
|
||||||
// For proper middleware, this should cause no problems.
|
|
||||||
//
|
|
||||||
// Then() treats nil as http.DefaultServeMux.
|
|
||||||
func (c Chain) Then(h http.Handler) http.Handler {
|
|
||||||
if h == nil {
|
|
||||||
h = http.DefaultServeMux
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range c.constructors {
|
|
||||||
h = c.constructors[len(c.constructors)-1-i](h)
|
|
||||||
}
|
|
||||||
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
// ThenFunc works identically to Then, but takes
|
|
||||||
// a HandlerFunc instead of a Handler.
|
|
||||||
//
|
|
||||||
// The following two statements are equivalent:
|
|
||||||
// c.Then(http.HandlerFunc(fn))
|
|
||||||
// c.ThenFunc(fn)
|
|
||||||
//
|
|
||||||
// ThenFunc provides all the guarantees of Then.
|
|
||||||
func (c Chain) ThenFunc(fn http.HandlerFunc) http.Handler {
|
|
||||||
if fn == nil {
|
|
||||||
return c.Then(nil)
|
|
||||||
}
|
|
||||||
return c.Then(fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append extends a chain, adding the specified constructors
|
|
||||||
// as the last ones in the request flow.
|
|
||||||
//
|
|
||||||
// Append returns a new chain, leaving the original one untouched.
|
|
||||||
//
|
|
||||||
// stdChain := middleware.NewChain(m1, m2)
|
|
||||||
// extChain := stdChain.Append(m3, m4)
|
|
||||||
// // requests in stdChain go m1 -> m2
|
|
||||||
// // requests in extChain go m1 -> m2 -> m3 -> m4
|
|
||||||
func (c Chain) Append(constructors ...Constructor) Chain {
|
|
||||||
newCons := make([]Constructor, 0, len(c.constructors)+len(constructors))
|
|
||||||
newCons = append(newCons, c.constructors...)
|
|
||||||
newCons = append(newCons, constructors...)
|
|
||||||
|
|
||||||
return Chain{newCons}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extend extends a chain by adding the specified chain
|
|
||||||
// as the last one in the request flow.
|
|
||||||
//
|
|
||||||
// Extend returns a new chain, leaving the original one untouched.
|
|
||||||
//
|
|
||||||
// stdChain := middleware.NewChain(m1, m2)
|
|
||||||
// ext1Chain := middleware.NewChain(m3, m4)
|
|
||||||
// ext2Chain := stdChain.Extend(ext1Chain)
|
|
||||||
// // requests in stdChain go m1 -> m2
|
|
||||||
// // requests in ext1Chain go m3 -> m4
|
|
||||||
// // requests in ext2Chain go m1 -> m2 -> m3 -> m4
|
|
||||||
//
|
|
||||||
// Another example:
|
|
||||||
// aHtmlAfterNosurf := middleware.NewChain(m2)
|
|
||||||
// aHtml := middleware.NewChain(m1, func(h http.Handler) http.Handler {
|
|
||||||
// csrf := nosurf.NewChain(h)
|
|
||||||
// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail))
|
|
||||||
// return csrf
|
|
||||||
// }).Extend(aHtmlAfterNosurf)
|
|
||||||
// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-handler
|
|
||||||
// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail
|
|
||||||
func (c Chain) Extend(chain Chain) Chain {
|
|
||||||
return c.Append(chain.constructors...)
|
|
||||||
}
|
|
|
@ -1,177 +0,0 @@
|
||||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A constructor for middleware
|
|
||||||
// that writes its own "tag" into the RW and does nothing else.
|
|
||||||
// Useful in checking if a chain is behaving in the right order.
|
|
||||||
func tagMiddleware(tag string) Constructor {
|
|
||||||
return func(h http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Write([]byte(tag))
|
|
||||||
h.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
|
|
||||||
// but the best we can do.
|
|
||||||
func funcsEqual(f1, f2 interface{}) bool {
|
|
||||||
val1 := reflect.ValueOf(f1)
|
|
||||||
val2 := reflect.ValueOf(f2)
|
|
||||||
return val1.Pointer() == val2.Pointer()
|
|
||||||
}
|
|
||||||
|
|
||||||
var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Write([]byte("app\n"))
|
|
||||||
})
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
c1 := func(h http.Handler) http.Handler {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c2 := func(h http.Handler) http.Handler {
|
|
||||||
return http.StripPrefix("potato", nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
slice := []Constructor{c1, c2}
|
|
||||||
|
|
||||||
chain := NewChain(slice...)
|
|
||||||
for k := range slice {
|
|
||||||
if !funcsEqual(chain.constructors[k], slice[k]) {
|
|
||||||
t.Error("New does not add constructors correctly")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestThenWorksWithNoMiddleware(t *testing.T) {
|
|
||||||
if !funcsEqual(NewChain().Then(testApp), testApp) {
|
|
||||||
t.Error("Then does not work with no middleware")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestThenTreatsNilAsDefaultServeMux(t *testing.T) {
|
|
||||||
if NewChain().Then(nil) != http.DefaultServeMux {
|
|
||||||
t.Error("Then does not treat nil as DefaultServeMux")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) {
|
|
||||||
if NewChain().ThenFunc(nil) != http.DefaultServeMux {
|
|
||||||
t.Error("ThenFunc does not treat nil as DefaultServeMux")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestThenFuncConstructsHandlerFunc(t *testing.T) {
|
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
chained := NewChain().ThenFunc(fn)
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
|
|
||||||
chained.ServeHTTP(rec, (*http.Request)(nil))
|
|
||||||
|
|
||||||
if reflect.TypeOf(chained) != reflect.TypeOf(http.HandlerFunc(nil)) {
|
|
||||||
t.Error("ThenFunc does not construct HandlerFunc")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestThenOrdersHandlersCorrectly(t *testing.T) {
|
|
||||||
t1 := tagMiddleware("t1\n")
|
|
||||||
t2 := tagMiddleware("t2\n")
|
|
||||||
t3 := tagMiddleware("t3\n")
|
|
||||||
|
|
||||||
chained := NewChain(t1, t2, t3).Then(testApp)
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
r, err := http.NewRequest("GET", "/", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
chained.ServeHTTP(w, r)
|
|
||||||
|
|
||||||
if w.Body.String() != "t1\nt2\nt3\napp\n" {
|
|
||||||
t.Error("Then does not order handlers correctly")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAppendAddsHandlersCorrectly(t *testing.T) {
|
|
||||||
chain := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
|
||||||
newChain := chain.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
|
||||||
|
|
||||||
if len(chain.constructors) != 2 {
|
|
||||||
t.Error("chain should have 2 constructors")
|
|
||||||
}
|
|
||||||
if len(newChain.constructors) != 4 {
|
|
||||||
t.Error("newChain should have 4 constructors")
|
|
||||||
}
|
|
||||||
|
|
||||||
chained := newChain.Then(testApp)
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
r, err := http.NewRequest("GET", "/", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
chained.ServeHTTP(w, r)
|
|
||||||
|
|
||||||
if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" {
|
|
||||||
t.Error("Append does not add handlers correctly")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAppendRespectsImmutability(t *testing.T) {
|
|
||||||
chain := NewChain(tagMiddleware(""))
|
|
||||||
newChain := chain.Append(tagMiddleware(""))
|
|
||||||
|
|
||||||
if &chain.constructors[0] == &newChain.constructors[0] {
|
|
||||||
t.Error("Apppend does not respect immutability")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtendAddsHandlersCorrectly(t *testing.T) {
|
|
||||||
chain1 := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
|
||||||
chain2 := NewChain(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
|
||||||
newChain := chain1.Extend(chain2)
|
|
||||||
|
|
||||||
if len(chain1.constructors) != 2 {
|
|
||||||
t.Error("chain1 should contain 2 constructors")
|
|
||||||
}
|
|
||||||
if len(chain2.constructors) != 2 {
|
|
||||||
t.Error("chain2 should contain 2 constructors")
|
|
||||||
}
|
|
||||||
if len(newChain.constructors) != 4 {
|
|
||||||
t.Error("newChain should contain 4 constructors")
|
|
||||||
}
|
|
||||||
|
|
||||||
chained := newChain.Then(testApp)
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
r, err := http.NewRequest("GET", "/", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
chained.ServeHTTP(w, r)
|
|
||||||
|
|
||||||
if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" {
|
|
||||||
t.Error("Extend does not add handlers in correctly")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtendRespectsImmutability(t *testing.T) {
|
|
||||||
chain := NewChain(tagMiddleware(""))
|
|
||||||
newChain := chain.Extend(NewChain(tagMiddleware("")))
|
|
||||||
|
|
||||||
if &chain.constructors[0] == &newChain.constructors[0] {
|
|
||||||
t.Error("Extend does not respect immutability")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,2 +1,2 @@
|
||||||
// Package middleware provides a standard set of middleware implementations for pomerium.
|
// Package middleware provides a standard set of middleware for pomerium.
|
||||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
|
|
@ -79,12 +79,10 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
|
||||||
func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
domain := req.Host
|
domain := req.Host
|
||||||
|
|
||||||
if name == cs.csrfName() {
|
if cs.CookieDomain != "" {
|
||||||
domain = req.Host
|
|
||||||
} else if cs.CookieDomain != "" {
|
|
||||||
domain = cs.CookieDomain
|
domain = cs.CookieDomain
|
||||||
} else {
|
} else {
|
||||||
domain = splitDomain(domain)
|
domain = ParentSubdomain(domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
if h, _, err := net.SplitHostPort(domain); err == nil {
|
if h, _, err := net.SplitHostPort(domain); err == nil {
|
||||||
|
@ -105,19 +103,11 @@ func (cs *CookieStore) makeCookie(req *http.Request, name string, value string,
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CookieStore) csrfName() string {
|
|
||||||
return fmt.Sprintf("%s_csrf", cs.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
|
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
|
||||||
func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
return cs.makeCookie(req, cs.Name, value, expiration, now)
|
return cs.makeCookie(req, cs.Name, value, expiration, now)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
|
||||||
return cs.makeCookie(req, cs.csrfName(), value, expiration, now)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
||||||
if len(cookie.String()) <= MaxChunkSize {
|
if len(cookie.String()) <= MaxChunkSize {
|
||||||
http.SetCookie(w, cookie)
|
http.SetCookie(w, cookie)
|
||||||
|
@ -134,7 +124,6 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
||||||
nc.Name = fmt.Sprintf("%s_%d", cookie.Name, i)
|
nc.Name = fmt.Sprintf("%s_%d", cookie.Name, i)
|
||||||
nc.Value = c
|
nc.Value = c
|
||||||
}
|
}
|
||||||
fmt.Println(i)
|
|
||||||
http.SetCookie(w, &nc)
|
http.SetCookie(w, &nc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -150,25 +139,6 @@ func chunk(s string, size int) []string {
|
||||||
return ss
|
return ss
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearCSRF clears the CSRF cookie from the request
|
|
||||||
func (cs *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) {
|
|
||||||
http.SetCookie(w, cs.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
|
|
||||||
func (cs *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) {
|
|
||||||
http.SetCookie(w, cs.makeCSRFCookie(req, val, cs.CookieExpire, time.Now()))
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
|
|
||||||
func (cs *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
|
|
||||||
c, err := req.Cookie(cs.csrfName())
|
|
||||||
if err != nil {
|
|
||||||
return nil, ErrEmptyCSRF // ErrNoCookie is confusing in this context
|
|
||||||
}
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSession clears the session cookie from a request
|
// ClearSession clears the session cookie from a request
|
||||||
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
|
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
|
||||||
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
|
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
|
||||||
|
@ -235,7 +205,8 @@ func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func splitDomain(s string) string {
|
// ParentSubdomain returns the parent subdomain.
|
||||||
|
func ParentSubdomain(s string) string {
|
||||||
if strings.Count(s, ".") < 2 {
|
if strings.Count(s, ".") < 2 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package sessions
|
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
@ -38,7 +38,7 @@ func (a mockCipher) Unmarshal(s string, i interface{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func TestNewCookieStore(t *testing.T) {
|
func TestNewCookieStore(t *testing.T) {
|
||||||
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
|
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -111,7 +111,7 @@ func TestNewCookieStore(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCookieStore_makeCookie(t *testing.T) {
|
func TestCookieStore_makeCookie(t *testing.T) {
|
||||||
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
|
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -155,62 +155,13 @@ func TestCookieStore_makeCookie(t *testing.T) {
|
||||||
if diff := cmp.Diff(s.makeSessionCookie(r, tt.value, tt.expiration, now), tt.want); diff != "" {
|
if diff := cmp.Diff(s.makeSessionCookie(r, tt.value, tt.expiration, now), tt.want); diff != "" {
|
||||||
t.Errorf("CookieStore.makeSessionCookie() = \n%s", diff)
|
t.Errorf("CookieStore.makeSessionCookie() = \n%s", diff)
|
||||||
}
|
}
|
||||||
got := s.makeCSRFCookie(r, tt.value, tt.expiration, now)
|
|
||||||
tt.wantCSRF.Name = "_pomerium_csrf"
|
|
||||||
if !reflect.DeepEqual(got, tt.wantCSRF) {
|
|
||||||
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.wantCSRF)
|
|
||||||
}
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
want := "new-csrf"
|
|
||||||
s.SetCSRF(w, r, want)
|
|
||||||
found := false
|
|
||||||
for _, cookie := range w.Result().Cookies() {
|
|
||||||
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
t.Error("SetCSRF failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
w = httptest.NewRecorder()
|
|
||||||
s.ClearCSRF(w, r)
|
|
||||||
for _, cookie := range w.Result().Cookies() {
|
|
||||||
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
|
|
||||||
t.Error("clear csrf failed")
|
|
||||||
break
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w = httptest.NewRecorder()
|
|
||||||
want = "new-session"
|
|
||||||
s.setSessionCookie(w, r, want)
|
|
||||||
found = false
|
|
||||||
for _, cookie := range w.Result().Cookies() {
|
|
||||||
if cookie.Name == s.Name && cookie.Value == want {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
t.Error("SetCSRF failed")
|
|
||||||
}
|
|
||||||
w = httptest.NewRecorder()
|
|
||||||
s.ClearSession(w, r)
|
|
||||||
for _, cookie := range w.Result().Cookies() {
|
|
||||||
if cookie.Name == s.Name && cookie.Value == want {
|
|
||||||
t.Error("clear csrf failed")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCookieStore_SaveSession(t *testing.T) {
|
func TestCookieStore_SaveSession(t *testing.T) {
|
||||||
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
|
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -265,38 +216,6 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMockCSRFStore(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
mockCSRF *MockCSRFStore
|
|
||||||
newCSRFValue string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"basic",
|
|
||||||
&MockCSRFStore{
|
|
||||||
ResponseCSRF: "ok",
|
|
||||||
Cookie: &http.Cookie{Name: "hi"}},
|
|
||||||
"newcsrf",
|
|
||||||
false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
ms := tt.mockCSRF
|
|
||||||
ms.SetCSRF(nil, nil, tt.newCSRFValue)
|
|
||||||
ms.ClearCSRF(nil, nil)
|
|
||||||
got, err := ms.GetCSRF(nil)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.mockCSRF.Cookie) {
|
|
||||||
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMockSessionStore(t *testing.T) {
|
func TestMockSessionStore(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -341,7 +260,7 @@ func TestMockSessionStore(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_splitDomain(t *testing.T) {
|
func Test_ParentSubdomain(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
s string
|
s string
|
||||||
|
@ -354,8 +273,8 @@ func Test_splitDomain(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.s, func(t *testing.T) {
|
t.Run(tt.s, func(t *testing.T) {
|
||||||
if got := splitDomain(tt.s); got != tt.want {
|
if got := ParentSubdomain(tt.s); got != tt.want {
|
||||||
t.Errorf("splitDomain() = %v, want %v", got, tt.want)
|
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
130
internal/sessions/middleware.go
Normal file
130
internal/sessions/middleware.go
Normal file
|
@ -0,0 +1,130 @@
|
||||||
|
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Context keys
|
||||||
|
var (
|
||||||
|
SessionCtxKey = &contextKey{"Session"}
|
||||||
|
ErrorCtxKey = &contextKey{"Error"}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Library errors
|
||||||
|
var (
|
||||||
|
ErrExpired = errors.New("internal/sessions: session is expired")
|
||||||
|
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
|
||||||
|
ErrMalformed = errors.New("internal/sessions: session is malformed")
|
||||||
|
)
|
||||||
|
|
||||||
|
// RetrieveSession http middleware handler will verify a auth session from a http request.
|
||||||
|
//
|
||||||
|
// RetrieveSession will search for a auth session in a http request, in the order:
|
||||||
|
// 1. `pomerium_session` URI query parameter
|
||||||
|
// 2. `Authorization: BEARER` request header
|
||||||
|
// 3. Cookie `_pomerium` value
|
||||||
|
func RetrieveSession(s SessionStore) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return retrieve(s, TokenFromQuery, TokenFromHeader, TokenFromCookie)(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func retrieve(s SessionStore, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
hfn := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
token, err := retrieveFromRequest(s, r, findTokenFns...)
|
||||||
|
ctx = NewContext(ctx, token, err)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
}
|
||||||
|
return http.HandlerFunc(hfn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func retrieveFromRequest(s SessionStore, r *http.Request, findTokenFns ...func(r *http.Request) string) (*State, error) {
|
||||||
|
var tokenStr string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Extract token string from the request by calling token find functions in
|
||||||
|
// the order they where provided. Further extraction stops if a function
|
||||||
|
// returns a non-empty string.
|
||||||
|
for _, fn := range findTokenFns {
|
||||||
|
tokenStr = fn(r)
|
||||||
|
if tokenStr != "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tokenStr == "" {
|
||||||
|
return nil, ErrNoSessionFound
|
||||||
|
}
|
||||||
|
|
||||||
|
state, err := s.LoadSession(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrMalformed
|
||||||
|
}
|
||||||
|
err = state.Valid()
|
||||||
|
if err != nil {
|
||||||
|
// a little unusual but we want to return the expired state too
|
||||||
|
return state, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid!
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewContext sets context values for the user session state and error.
|
||||||
|
func NewContext(ctx context.Context, t *State, err error) context.Context {
|
||||||
|
ctx = context.WithValue(ctx, SessionCtxKey, t)
|
||||||
|
ctx = context.WithValue(ctx, ErrorCtxKey, err)
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromContext retrieves context values for the user session state and error.
|
||||||
|
func FromContext(ctx context.Context) (*State, error) {
|
||||||
|
state, _ := ctx.Value(SessionCtxKey).(*State)
|
||||||
|
err, _ := ctx.Value(ErrorCtxKey).(error)
|
||||||
|
return state, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenFromCookie tries to retrieve the token string from a cookie named
|
||||||
|
// "_pomerium".
|
||||||
|
func TokenFromCookie(r *http.Request) string {
|
||||||
|
cookie, err := r.Cookie("_pomerium")
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return cookie.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenFromHeader tries to retrieve the token string from the
|
||||||
|
// "Authorization" request header: "Authorization: BEARER T".
|
||||||
|
func TokenFromHeader(r *http.Request) string {
|
||||||
|
// Get token from authorization header.
|
||||||
|
bearer := r.Header.Get("Authorization")
|
||||||
|
if len(bearer) > 7 && strings.EqualFold(bearer[0:6], "BEARER") {
|
||||||
|
return bearer[7:]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenFromQuery tries to retrieve the token string from the "pomerium_session" URI
|
||||||
|
// query parameter.
|
||||||
|
// todo(bdd) : document setting session code as queryparam
|
||||||
|
func TokenFromQuery(r *http.Request) string {
|
||||||
|
// Get token from query param named "pomerium_session".
|
||||||
|
return r.URL.Query().Get("pomerium_session")
|
||||||
|
}
|
||||||
|
|
||||||
|
// contextKey is a value for use with context.WithValue. It's used as
|
||||||
|
// a pointer so it fits in an interface{} without allocation. This technique
|
||||||
|
// for defining context keys was copied from Go 1.7's new use of context in net/http.
|
||||||
|
type contextKey struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *contextKey) String() string {
|
||||||
|
return "SessionStore context value " + k.name
|
||||||
|
}
|
133
internal/sessions/middleware_test.go
Normal file
133
internal/sessions/middleware_test.go
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
package sessions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewContext(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
t *State
|
||||||
|
err error
|
||||||
|
want context.Context
|
||||||
|
}{
|
||||||
|
{"simple", context.Background(), &State{Email: "bdd@pomerium.io"}, nil, nil},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctxOut := NewContext(tt.ctx, tt.t, tt.err)
|
||||||
|
stateOut, errOut := FromContext(ctxOut)
|
||||||
|
if diff := cmp.Diff(tt.t, stateOut); diff != "" {
|
||||||
|
t.Errorf("NewContext() = %s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tt.err, errOut); diff != "" {
|
||||||
|
t.Errorf("NewContext() = %s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAuthorizer(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, err := FromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifier(t *testing.T) {
|
||||||
|
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
// s SessionStore
|
||||||
|
state State
|
||||||
|
|
||||||
|
cookie bool
|
||||||
|
header bool
|
||||||
|
param bool
|
||||||
|
|
||||||
|
wantBody string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{"good cookie session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK},
|
||||||
|
{"expired cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||||
|
{"malformed cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||||
|
{"good auth header session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK},
|
||||||
|
{"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||||
|
{"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||||
|
{"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
|
||||||
|
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||||
|
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||||
|
{"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cipher, err := cryptutil.NewCipherFromBase64(cryptutil.NewBase64Key())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
encSession, err := MarshalSession(&tt.state, cipher)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if strings.Contains(tt.name, "malformed") {
|
||||||
|
// add some garbage to the end of the string
|
||||||
|
encSession += cryptutil.NewBase64Key()
|
||||||
|
fmt.Println(encSession)
|
||||||
|
}
|
||||||
|
|
||||||
|
cs, err := NewCookieStore(&CookieStoreOptions{
|
||||||
|
Name: "_pomerium",
|
||||||
|
CookieCipher: cipher,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
if tt.cookie {
|
||||||
|
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: encSession})
|
||||||
|
} else if tt.header {
|
||||||
|
r.Header.Set("Authorization", "Bearer "+encSession)
|
||||||
|
} else if tt.param {
|
||||||
|
q := r.URL.Query()
|
||||||
|
q.Add("pomerium_session", encSession)
|
||||||
|
r.URL.RawQuery = q.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
got := RetrieveSession(cs)(testAuthorizer((fnh)))
|
||||||
|
got.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
gotBody := w.Body.String()
|
||||||
|
gotStatus := w.Result().StatusCode
|
||||||
|
|
||||||
|
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %v", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %v", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,28 +4,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockCSRFStore is a mock implementation of the CSRF store interface
|
|
||||||
type MockCSRFStore struct {
|
|
||||||
ResponseCSRF string
|
|
||||||
Cookie *http.Cookie
|
|
||||||
GetError error
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCSRF sets the ResponseCSRF string to a val
|
|
||||||
func (ms MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
|
|
||||||
ms.ResponseCSRF = val
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearCSRF clears the ResponseCSRF string
|
|
||||||
func (ms MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
|
|
||||||
ms.ResponseCSRF = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCSRF returns the cookie and error
|
|
||||||
func (ms MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
|
|
||||||
return ms.Cookie, ms.GetError
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockSessionStore is a mock implementation of the SessionStore interface
|
// MockSessionStore is a mock implementation of the SessionStore interface
|
||||||
type MockSessionStore struct {
|
type MockSessionStore struct {
|
||||||
ResponseSession string
|
ResponseSession string
|
||||||
|
|
|
@ -10,9 +10,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrExpired is an error for a expired sessions.
|
|
||||||
var ErrExpired = fmt.Errorf("internal/sessions: expired session")
|
|
||||||
|
|
||||||
// State is our object that keeps track of a user's session state
|
// State is our object that keeps track of a user's session state
|
||||||
type State struct {
|
type State struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStateSerialization(t *testing.T) {
|
func TestStateSerialization(t *testing.T) {
|
||||||
secret := cryptutil.GenerateKey()
|
secret := cryptutil.NewKey()
|
||||||
c, err := cryptutil.NewCipher(secret)
|
c, err := cryptutil.NewCipher(secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||||
|
@ -123,7 +123,7 @@ func TestState_Impersonating(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMarshalSession(t *testing.T) {
|
func TestMarshalSession(t *testing.T) {
|
||||||
secret := cryptutil.GenerateKey()
|
secret := cryptutil.NewKey()
|
||||||
c, err := cryptutil.NewCipher(secret)
|
c, err := cryptutil.NewCipher(secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||||
|
|
|
@ -8,16 +8,6 @@ import (
|
||||||
// ErrEmptySession is an error for an empty sessions.
|
// ErrEmptySession is an error for an empty sessions.
|
||||||
var ErrEmptySession = errors.New("internal/sessions: empty session")
|
var ErrEmptySession = errors.New("internal/sessions: empty session")
|
||||||
|
|
||||||
// ErrEmptyCSRF is an error for an empty sessions.
|
|
||||||
var ErrEmptyCSRF = errors.New("internal/sessions: empty csrf")
|
|
||||||
|
|
||||||
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
|
|
||||||
type CSRFStore interface {
|
|
||||||
SetCSRF(http.ResponseWriter, *http.Request, string)
|
|
||||||
GetCSRF(*http.Request) (*http.Cookie, error)
|
|
||||||
ClearCSRF(http.ResponseWriter, *http.Request)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
||||||
type SessionStore interface {
|
type SessionStore interface {
|
||||||
ClearSession(http.ResponseWriter, *http.Request)
|
ClearSession(http.ResponseWriter, *http.Request)
|
||||||
|
|
|
@ -306,7 +306,7 @@ func New() *template.Template {
|
||||||
</svg>
|
</svg>
|
||||||
<form method="POST" action="{{.SignoutURL}}">
|
<form method="POST" action="{{.SignoutURL}}">
|
||||||
<section>
|
<section>
|
||||||
<h2>Session</h2>
|
<h2>Current user</h2>
|
||||||
<p class="message">Your current session details.</p>
|
<p class="message">Your current session details.</p>
|
||||||
<fieldset>
|
<fieldset>
|
||||||
<label>
|
<label>
|
||||||
|
@ -334,10 +334,22 @@ func New() *template.Template {
|
||||||
</fieldset>
|
</fieldset>
|
||||||
</section>
|
</section>
|
||||||
<div class="flex">
|
<div class="flex">
|
||||||
<button class="button half" type="submit">Sign Out</button>
|
{{ .csrfField }}
|
||||||
<a href="/.pomerium/refresh" class="button half">Refresh</a>
|
<button class="button full" type="submit">Sign Out</button>
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</form>
|
||||||
|
<section>
|
||||||
|
<h2>Refresh Identity</h2>
|
||||||
|
<p class="message">Pomerium will automatically refresh your user session. However, if your group memberships have recently changed and haven't taken effect yet, you can refresh your session manually.</p>
|
||||||
|
|
||||||
|
<form method="POST" action="/.pomerium/refresh">
|
||||||
|
<div class="flex">
|
||||||
|
{{ .csrfField }}
|
||||||
|
<button class="button full" type="submit">Refresh</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</section>
|
||||||
|
|
||||||
{{if .IsAdmin}}
|
{{if .IsAdmin}}
|
||||||
<form method="POST" action="/.pomerium/impersonate">
|
<form method="POST" action="/.pomerium/impersonate">
|
||||||
<section>
|
<section>
|
||||||
|
@ -355,7 +367,7 @@ func New() *template.Template {
|
||||||
</fieldset>
|
</fieldset>
|
||||||
</section>
|
</section>
|
||||||
<div class="flex">
|
<div class="flex">
|
||||||
<input name="csrf" type="hidden" value="{{.CSRF}}">
|
{{ .csrfField }}
|
||||||
<button class="button full" type="submit">Impersonate session</button>
|
<button class="button full" type="submit">Impersonate session</button>
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</form>
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
|
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StripPort returns a host, without any port number.
|
// StripPort returns a host, without any port number.
|
||||||
|
@ -32,18 +37,73 @@ func ParseAndValidateURL(rawurl string) (*url.URL, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if u.Scheme == "" {
|
if err := ValidateURL(u); err != nil {
|
||||||
return nil, fmt.Errorf("%s url does contain a valid scheme. Did you mean https://%s?", rawurl, rawurl)
|
return nil, err
|
||||||
}
|
|
||||||
if u.Host == "" {
|
|
||||||
return nil, fmt.Errorf("%s url does contain a valid hostname", rawurl)
|
|
||||||
}
|
}
|
||||||
return u, nil
|
return u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateURL wraps standard library's default url.Parse because
|
||||||
|
// it's much more lenient about what type of urls it accepts than pomerium.
|
||||||
|
func ValidateURL(u *url.URL) error {
|
||||||
|
if u == nil {
|
||||||
|
return fmt.Errorf("nil url")
|
||||||
|
}
|
||||||
|
if u.Scheme == "" {
|
||||||
|
return fmt.Errorf("%s url does contain a valid scheme. Did you mean https://%s?", u.String(), u.String())
|
||||||
|
}
|
||||||
|
if u.Host == "" {
|
||||||
|
return fmt.Errorf("%s url does contain a valid hostname", u.String())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func DeepCopy(u *url.URL) (*url.URL, error) {
|
func DeepCopy(u *url.URL) (*url.URL, error) {
|
||||||
if u == nil {
|
if u == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return ParseAndValidateURL(u.String())
|
return ParseAndValidateURL(u.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testTimeNow can be used in tests to set a specific int64 time
|
||||||
|
var testTimeNow int64
|
||||||
|
|
||||||
|
// timestamp returns the current timestamp, in seconds.
|
||||||
|
//
|
||||||
|
// For testing purposes, the function that generates the timestamp can be
|
||||||
|
// overridden. If not set, it will return time.Now().UTC().Unix().
|
||||||
|
func timestamp() int64 {
|
||||||
|
if testTimeNow == 0 {
|
||||||
|
return time.Now().UTC().Unix()
|
||||||
|
}
|
||||||
|
return testTimeNow
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignedRedirectURL takes a destination URL and adds redirect_uri to it's
|
||||||
|
// query params, along with a timestamp and an keyed signature.
|
||||||
|
func SignedRedirectURL(key string, destination, urlToSign *url.URL) *url.URL {
|
||||||
|
now := timestamp()
|
||||||
|
rawURL := urlToSign.String()
|
||||||
|
params, _ := url.ParseQuery(destination.RawQuery) // handled by incoming mux
|
||||||
|
params.Set("redirect_uri", rawURL)
|
||||||
|
params.Set("ts", fmt.Sprint(now))
|
||||||
|
params.Set("sig", hmacURL(key, rawURL, now))
|
||||||
|
destination.RawQuery = params.Encode()
|
||||||
|
return destination
|
||||||
|
}
|
||||||
|
|
||||||
|
// hmacURL takes a redirect url string and timestamp and returns the base64
|
||||||
|
// encoded HMAC result.
|
||||||
|
func hmacURL(key, data string, timestamp int64) string {
|
||||||
|
h := cryptutil.Hash(key, []byte(fmt.Sprint(data, timestamp)))
|
||||||
|
return base64.URLEncoding.EncodeToString(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAbsoluteURL returns the current handler's absolute url.
|
||||||
|
// https://stackoverflow.com/a/23152483
|
||||||
|
func GetAbsoluteURL(r *http.Request) *url.URL {
|
||||||
|
u := r.URL
|
||||||
|
u.Scheme = "https"
|
||||||
|
u.Host = r.Host
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package urlutil
|
package urlutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -35,7 +36,7 @@ func Test_StripPort(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseAndValidateURL(t *testing.T) {
|
func TestParseAndValidateURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
rawurl string
|
rawurl string
|
||||||
|
@ -63,7 +64,7 @@ func TestParseAndValidateURL(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeepCopy(t *testing.T) {
|
func TestDeepCopy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
u *url.URL
|
u *url.URL
|
||||||
|
@ -87,3 +88,90 @@ func TestDeepCopy(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
u *url.URL
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"good", &url.URL{Scheme: "https", Host: "some.example"}, false},
|
||||||
|
{"nil", nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := ValidateURL(tt.u); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ValidateURL() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignedRedirectURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mockedTime int64
|
||||||
|
key string
|
||||||
|
destination *url.URL
|
||||||
|
urlToSign *url.URL
|
||||||
|
want *url.URL
|
||||||
|
}{
|
||||||
|
{"good", 2, "hunter42", &url.URL{Host: "pomerium.io", Scheme: "https://"}, &url.URL{Host: "pomerium.io", Scheme: "https://", Path: "/ok"}, &url.URL{Host: "pomerium.io", Scheme: "https://", RawQuery: "redirect_uri=https%3A%2F%2F%3A%2F%2Fpomerium.io%2Fok&sig=7jdo1XFcmuhjBHnpfVhll5cXflYByeMnbp5kRz87CVQ%3D&ts=2"}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
testTimeNow = tt.mockedTime
|
||||||
|
got := SignedRedirectURL(tt.key, tt.destination, tt.urlToSign)
|
||||||
|
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||||
|
t.Errorf("SignedRedirectURL() = diff %v", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_timestamp(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dontWant int64
|
||||||
|
}{
|
||||||
|
{"if unset should never return", 0},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
testTimeNow = tt.dontWant
|
||||||
|
if got := timestamp(); got == tt.dontWant {
|
||||||
|
t.Errorf("timestamp() = %v, dontWant %v", got, tt.dontWant)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseURLHelper(s string) *url.URL {
|
||||||
|
u, _ := url.Parse(s)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAbsoluteURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
u *url.URL
|
||||||
|
want *url.URL
|
||||||
|
}{
|
||||||
|
{"add https", parseURLHelper("http://pomerium.io"), parseURLHelper("https://pomerium.io")},
|
||||||
|
{"missing scheme", parseURLHelper("https://pomerium.io"), parseURLHelper("https://pomerium.io")},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r := http.Request{URL: tt.u, Host: tt.u.Host}
|
||||||
|
got := GetAbsoluteURL(&r)
|
||||||
|
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||||
|
t.Errorf("GetAbsoluteURL() = %v", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/csrf"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/config"
|
"github.com/pomerium/pomerium/internal/config"
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
@ -18,34 +18,55 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StateParameter holds the redirect id along with the session id.
|
|
||||||
type StateParameter struct {
|
|
||||||
SessionID string `json:"session_id"`
|
|
||||||
RedirectURI string `json:"redirect_uri"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handler returns the proxy service's ServeMux
|
// Handler returns the proxy service's ServeMux
|
||||||
func (p *Proxy) Handler() http.Handler {
|
func (p *Proxy) Handler() http.Handler {
|
||||||
// validation middleware chain
|
r := httputil.NewRouter().StrictSlash(true)
|
||||||
validate := middleware.NewChain()
|
r.Use(middleware.ValidateHost(func(host string) bool {
|
||||||
validate = validate.Append(middleware.ValidateHost(func(host string) bool {
|
|
||||||
_, ok := p.routeConfigs[host]
|
_, ok := p.routeConfigs[host]
|
||||||
return ok
|
return ok
|
||||||
}))
|
}))
|
||||||
mux := http.NewServeMux()
|
r.Use(csrf.Protect(
|
||||||
mux.HandleFunc("/robots.txt", p.RobotsTxt)
|
p.cookieSecret,
|
||||||
mux.HandleFunc("/.pomerium", p.UserDashboard)
|
csrf.Path("/"),
|
||||||
mux.HandleFunc("/.pomerium/impersonate", p.Impersonate) // POST
|
csrf.Domain(p.cookieDomain),
|
||||||
mux.HandleFunc("/.pomerium/sign_out", p.SignOut)
|
csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieName)),
|
||||||
// handlers with validation
|
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
|
||||||
mux.Handle("/.pomerium/callback", validate.ThenFunc(p.AuthenticateCallback))
|
))
|
||||||
mux.Handle("/.pomerium/refresh", validate.ThenFunc(p.ForceRefresh))
|
r.HandleFunc("/robots.txt", p.RobotsTxt)
|
||||||
mux.Handle("/", validate.ThenFunc(p.Proxy))
|
// requires authN not authZ
|
||||||
return mux
|
r.Use(sessions.RetrieveSession(p.sessionStore))
|
||||||
|
r.Use(p.VerifySession)
|
||||||
|
r.HandleFunc("/.pomerium/", p.UserDashboard).Methods(http.MethodGet)
|
||||||
|
r.HandleFunc("/.pomerium/impersonate", p.Impersonate).Methods(http.MethodPost)
|
||||||
|
r.HandleFunc("/.pomerium/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost)
|
||||||
|
r.HandleFunc("/.pomerium/refresh", p.ForceRefresh).Methods(http.MethodPost)
|
||||||
|
r.PathPrefix("/").HandlerFunc(p.Proxy)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifySession is the middleware used to enforce a valid authentication
|
||||||
|
// session state is attached to the users's request context.
|
||||||
|
func (p *Proxy) VerifySession(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, err := sessions.FromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error")
|
||||||
|
p.authenticate(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := state.Valid(); err != nil {
|
||||||
|
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to invalid session")
|
||||||
|
p.authenticate(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RobotsTxt sets the User-Agent header in the response to be "Disallow"
|
// RobotsTxt sets the User-Agent header in the response to be "Disallow"
|
||||||
func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
|
func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
||||||
}
|
}
|
||||||
|
@ -55,110 +76,18 @@ func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
|
||||||
// the local session state.
|
// the local session state.
|
||||||
func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
|
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
|
||||||
switch r.Method {
|
if uri, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")); err == nil && uri.String() != "" {
|
||||||
case http.MethodPost:
|
redirectURL = uri
|
||||||
if err := r.ParseForm(); err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
uri, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
|
|
||||||
if err == nil && uri.String() != "" {
|
|
||||||
redirectURL = uri
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
uri, err := urlutil.ParseAndValidateURL(r.URL.Query().Get("redirect_uri"))
|
|
||||||
if err == nil && uri.String() != "" {
|
|
||||||
redirectURL = uri
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
http.Redirect(w, r, p.GetSignOutURL(p.authenticateURL, redirectURL).String(), http.StatusFound)
|
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
|
||||||
|
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuthStart begins the authenticate flow, encrypting the redirect url
|
// Authenticate begins the authenticate flow, encrypting the redirect url
|
||||||
// in a request to the provider's sign in endpoint.
|
// in a request to the provider's sign in endpoint.
|
||||||
func (p *Proxy) OAuthStart(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request) {
|
||||||
state := &StateParameter{
|
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
||||||
SessionID: fmt.Sprintf("%x", cryptutil.GenerateKey()),
|
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||||
RedirectURI: r.URL.String(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypt CSRF + redirect_uri and store in csrf session. Validated on callback.
|
|
||||||
csrfState, err := p.cipher.Marshal(state)
|
|
||||||
if err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p.csrfStore.SetCSRF(w, r, csrfState)
|
|
||||||
|
|
||||||
paramState, err := p.cipher.Marshal(state)
|
|
||||||
if err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sanity check. The encrypted payload of local and remote state should
|
|
||||||
// never match as each encryption round uses a cryptographic nonce.
|
|
||||||
// if paramState == csrfState {
|
|
||||||
// httputil.ErrorResponse(w, r, httputil.Error("encrypted state should not match", http.StatusBadRequest, nil))
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
signinURL := p.GetSignInURL(p.authenticateURL, p.GetRedirectURL(r.Host), paramState)
|
|
||||||
|
|
||||||
// Redirect the user to the authenticate service along with the encrypted
|
|
||||||
// state which contains a redirect uri back to the proxy and a nonce
|
|
||||||
http.Redirect(w, r, signinURL.String(), http.StatusFound)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthenticateCallback checks the state parameter to make sure it matches the
|
|
||||||
// local csrf state then redirects the user back to the original intended route.
|
|
||||||
func (p *Proxy) AuthenticateCallback(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if err := r.ParseForm(); err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypted CSRF passed from authenticate service
|
|
||||||
remoteStateEncrypted := r.Form.Get("state")
|
|
||||||
var remoteStatePlain StateParameter
|
|
||||||
if err := p.cipher.Unmarshal(remoteStateEncrypted, &remoteStatePlain); err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c, err := p.csrfStore.GetCSRF(r)
|
|
||||||
if err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p.csrfStore.ClearCSRF(w, r)
|
|
||||||
|
|
||||||
localStateEncrypted := c.Value
|
|
||||||
var localStatePlain StateParameter
|
|
||||||
err = p.cipher.Unmarshal(localStateEncrypted, &localStatePlain)
|
|
||||||
if err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// assert no nonce reuse
|
|
||||||
if remoteStateEncrypted == localStateEncrypted {
|
|
||||||
p.sessionStore.ClearSession(w, r)
|
|
||||||
httputil.ErrorResponse(w, r,
|
|
||||||
httputil.Error("local and remote state", http.StatusBadRequest,
|
|
||||||
fmt.Errorf("possible nonce-reuse / replay attack")))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decrypted remote and local state struct (inc. nonce) must match
|
|
||||||
if remoteStatePlain.SessionID != localStatePlain.SessionID {
|
|
||||||
p.sessionStore.ClearSession(w, r)
|
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("CSRF mismatch", http.StatusBadRequest, nil))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is the redirect back to the original requested application
|
|
||||||
http.Redirect(w, r, remoteStatePlain.RedirectURI, http.StatusFound)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// shouldSkipAuthentication contains conditions for skipping authentication.
|
// shouldSkipAuthentication contains conditions for skipping authentication.
|
||||||
|
@ -189,17 +118,6 @@ func isCORSPreflight(r *http.Request) bool {
|
||||||
r.Header.Get("Origin") != ""
|
r.Header.Get("Origin") != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) loadExistingSession(r *http.Request) (*sessions.State, error) {
|
|
||||||
s, err := p.sessionStore.LoadSession(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("proxy: invalid session: %w", err)
|
|
||||||
}
|
|
||||||
if err := s.Valid(); err != nil {
|
|
||||||
return nil, fmt.Errorf("proxy: invalid state: %w", err)
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proxy authenticates a request, either proxying the request if it is authenticated,
|
// Proxy authenticates a request, either proxying the request if it is authenticated,
|
||||||
// or starting the authenticate service for validation if not.
|
// or starting the authenticate service for validation if not.
|
||||||
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -214,11 +132,10 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||||
route.ServeHTTP(w, r)
|
route.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
s, err := sessions.FromContext(r.Context())
|
||||||
s, err := p.loadExistingSession(r)
|
if err != nil || s == nil {
|
||||||
if err != nil {
|
log.Debug().Err(err).Msg("proxy: couldn't get session from context")
|
||||||
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
|
p.authenticate(w, r)
|
||||||
p.OAuthStart(w, r)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s)
|
authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s)
|
||||||
|
@ -226,7 +143,7 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||||
httputil.ErrorResponse(w, r, err)
|
httputil.ErrorResponse(w, r, err)
|
||||||
return
|
return
|
||||||
} else if !authorized {
|
} else if !authorized {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.Email), http.StatusForbidden, nil))
|
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.RequestEmail()), http.StatusForbidden, nil))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.Header.Set(HeaderUserID, s.User)
|
r.Header.Set(HeaderUserID, s.User)
|
||||||
|
@ -240,62 +157,41 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||||
// It also contains certain administrative actions like user impersonation.
|
// It also contains certain administrative actions like user impersonation.
|
||||||
// Nota bene: This endpoint does authentication, not authorization.
|
// Nota bene: This endpoint does authentication, not authorization.
|
||||||
func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
|
||||||
session, err := p.loadExistingSession(r)
|
session, err := sessions.FromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
|
httputil.ErrorResponse(w, r, err)
|
||||||
p.OAuthStart(w, r)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/.pomerium/sign_out"}
|
|
||||||
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
|
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, err)
|
httputil.ErrorResponse(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
//todo(bdd): make sign out redirect a configuration option so that
|
||||||
// CSRF value used to mitigate replay attacks.
|
// admins can set to whatever their corporate homepage is
|
||||||
csrf := &StateParameter{SessionID: fmt.Sprintf("%x", cryptutil.GenerateKey())}
|
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
|
||||||
csrfCookie, err := p.cipher.Marshal(csrf)
|
signoutURL := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
|
||||||
if err != nil {
|
templates.New().ExecuteTemplate(w, "dashboard.html", map[string]interface{}{
|
||||||
httputil.ErrorResponse(w, r, err)
|
"Email": session.Email,
|
||||||
return
|
"User": session.User,
|
||||||
}
|
"Groups": session.Groups,
|
||||||
p.csrfStore.SetCSRF(w, r, csrfCookie)
|
"RefreshDeadline": time.Until(session.RefreshDeadline).Round(time.Second).String(),
|
||||||
|
"SignoutURL": signoutURL.String(),
|
||||||
t := struct {
|
"IsAdmin": isAdmin,
|
||||||
Email string
|
"ImpersonateEmail": session.ImpersonateEmail,
|
||||||
User string
|
"ImpersonateGroup": strings.Join(session.ImpersonateGroups, ","),
|
||||||
Groups []string
|
"csrfField": csrf.TemplateField(r),
|
||||||
RefreshDeadline string
|
})
|
||||||
SignoutURL string
|
|
||||||
|
|
||||||
IsAdmin bool
|
|
||||||
ImpersonateEmail string
|
|
||||||
ImpersonateGroup string
|
|
||||||
CSRF string
|
|
||||||
}{
|
|
||||||
Email: session.Email,
|
|
||||||
User: session.User,
|
|
||||||
Groups: session.Groups,
|
|
||||||
RefreshDeadline: time.Until(session.RefreshDeadline).Round(time.Second).String(),
|
|
||||||
SignoutURL: p.GetSignOutURL(p.authenticateURL, redirectURL).String(),
|
|
||||||
IsAdmin: isAdmin,
|
|
||||||
ImpersonateEmail: session.ImpersonateEmail,
|
|
||||||
ImpersonateGroup: strings.Join(session.ImpersonateGroups, ","),
|
|
||||||
CSRF: csrf.SessionID,
|
|
||||||
}
|
|
||||||
templates.New().ExecuteTemplate(w, "dashboard.html", t)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForceRefresh redeems and extends an existing authenticated oidc session with
|
// ForceRefresh redeems and extends an existing authenticated oidc session with
|
||||||
// the underlying identity provider. All session details including groups,
|
// the underlying identity provider. All session details including groups,
|
||||||
// timeouts, will be renewed.
|
// timeouts, will be renewed.
|
||||||
func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
|
||||||
session, err := p.loadExistingSession(r)
|
session, err := sessions.FromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
|
httputil.ErrorResponse(w, r, err)
|
||||||
p.OAuthStart(w, r)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
iss, err := session.IssuedAt()
|
iss, err := session.IssuedAt()
|
||||||
|
@ -324,49 +220,25 @@ func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
|
||||||
// to the user's current user sessions state if the user is currently an
|
// to the user's current user sessions state if the user is currently an
|
||||||
// administrative user. Requests are redirected back to the user dashboard.
|
// administrative user. Requests are redirected back to the user dashboard.
|
||||||
func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method == http.MethodPost {
|
session, err := sessions.FromContext(r.Context())
|
||||||
if err := r.ParseForm(); err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, err)
|
httputil.ErrorResponse(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
|
||||||
session, err := p.loadExistingSession(r)
|
|
||||||
if err != nil {
|
|
||||||
log.Debug().Str("cause", err.Error()).Msg("proxy: bad authN session, redirecting")
|
|
||||||
p.OAuthStart(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
|
|
||||||
if err != nil || !isAdmin {
|
|
||||||
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.Email), http.StatusForbidden, err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// CSRF check -- did this request originate from our form?
|
|
||||||
c, err := p.csrfStore.GetCSRF(r)
|
|
||||||
if err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p.csrfStore.ClearCSRF(w, r)
|
|
||||||
encryptedCSRF := c.Value
|
|
||||||
var decryptedCSRF StateParameter
|
|
||||||
if err = p.cipher.Unmarshal(encryptedCSRF, decryptedCSRF); err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if decryptedCSRF.SessionID != r.FormValue("csrf") {
|
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("CSRF mismatch", http.StatusBadRequest, nil))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// OK to impersonation
|
|
||||||
session.ImpersonateEmail = r.FormValue("email")
|
|
||||||
session.ImpersonateGroups = strings.Split(r.FormValue("group"), ",")
|
|
||||||
|
|
||||||
if err := p.sessionStore.SaveSession(w, r, session); err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
|
||||||
|
if err != nil || !isAdmin {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.RequestEmail()), http.StatusForbidden, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// OK to impersonation
|
||||||
|
session.ImpersonateEmail = r.FormValue("email")
|
||||||
|
session.ImpersonateGroups = strings.Split(r.FormValue("group"), ",")
|
||||||
|
|
||||||
|
if err := p.sessionStore.SaveSession(w, r, session); err != nil {
|
||||||
|
httputil.ErrorResponse(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
http.Redirect(w, r, "/.pomerium", http.StatusFound)
|
http.Redirect(w, r, "/.pomerium", http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -391,48 +263,3 @@ func (p *Proxy) policy(r *http.Request) (*config.Policy, bool) {
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRedirectURL returns the redirect url for a single reverse proxy host. HTTPS is set explicitly.
|
|
||||||
func (p *Proxy) GetRedirectURL(host string) *url.URL {
|
|
||||||
u := p.redirectURL
|
|
||||||
u.Scheme = "https"
|
|
||||||
u.Host = host
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// signRedirectURL takes a redirect url string and timestamp and returns the base64
|
|
||||||
// encoded HMAC result.
|
|
||||||
func (p *Proxy) signRedirectURL(rawRedirect string, timestamp time.Time) string {
|
|
||||||
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
|
|
||||||
h := cryptutil.Hash(p.SharedKey, data)
|
|
||||||
return base64.URLEncoding.EncodeToString(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSignInURL with typical oauth parameters
|
|
||||||
func (p *Proxy) GetSignInURL(authenticateURL, redirectURL *url.URL, state string) *url.URL {
|
|
||||||
a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_in"})
|
|
||||||
now := time.Now()
|
|
||||||
rawRedirect := redirectURL.String()
|
|
||||||
params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux
|
|
||||||
params.Set("redirect_uri", rawRedirect)
|
|
||||||
params.Set("shared_secret", p.SharedKey)
|
|
||||||
params.Set("response_type", "code")
|
|
||||||
params.Add("state", state)
|
|
||||||
params.Set("ts", fmt.Sprint(now.Unix()))
|
|
||||||
params.Set("sig", p.signRedirectURL(rawRedirect, now))
|
|
||||||
a.RawQuery = params.Encode()
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSignOutURL creates and returns the sign out URL, given a redirectURL
|
|
||||||
func (p *Proxy) GetSignOutURL(authenticateURL, redirectURL *url.URL) *url.URL {
|
|
||||||
a := authenticateURL.ResolveReference(&url.URL{Path: "/sign_out"})
|
|
||||||
now := time.Now()
|
|
||||||
rawRedirect := redirectURL.String()
|
|
||||||
params, _ := url.ParseQuery(a.RawQuery) // handled by ServeMux
|
|
||||||
params.Add("redirect_uri", rawRedirect)
|
|
||||||
params.Set("ts", fmt.Sprint(now.Unix()))
|
|
||||||
params.Set("sig", p.signRedirectURL(rawRedirect, now))
|
|
||||||
a.RawQuery = params.Encode()
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
|
@ -7,45 +7,17 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/config"
|
"github.com/pomerium/pomerium/internal/config"
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/proxy/clients"
|
"github.com/pomerium/pomerium/proxy/clients"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockCipher struct{}
|
|
||||||
|
|
||||||
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
|
|
||||||
if string(s) == "error" {
|
|
||||||
return []byte(""), errors.New("error encrypting")
|
|
||||||
}
|
|
||||||
return []byte("OK"), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
|
|
||||||
if string(s) == "error" {
|
|
||||||
return []byte(""), errors.New("error encrypting")
|
|
||||||
}
|
|
||||||
return []byte("OK"), nil
|
|
||||||
}
|
|
||||||
func (a mockCipher) Marshal(s interface{}) (string, error) {
|
|
||||||
if s == "error" {
|
|
||||||
return "", errors.New("error")
|
|
||||||
}
|
|
||||||
return "ok", nil
|
|
||||||
}
|
|
||||||
func (a mockCipher) Unmarshal(s string, i interface{}) error {
|
|
||||||
if s == "unmarshal error" || s == "error" {
|
|
||||||
return errors.New("error")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_RobotsTxt(t *testing.T) {
|
func TestProxy_RobotsTxt(t *testing.T) {
|
||||||
proxy := Proxy{}
|
proxy := Proxy{}
|
||||||
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
|
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
|
||||||
|
@ -60,94 +32,6 @@ func TestProxy_RobotsTxt(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxy_GetRedirectURL(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
host string
|
|
||||||
want *url.URL
|
|
||||||
}{
|
|
||||||
{"google", "google.com", &url.URL{Scheme: "https", Host: "google.com", Path: "/.pomerium/callback"}},
|
|
||||||
{"pomerium", "pomerium.io", &url.URL{Scheme: "https", Host: "pomerium.io", Path: "/.pomerium/callback"}},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
p := &Proxy{redirectURL: &url.URL{Path: "/.pomerium/callback"}}
|
|
||||||
if got := p.GetRedirectURL(tt.host); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Proxy.GetRedirectURL() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_signRedirectURL(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
rawRedirect string
|
|
||||||
timestamp time.Time
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"pomerium", "https://pomerium.io/.pomerium/callback", fixedDate, "wq3rAjRGN96RXS8TAzH-uxQTD0XgY_8ZYEKMiOLD5P4="},
|
|
||||||
{"google", "https://google.com/.pomerium/callback", fixedDate, "7EYHZObq167CuyuPm5CqOtkU4zg5dFeUCs7W7QOrgNQ="},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
p := &Proxy{}
|
|
||||||
if got := p.signRedirectURL(tt.rawRedirect, tt.timestamp); got != tt.want {
|
|
||||||
t.Errorf("Proxy.signRedirectURL() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_GetSignOutURL(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
authenticate string
|
|
||||||
redirect string
|
|
||||||
wantPrefix string
|
|
||||||
}{
|
|
||||||
{"good", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "https://auth.corp.pomerium.io/sign_out?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
authenticateURL, _ := url.Parse(tt.authenticate)
|
|
||||||
redirectURL, _ := url.Parse(tt.redirect)
|
|
||||||
|
|
||||||
p := &Proxy{}
|
|
||||||
// signature is ignored as it is tested above. Avoids testing time.Now
|
|
||||||
if got := p.GetSignOutURL(authenticateURL, redirectURL); !strings.HasPrefix(got.String(), tt.wantPrefix) {
|
|
||||||
t.Errorf("Proxy.GetSignOutURL() = %v, wantPrefix %v", got.String(), tt.wantPrefix)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_GetSignInURL(t *testing.T) {
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
authenticate string
|
|
||||||
redirect string
|
|
||||||
state string
|
|
||||||
|
|
||||||
wantPrefix string
|
|
||||||
}{
|
|
||||||
{"good", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "example_state", "https://auth.corp.pomerium.io/sign_in?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io&response_type=code&shared_secret=shared-secret"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
p := &Proxy{SharedKey: "shared-secret"}
|
|
||||||
authenticateURL, _ := url.Parse(tt.authenticate)
|
|
||||||
redirectURL, _ := url.Parse(tt.redirect)
|
|
||||||
|
|
||||||
if got := p.GetSignInURL(authenticateURL, redirectURL, tt.state); !strings.HasPrefix(got.String(), tt.wantPrefix) {
|
|
||||||
t.Errorf("Proxy.GetSignOutURL() = %v, wantPrefix %v", got.String(), tt.wantPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_Signout(t *testing.T) {
|
func TestProxy_Signout(t *testing.T) {
|
||||||
opts := testOptions(t)
|
opts := testOptions(t)
|
||||||
err := ValidateOptions(opts)
|
err := ValidateOptions(opts)
|
||||||
|
@ -171,7 +55,7 @@ func TestProxy_Signout(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxy_OAuthStart(t *testing.T) {
|
func TestProxy_authenticate(t *testing.T) {
|
||||||
proxy, err := New(testOptions(t))
|
proxy, err := New(testOptions(t))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -179,18 +63,19 @@ func TestProxy_OAuthStart(t *testing.T) {
|
||||||
req := httptest.NewRequest(http.MethodGet, "/oauth-start", nil)
|
req := httptest.NewRequest(http.MethodGet, "/oauth-start", nil)
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
proxy.OAuthStart(rr, req)
|
proxy.authenticate(rr, req)
|
||||||
// expect oauth redirect
|
// expect oauth redirect
|
||||||
if status := rr.Code; status != http.StatusFound {
|
if status := rr.Code; status != http.StatusFound {
|
||||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
|
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
|
||||||
}
|
}
|
||||||
// expected url
|
// expected url
|
||||||
expected := `<a href="https://authenticate.example/sign_in`
|
expected := `<a href="https://authenticate.example/.pomerium/sign_in`
|
||||||
body := rr.Body.String()
|
body := rr.Body.String()
|
||||||
if !strings.HasPrefix(body, expected) {
|
if !strings.HasPrefix(body, expected) {
|
||||||
t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
|
t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxy_Handler(t *testing.T) {
|
func TestProxy_Handler(t *testing.T) {
|
||||||
proxy, err := New(testOptions(t))
|
proxy, err := New(testOptions(t))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -226,7 +111,6 @@ func TestProxy_router(t *testing.T) {
|
||||||
{"good corp", "https://corp.example.com", policies, nil, true},
|
{"good corp", "https://corp.example.com", policies, nil, true},
|
||||||
{"good with slash", "https://corp.example.com/", policies, nil, true},
|
{"good with slash", "https://corp.example.com/", policies, nil, true},
|
||||||
{"good with path", "https://corp.example.com/123", policies, nil, true},
|
{"good with path", "https://corp.example.com/123", policies, nil, true},
|
||||||
// {"multiple", "https://corp.example.com/", map[string]string{"corp.unrelated.com": "unrelated.com", "corp.example.com": "example.com"}, nil, true},
|
|
||||||
{"no policies", "https://notcorp.example.com/123", []config.Policy{}, nil, false},
|
{"no policies", "https://notcorp.example.com/123", []config.Policy{}, nil, false},
|
||||||
{"bad corp", "https://notcorp.example.com/123", policies, nil, false},
|
{"bad corp", "https://notcorp.example.com/123", policies, nil, false},
|
||||||
{"bad sub-sub", "https://notcorp.corp.example.com/123", policies, nil, false},
|
{"bad sub-sub", "https://notcorp.corp.example.com/123", policies, nil, false},
|
||||||
|
@ -254,11 +138,15 @@ func TestProxy_Proxy(t *testing.T) {
|
||||||
goodSession := &sessions.State{
|
goodSession := &sessions.State{
|
||||||
AccessToken: "AccessToken",
|
AccessToken: "AccessToken",
|
||||||
RefreshToken: "RefreshToken",
|
RefreshToken: "RefreshToken",
|
||||||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
RefreshDeadline: time.Now().Add(20 * time.Second),
|
||||||
}
|
}
|
||||||
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
|
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
|
@ -285,25 +173,24 @@ func TestProxy_Proxy(t *testing.T) {
|
||||||
authorizer clients.Authorizer
|
authorizer clients.Authorizer
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
{"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(20 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||||
{"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
|
{"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
|
||||||
{"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
{"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||||
{"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
{"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||||
// same request as above, but with cors_allow_preflight=false in the policy
|
// same request as above, but with cors_allow_preflight=false in the policy
|
||||||
{"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
{"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||||
// cors allowed, but the request is missing proper headers
|
// cors allowed, but the request is missing proper headers
|
||||||
{"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
{"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||||
// redirect to start auth process
|
// redirect to start auth process
|
||||||
{"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
{"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
||||||
{"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
{"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||||
{"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError},
|
{"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError},
|
||||||
// authenticate errors
|
// authenticate errors
|
||||||
{"session error, redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: errors.New("weird"), Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
{"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
||||||
{"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
{"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
|
||||||
{"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
|
|
||||||
{"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound},
|
{"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound},
|
||||||
{"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
{"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
||||||
{"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
{"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -319,18 +206,17 @@ func TestProxy_Proxy(t *testing.T) {
|
||||||
p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"}
|
p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"}
|
||||||
p.sessionStore = tt.session
|
p.sessionStore = tt.session
|
||||||
p.AuthorizeClient = tt.authorizer
|
p.AuthorizeClient = tt.authorizer
|
||||||
|
|
||||||
r := httptest.NewRequest(tt.method, tt.host, nil)
|
r := httptest.NewRequest(tt.method, tt.host, nil)
|
||||||
r.Header = tt.header
|
r.Header = tt.header
|
||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
|
state, _ := tt.session.LoadSession(r)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = sessions.NewContext(ctx, state, nil)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
p.Proxy(w, r)
|
p.Proxy(w, r)
|
||||||
if status := w.Code; status != tt.wantStatus {
|
if status := w.Code; status != tt.wantStatus {
|
||||||
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantStatus)
|
|
||||||
t.Errorf("\n%+v", w.Body.String())
|
|
||||||
t.Errorf("\n%+v", opts)
|
|
||||||
t.Errorf("\n%+v", ts.URL)
|
|
||||||
|
|
||||||
t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String())
|
t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -342,6 +228,7 @@ func TestProxy_UserDashboard(t *testing.T) {
|
||||||
opts := testOptions(t)
|
opts := testOptions(t)
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
ctxError error
|
||||||
options config.Options
|
options config.Options
|
||||||
method string
|
method string
|
||||||
cipher cryptutil.Cipher
|
cipher cryptutil.Cipher
|
||||||
|
@ -351,11 +238,10 @@ func TestProxy_UserDashboard(t *testing.T) {
|
||||||
wantAdminForm bool
|
wantAdminForm bool
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK},
|
{"good", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK},
|
||||||
{"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthorize{}, false, http.StatusFound},
|
{"session context error", errors.New("error"), opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
|
||||||
{"can't save csrf", opts, http.MethodGet, &cryptutil.MockCipher{MarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
|
{"want admin form good admin authorization", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
|
||||||
{"want admin form good admin authorization", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
|
{"is admin but authorization fails", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
|
||||||
{"is admin but authorization fails", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -369,6 +255,11 @@ func TestProxy_UserDashboard(t *testing.T) {
|
||||||
p.AuthorizeClient = tt.authorizer
|
p.AuthorizeClient = tt.authorizer
|
||||||
|
|
||||||
r := httptest.NewRequest(tt.method, "/", nil)
|
r := httptest.NewRequest(tt.method, "/", nil)
|
||||||
|
state, _ := tt.session.LoadSession(r)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
@ -395,6 +286,7 @@ func TestProxy_ForceRefresh(t *testing.T) {
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
ctxError error
|
||||||
options config.Options
|
options config.Options
|
||||||
method string
|
method string
|
||||||
cipher cryptutil.Cipher
|
cipher cryptutil.Cipher
|
||||||
|
@ -402,12 +294,12 @@ func TestProxy_ForceRefresh(t *testing.T) {
|
||||||
authorizer clients.Authorizer
|
authorizer clients.Authorizer
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
{"good", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
||||||
{"cannot load session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("load error")}, clients.MockAuthorize{}, http.StatusFound},
|
{"cannot load session", errors.New("error"), opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||||
{"bad id token", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
{"bad id token", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||||
{"issue date too soon", timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest},
|
{"issue date too soon", nil, timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest},
|
||||||
{"refresh failure", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
{"refresh failure", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
||||||
{"can't save refreshed session", opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
{"can't save refreshed session", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -420,6 +312,12 @@ func TestProxy_ForceRefresh(t *testing.T) {
|
||||||
p.AuthorizeClient = tt.authorizer
|
p.AuthorizeClient = tt.authorizer
|
||||||
|
|
||||||
r := httptest.NewRequest(tt.method, "/", nil)
|
r := httptest.NewRequest(tt.method, "/", nil)
|
||||||
|
state, _ := tt.session.LoadSession(r)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
p.ForceRefresh(w, r)
|
p.ForceRefresh(w, r)
|
||||||
if status := w.Code; status != tt.wantStatus {
|
if status := w.Code; status != tt.wantStatus {
|
||||||
|
@ -431,32 +329,29 @@ func TestProxy_ForceRefresh(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxy_Impersonate(t *testing.T) {
|
func TestProxy_Impersonate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
opts := testOptions(t)
|
opts := testOptions(t)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
malformed bool
|
malformed bool
|
||||||
options config.Options
|
options config.Options
|
||||||
|
ctxError error
|
||||||
method string
|
method string
|
||||||
email string
|
email string
|
||||||
groups string
|
groups string
|
||||||
csrf string
|
csrf string
|
||||||
cipher cryptutil.Cipher
|
cipher cryptutil.Cipher
|
||||||
sessionStore sessions.SessionStore
|
sessionStore sessions.SessionStore
|
||||||
csrfStore sessions.CSRFStore
|
|
||||||
authorizer clients.Authorizer
|
authorizer clients.Authorizer
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||||
{"session load error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||||
// {"non admin users rejected", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
|
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||||
{"non admin users rejected on error", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
|
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
|
||||||
{"csrf from store retrieve failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}, GetError: errors.New("err")}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
|
||||||
{"can't decrypt csrf value", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{UnmarshalError: errors.New("err")}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
{"save session failure", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||||
{"decrypted csrf mismatch", false, opts, http.MethodPost, "user@blah.com", "", "CSRF!", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusBadRequest},
|
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||||
{"save session failure", false, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
|
||||||
{"malformed", true, opts, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
|
||||||
{"groups", false, opts, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, &sessions.MockCSRFStore{Cookie: &http.Cookie{Value: "csrf"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -466,17 +361,18 @@ func TestProxy_Impersonate(t *testing.T) {
|
||||||
}
|
}
|
||||||
p.cipher = tt.cipher
|
p.cipher = tt.cipher
|
||||||
p.sessionStore = tt.sessionStore
|
p.sessionStore = tt.sessionStore
|
||||||
p.csrfStore = tt.csrfStore
|
|
||||||
p.AuthorizeClient = tt.authorizer
|
p.AuthorizeClient = tt.authorizer
|
||||||
postForm := url.Values{}
|
postForm := url.Values{}
|
||||||
postForm.Add("email", tt.email)
|
postForm.Add("email", tt.email)
|
||||||
postForm.Add("group", tt.groups)
|
postForm.Add("group", tt.groups)
|
||||||
postForm.Set("csrf", tt.csrf)
|
postForm.Set("csrf", tt.csrf)
|
||||||
uri := &url.URL{Path: "/"}
|
uri := &url.URL{Path: "/"}
|
||||||
if tt.malformed {
|
|
||||||
uri.RawQuery = "email=%zzzzz"
|
|
||||||
}
|
|
||||||
r := httptest.NewRequest(tt.method, uri.String(), bytes.NewBufferString(postForm.Encode()))
|
r := httptest.NewRequest(tt.method, uri.String(), bytes.NewBufferString(postForm.Encode()))
|
||||||
|
state, _ := tt.sessionStore.LoadSession(r)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
@ -489,50 +385,8 @@ func TestProxy_Impersonate(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxy_OAuthCallback(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
csrf sessions.MockCSRFStore
|
|
||||||
session sessions.MockSessionStore
|
|
||||||
params map[string]string
|
|
||||||
wantCode int
|
|
||||||
}{
|
|
||||||
{"good", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusFound},
|
|
||||||
{"state err", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "error"}, http.StatusInternalServerError},
|
|
||||||
{"csrf err", sessions.MockCSRFStore{GetError: errors.New("error")}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
|
|
||||||
{"unmarshal err", sessions.MockCSRFStore{Cookie: &http.Cookie{Name: "something_csrf", Value: "unmarshal error"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
|
|
||||||
{"malformed", sessions.MockCSRFStore{ResponseCSRF: "ok", GetError: nil, Cookie: &http.Cookie{Name: "something_csrf", Value: "csrf_state"}}, sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
proxy, err := New(testOptions(t))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
proxy.sessionStore = &tt.session
|
|
||||||
proxy.csrfStore = tt.csrf
|
|
||||||
proxy.cipher = mockCipher{}
|
|
||||||
// proxy.Csrf
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/.pomerium/callback", nil)
|
|
||||||
q := req.URL.Query()
|
|
||||||
for k, v := range tt.params {
|
|
||||||
q.Add(k, v)
|
|
||||||
}
|
|
||||||
req.URL.RawQuery = q.Encode()
|
|
||||||
if tt.name == "malformed" {
|
|
||||||
req.URL.RawQuery = "email=%zzzzz"
|
|
||||||
}
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
proxy.AuthenticateCallback(w, req)
|
|
||||||
if status := w.Code; status != tt.wantCode {
|
|
||||||
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func TestProxy_SignOut(t *testing.T) {
|
func TestProxy_SignOut(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
verb string
|
verb string
|
||||||
|
@ -542,7 +396,6 @@ func TestProxy_SignOut(t *testing.T) {
|
||||||
{"good post", http.MethodPost, "https://test.example", http.StatusFound},
|
{"good post", http.MethodPost, "https://test.example", http.StatusFound},
|
||||||
{"good get", http.MethodGet, "https://test.example", http.StatusFound},
|
{"good get", http.MethodGet, "https://test.example", http.StatusFound},
|
||||||
{"good empty default", http.MethodGet, "", http.StatusFound},
|
{"good empty default", http.MethodGet, "", http.StatusFound},
|
||||||
{"malformed", http.MethodPost, "", http.StatusInternalServerError},
|
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -554,9 +407,6 @@ func TestProxy_SignOut(t *testing.T) {
|
||||||
postForm := url.Values{}
|
postForm := url.Values{}
|
||||||
postForm.Add("redirect_uri", tt.redirectURL)
|
postForm.Add("redirect_uri", tt.redirectURL)
|
||||||
uri := &url.URL{Path: "/"}
|
uri := &url.URL{Path: "/"}
|
||||||
if tt.name == "malformed" {
|
|
||||||
uri.RawQuery = "redirect_uri=%zzzzz"
|
|
||||||
}
|
|
||||||
|
|
||||||
query, _ := url.ParseQuery(uri.RawQuery)
|
query, _ := url.ParseQuery(uri.RawQuery)
|
||||||
if tt.verb == http.MethodGet {
|
if tt.verb == http.MethodGet {
|
||||||
|
@ -576,3 +426,56 @@ func TestProxy_SignOut(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func uriParseHelper(s string) *url.URL {
|
||||||
|
uri, _ := url.Parse(s)
|
||||||
|
return uri
|
||||||
|
}
|
||||||
|
func TestProxy_VerifySession(t *testing.T) {
|
||||||
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
|
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
session sessions.SessionStore
|
||||||
|
ctxError error
|
||||||
|
provider identity.Authenticator
|
||||||
|
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
|
||||||
|
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||||
|
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusFound},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
a := Proxy{
|
||||||
|
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||||
|
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
||||||
|
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
|
||||||
|
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
|
||||||
|
sessionStore: tt.session,
|
||||||
|
}
|
||||||
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
state, _ := tt.session.LoadSession(r)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
got := a.VerifySession(fn)
|
||||||
|
got.ServeHTTP(w, r)
|
||||||
|
if status := w.Code; status != tt.wantStatus {
|
||||||
|
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||||
|
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
102
proxy/proxy.go
102
proxy/proxy.go
|
@ -2,6 +2,7 @@ package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
stdlog "log"
|
stdlog "log"
|
||||||
|
@ -12,11 +13,11 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/config"
|
"github.com/pomerium/pomerium/internal/config"
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
pom_httputil "github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/internal/templates"
|
"github.com/pomerium/pomerium/internal/templates"
|
||||||
"github.com/pomerium/pomerium/internal/tripper"
|
"github.com/pomerium/pomerium/internal/tripper"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
@ -32,6 +33,11 @@ const (
|
||||||
HeaderEmail = "x-pomerium-authenticated-user-email"
|
HeaderEmail = "x-pomerium-authenticated-user-email"
|
||||||
// HeaderGroups is the header key containing the user's groups.
|
// HeaderGroups is the header key containing the user's groups.
|
||||||
HeaderGroups = "x-pomerium-authenticated-user-groups"
|
HeaderGroups = "x-pomerium-authenticated-user-groups"
|
||||||
|
|
||||||
|
// signinURL is the path to authenticate's sign in endpoint
|
||||||
|
signinURL = "/.pomerium/sign_in"
|
||||||
|
// signoutURL is the path to authenticate's sign out endpoint
|
||||||
|
signoutURL = "/.pomerium/sign_out"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValidateOptions checks that proper configuration settings are set to create
|
// ValidateOptions checks that proper configuration settings are set to create
|
||||||
|
@ -40,23 +46,21 @@ func ValidateOptions(o config.Options) error {
|
||||||
if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil {
|
if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil {
|
||||||
return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %v", err)
|
return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
|
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
|
||||||
return fmt.Errorf("proxy: invalid 'COOKIE_SECRET': %v", err)
|
return fmt.Errorf("proxy: invalid 'COOKIE_SECRET': %v", err)
|
||||||
}
|
}
|
||||||
if o.AuthenticateURL == nil {
|
|
||||||
return fmt.Errorf("proxy: missing 'AUTHENTICATE_SERVICE_URL'")
|
if err := urlutil.ValidateURL(o.AuthenticateURL); err != nil {
|
||||||
}
|
|
||||||
if _, err := urlutil.ParseAndValidateURL(o.AuthenticateURL.String()); err != nil {
|
|
||||||
return fmt.Errorf("proxy: invalid 'AUTHENTICATE_SERVICE_URL': %v", err)
|
return fmt.Errorf("proxy: invalid 'AUTHENTICATE_SERVICE_URL': %v", err)
|
||||||
}
|
}
|
||||||
if o.AuthorizeURL == nil {
|
|
||||||
return fmt.Errorf("proxy: missing 'AUTHORIZE_SERVICE_URL'")
|
if err := urlutil.ValidateURL(o.AuthorizeURL); err != nil {
|
||||||
}
|
|
||||||
if _, err := urlutil.ParseAndValidateURL(o.AuthorizeURL.String()); err != nil {
|
|
||||||
return fmt.Errorf("proxy: invalid 'AUTHORIZE_SERVICE_URL': %v", err)
|
return fmt.Errorf("proxy: invalid 'AUTHORIZE_SERVICE_URL': %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(o.SigningKey) != 0 {
|
if len(o.SigningKey) != 0 {
|
||||||
if _, err := cryptutil.NewES256Signer(o.SigningKey, "localhost"); err != nil {
|
if _, err := cryptutil.NewES256Signer(o.SigningKey, ""); err != nil {
|
||||||
return fmt.Errorf("proxy: invalid 'SIGNING_KEY': %v", err)
|
return fmt.Errorf("proxy: invalid 'SIGNING_KEY': %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -66,17 +70,19 @@ func ValidateOptions(o config.Options) error {
|
||||||
// Proxy stores all the information associated with proxying a request.
|
// Proxy stores all the information associated with proxying a request.
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
// SharedKey used to mutually authenticate service communication
|
// SharedKey used to mutually authenticate service communication
|
||||||
SharedKey string
|
SharedKey string
|
||||||
authenticateURL *url.URL
|
authenticateURL *url.URL
|
||||||
authorizeURL *url.URL
|
authenticateSigninURL *url.URL
|
||||||
|
authenticateSignoutURL *url.URL
|
||||||
|
authorizeURL *url.URL
|
||||||
|
|
||||||
AuthorizeClient clients.Authorizer
|
AuthorizeClient clients.Authorizer
|
||||||
|
|
||||||
cipher cryptutil.Cipher
|
cipher cryptutil.Cipher
|
||||||
cookieName string
|
cookieName string
|
||||||
csrfStore sessions.CSRFStore
|
cookieDomain string
|
||||||
|
cookieSecret []byte
|
||||||
defaultUpstreamTimeout time.Duration
|
defaultUpstreamTimeout time.Duration
|
||||||
redirectURL *url.URL
|
|
||||||
refreshCooldown time.Duration
|
refreshCooldown time.Duration
|
||||||
routeConfigs map[string]*routeConfig
|
routeConfigs map[string]*routeConfig
|
||||||
sessionStore sessions.SessionStore
|
sessionStore sessions.SessionStore
|
||||||
|
@ -95,10 +101,17 @@ func New(opts config.Options) (*Proxy, error) {
|
||||||
if err := ValidateOptions(opts); err != nil {
|
if err := ValidateOptions(opts); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
decodedCookieSecret, err := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
cipher, err := cryptutil.NewCipherFromBase64(opts.CookieSecret)
|
cipher, err := cryptutil.NewCipherFromBase64(opts.CookieSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if opts.CookieDomain == "" {
|
||||||
|
opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String())
|
||||||
|
}
|
||||||
|
|
||||||
cookieStore, err := sessions.NewCookieStore(
|
cookieStore, err := sessions.NewCookieStore(
|
||||||
&sessions.CookieStoreOptions{
|
&sessions.CookieStoreOptions{
|
||||||
|
@ -116,13 +129,12 @@ func New(opts config.Options) (*Proxy, error) {
|
||||||
p := &Proxy{
|
p := &Proxy{
|
||||||
SharedKey: opts.SharedKey,
|
SharedKey: opts.SharedKey,
|
||||||
|
|
||||||
routeConfigs: make(map[string]*routeConfig),
|
routeConfigs: make(map[string]*routeConfig),
|
||||||
|
|
||||||
cipher: cipher,
|
cipher: cipher,
|
||||||
|
cookieSecret: decodedCookieSecret,
|
||||||
|
cookieDomain: opts.CookieDomain,
|
||||||
cookieName: opts.CookieName,
|
cookieName: opts.CookieName,
|
||||||
csrfStore: cookieStore,
|
|
||||||
defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
|
defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
|
||||||
redirectURL: &url.URL{Path: "/.pomerium/callback"},
|
|
||||||
refreshCooldown: opts.RefreshCooldown,
|
refreshCooldown: opts.RefreshCooldown,
|
||||||
sessionStore: cookieStore,
|
sessionStore: cookieStore,
|
||||||
signingKey: opts.SigningKey,
|
signingKey: opts.SigningKey,
|
||||||
|
@ -130,6 +142,9 @@ func New(opts config.Options) (*Proxy, error) {
|
||||||
}
|
}
|
||||||
// DeepCopy urls to avoid accidental mutation, err checked in validate func
|
// DeepCopy urls to avoid accidental mutation, err checked in validate func
|
||||||
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
|
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
|
||||||
|
p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
|
||||||
|
p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
|
||||||
|
|
||||||
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
|
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
|
||||||
|
|
||||||
if err := p.UpdatePolicies(&opts); err != nil {
|
if err := p.UpdatePolicies(&opts); err != nil {
|
||||||
|
@ -172,24 +187,24 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error {
|
||||||
if policy.TLSSkipVerify {
|
if policy.TLSSkipVerify {
|
||||||
tlsClientConfig.InsecureSkipVerify = true
|
tlsClientConfig.InsecureSkipVerify = true
|
||||||
isCustomClientConfig = true
|
isCustomClientConfig = true
|
||||||
log.Warn().Str("to", policy.Source.String()).Msg("proxy: tls skip verify")
|
log.Warn().Str("policy", policy.String()).Msg("proxy: tls skip verify")
|
||||||
}
|
}
|
||||||
if policy.RootCAs != nil {
|
if policy.RootCAs != nil {
|
||||||
tlsClientConfig.RootCAs = policy.RootCAs
|
tlsClientConfig.RootCAs = policy.RootCAs
|
||||||
isCustomClientConfig = true
|
isCustomClientConfig = true
|
||||||
log.Debug().Str("to", policy.Source.String()).Msg("proxy: custom root ca")
|
log.Debug().Str("policy", policy.String()).Msg("proxy: custom root ca")
|
||||||
}
|
}
|
||||||
|
|
||||||
if policy.ClientCertificate != nil {
|
if policy.ClientCertificate != nil {
|
||||||
tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate}
|
tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate}
|
||||||
isCustomClientConfig = true
|
isCustomClientConfig = true
|
||||||
log.Debug().Str("to", policy.Source.String()).Msg("proxy: client certs enabled")
|
log.Debug().Str("policy", policy.String()).Msg("proxy: client certs enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
if policy.TLSServerName != "" {
|
if policy.TLSServerName != "" {
|
||||||
tlsClientConfig.ServerName = policy.TLSServerName
|
tlsClientConfig.ServerName = policy.TLSServerName
|
||||||
isCustomClientConfig = true
|
isCustomClientConfig = true
|
||||||
log.Debug().Str("to", policy.Source.String()).Msgf("proxy: tls hostname override to: %s", policy.TLSServerName)
|
log.Debug().Str("policy", policy.String()).Msgf("proxy: tls hostname override to: %s", policy.TLSServerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// We avoid setting a custom client config unless we have to as
|
// We avoid setting a custom client config unless we have to as
|
||||||
|
@ -212,19 +227,6 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpstreamProxy stores information for proxying the request to the upstream.
|
|
||||||
type UpstreamProxy struct {
|
|
||||||
name string
|
|
||||||
handler http.Handler
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeHTTP handles the second (reverse-proxying) leg of pomerium's request flow
|
|
||||||
func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx, span := trace.StartSpan(r.Context(), fmt.Sprintf("%s%s", r.Host, r.URL.Path))
|
|
||||||
defer span.End()
|
|
||||||
u.handler.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewReverseProxy returns a new ReverseProxy that routes URLs to the scheme, host, and
|
// NewReverseProxy returns a new ReverseProxy that routes URLs to the scheme, host, and
|
||||||
// base path provided in target. NewReverseProxy rewrites the Host header.
|
// base path provided in target. NewReverseProxy rewrites the Host header.
|
||||||
func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
|
func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
|
||||||
|
@ -242,22 +244,17 @@ func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
|
||||||
return proxy
|
return proxy
|
||||||
}
|
}
|
||||||
|
|
||||||
// newReverseProxyHandler applies handler specific options to a given route.
|
// each route has a custom set of middleware applied to the reverse proxy
|
||||||
func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config.Policy) (handler http.Handler, err error) {
|
func (p *Proxy) newReverseProxyHandler(rp http.Handler, route *config.Policy) (http.Handler, error) {
|
||||||
handler = &UpstreamProxy{
|
r := pom_httputil.NewRouter()
|
||||||
name: route.Destination.Host,
|
r.Use(middleware.StripPomeriumCookie(p.cookieName))
|
||||||
handler: rp,
|
|
||||||
}
|
|
||||||
c := middleware.NewChain()
|
|
||||||
c = c.Append(middleware.StripPomeriumCookie(p.cookieName))
|
|
||||||
|
|
||||||
// if signing key is set, add signer to middleware
|
// if signing key is set, add signer to middleware
|
||||||
if len(p.signingKey) != 0 {
|
if len(p.signingKey) != 0 {
|
||||||
signer, err := cryptutil.NewES256Signer(p.signingKey, route.Source.Host)
|
signer, err := cryptutil.NewES256Signer(p.signingKey, route.Source.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c = c.Append(middleware.SignRequest(signer, HeaderUserID, HeaderEmail, HeaderGroups, HeaderJWT))
|
r.Use(middleware.SignRequest(signer, HeaderUserID, HeaderEmail, HeaderGroups, HeaderJWT))
|
||||||
}
|
}
|
||||||
// websockets cannot use the non-hijackable timeout-handler
|
// websockets cannot use the non-hijackable timeout-handler
|
||||||
if !route.AllowWebsockets {
|
if !route.AllowWebsockets {
|
||||||
|
@ -265,11 +262,16 @@ func (p *Proxy) newReverseProxyHandler(rp *httputil.ReverseProxy, route *config.
|
||||||
if route.UpstreamTimeout != 0 {
|
if route.UpstreamTimeout != 0 {
|
||||||
timeout = route.UpstreamTimeout
|
timeout = route.UpstreamTimeout
|
||||||
}
|
}
|
||||||
timeoutMsg := fmt.Sprintf("%s failed to respond within the %s timeout period", route.Destination.Host, timeout)
|
timeoutMsg := fmt.Sprintf("%s timed out in %s", route.Destination.Host, timeout)
|
||||||
handler = http.TimeoutHandler(handler, timeout, timeoutMsg)
|
rp = http.TimeoutHandler(rp, timeout, timeoutMsg)
|
||||||
}
|
}
|
||||||
|
// todo(bdd) : fix cors
|
||||||
return c.Then(handler), nil
|
// if route.CORSAllowPreflight {
|
||||||
|
// r.Use(cors.Default().Handler)
|
||||||
|
// }
|
||||||
|
r.Host(route.Destination.Host)
|
||||||
|
r.PathPrefix("/").Handler(rp)
|
||||||
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateOptions updates internal structures based on config.Options
|
// UpdateOptions updates internal structures based on config.Options
|
||||||
|
|
|
@ -12,8 +12,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/config"
|
"github.com/pomerium/pomerium/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
var fixedDate = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
|
|
||||||
|
|
||||||
func newTestOptions(t *testing.T) *config.Options {
|
func newTestOptions(t *testing.T) *config.Options {
|
||||||
opts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
|
opts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -285,10 +283,11 @@ func Test_UpdateOptions(t *testing.T) {
|
||||||
goodClientCertPolicies := testOptions(t)
|
goodClientCertPolicies := testOptions(t)
|
||||||
goodClientCertPolicies.Policies = []config.Policy{
|
goodClientCertPolicies.Policies = []config.Policy{
|
||||||
{To: "http://foo.example", From: "http://bar.example",
|
{To: "http://foo.example", From: "http://bar.example",
|
||||||
TLSClientKey: "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", TLSClientCert: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=",
|
TLSClientKey: "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", TLSClientCert: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="},
|
||||||
},
|
|
||||||
}
|
}
|
||||||
goodClientCertPolicies.Validate()
|
goodClientCertPolicies.Validate()
|
||||||
|
customServerName := testOptions(t)
|
||||||
|
customServerName.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSServerName: "test"}}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
originalOptions config.Options
|
originalOptions config.Options
|
||||||
|
@ -301,14 +300,13 @@ func Test_UpdateOptions(t *testing.T) {
|
||||||
{"good no change", good, good, "", "https://corp.example.example", false, true},
|
{"good no change", good, good, "", "https://corp.example.example", false, true},
|
||||||
{"changed", good, newPolicies, "", "https://bar.example", false, true},
|
{"changed", good, newPolicies, "", "https://bar.example", false, true},
|
||||||
{"changed and missing", good, newPolicies, "", "https://corp.example.example", false, false},
|
{"changed and missing", good, newPolicies, "", "https://corp.example.example", false, false},
|
||||||
// todo(bdd): not sure what intent of this test is?
|
|
||||||
{"bad signing key", good, newPolicies, "^bad base 64", "https://corp.example.example", true, false},
|
{"bad signing key", good, newPolicies, "^bad base 64", "https://corp.example.example", true, false},
|
||||||
{"bad change bad policy url", good, badNewPolicy, "", "https://bar.example", true, false},
|
{"bad change bad policy url", good, badNewPolicy, "", "https://bar.example", true, false},
|
||||||
// todo: stand up a test server using self signed certificates
|
|
||||||
{"disable tls verification", good, disableTLSPolicies, "", "https://bar.example", false, true},
|
{"disable tls verification", good, disableTLSPolicies, "", "https://bar.example", false, true},
|
||||||
{"custom root ca", good, customCAPolicies, "", "https://bar.example", false, true},
|
{"custom root ca", good, customCAPolicies, "", "https://bar.example", false, true},
|
||||||
{"bad custom root ca base64", good, badCustomCAPolicies, "", "https://bar.example", true, false},
|
{"bad custom root ca base64", good, badCustomCAPolicies, "", "https://bar.example", true, false},
|
||||||
{"good client certs", good, goodClientCertPolicies, "", "https://bar.example", false, true},
|
{"good client certs", good, goodClientCertPolicies, "", "https://bar.example", false, true},
|
||||||
|
{"custom server name", customServerName, customServerName, "", "https://bar.example", false, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue