mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 23:57:34 +02:00
all: support route scoped sessions
Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
83342112bb
commit
d3d60d1055
53 changed files with 2092 additions and 2416 deletions
|
@ -7,9 +7,12 @@ import (
|
|||
"fmt"
|
||||
"html/template"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/config"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
|
@ -18,6 +21,10 @@ import (
|
|||
|
||||
const callbackPath = "/oauth2/callback"
|
||||
|
||||
// DefaultSessionDuration is the default time a managed route session is
|
||||
// valid for.
|
||||
var DefaultSessionDuration = time.Minute * 10
|
||||
|
||||
// ValidateOptions checks that configuration are complete and valid.
|
||||
// Returns on first error found.
|
||||
func ValidateOptions(o config.Options) error {
|
||||
|
@ -41,18 +48,34 @@ func ValidateOptions(o config.Options) error {
|
|||
|
||||
// Authenticate contains data required to run the authenticate service.
|
||||
type Authenticate struct {
|
||||
SharedKey string
|
||||
// RedirectURL is the authenticate service's externally accessible
|
||||
// url that the identity provider (IdP) will callback to following
|
||||
// authentication flow
|
||||
RedirectURL *url.URL
|
||||
|
||||
cookieName string
|
||||
cookieSecure bool
|
||||
cookieDomain string
|
||||
// sharedKey is used to encrypt and authenticate data between services
|
||||
sharedKey string
|
||||
// sharedCipher is used to encrypt data for use between services
|
||||
sharedCipher cipher.AEAD
|
||||
// sharedEncoder is the encoder to use to serialize data to be consumed
|
||||
// by other services
|
||||
sharedEncoder sessions.Encoder
|
||||
|
||||
// data related to this service only
|
||||
cookieOptions *sessions.CookieOptions
|
||||
// cookieSecret is the secret to encrypt and authenticate data for this service
|
||||
cookieSecret []byte
|
||||
templates *template.Template
|
||||
// is the cipher to use to encrypt data for this service
|
||||
cookieCipher cipher.AEAD
|
||||
sessionStore sessions.SessionStore
|
||||
cipher cipher.AEAD
|
||||
encoder cryptutil.SecureEncoder
|
||||
encryptedEncoder sessions.Encoder
|
||||
sessionStores []sessions.SessionStore
|
||||
sessionLoaders []sessions.SessionLoader
|
||||
|
||||
// provider is the interface to interacting with the identity provider (IdP)
|
||||
provider identity.Authenticator
|
||||
|
||||
templates *template.Template
|
||||
}
|
||||
|
||||
// New validates and creates a new authenticate service from a set of Options.
|
||||
|
@ -60,29 +83,37 @@ func New(opts config.Options) (*Authenticate, error) {
|
|||
if err := ValidateOptions(opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// shared state encoder setup
|
||||
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey)
|
||||
signedEncoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// private state encoder setup
|
||||
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||
cipher, err := cryptutil.NewAEADCipher(decodedCookieSecret)
|
||||
encoder := cryptutil.NewSecureJSONEncoder(cipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if opts.CookieDomain == "" {
|
||||
opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String())
|
||||
}
|
||||
cookieStore, err := sessions.NewCookieStore(
|
||||
&sessions.CookieStoreOptions{
|
||||
cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret)
|
||||
encryptedEncoder := ecjson.New(cookieCipher)
|
||||
|
||||
cookieOptions := &sessions.CookieOptions{
|
||||
Name: opts.CookieName,
|
||||
CookieDomain: opts.CookieDomain,
|
||||
CookieSecure: opts.CookieSecure,
|
||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||
CookieExpire: opts.CookieExpire,
|
||||
Encoder: encoder,
|
||||
})
|
||||
Domain: opts.CookieDomain,
|
||||
Secure: opts.CookieSecure,
|
||||
HTTPOnly: opts.CookieHTTPOnly,
|
||||
Expire: opts.CookieExpire,
|
||||
}
|
||||
|
||||
cookieStore, err := sessions.NewCookieStore(cookieOptions, encryptedEncoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qpStore := sessions.NewQueryParamStore(encryptedEncoder, "pomerium_programmatic_token")
|
||||
headerStore := sessions.NewHeaderStore(encryptedEncoder, "Pomerium")
|
||||
|
||||
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
|
||||
redirectURL.Path = callbackPath
|
||||
// configure our identity provider
|
||||
provider, err := identity.New(
|
||||
opts.Provider,
|
||||
&identity.Provider{
|
||||
|
@ -99,16 +130,22 @@ func New(opts config.Options) (*Authenticate, error) {
|
|||
}
|
||||
|
||||
return &Authenticate{
|
||||
SharedKey: opts.SharedKey,
|
||||
RedirectURL: redirectURL,
|
||||
templates: templates.New(),
|
||||
sessionStore: cookieStore,
|
||||
cipher: cipher,
|
||||
encoder: encoder,
|
||||
provider: provider,
|
||||
// shared state
|
||||
sharedKey: opts.SharedKey,
|
||||
sharedCipher: sharedCipher,
|
||||
sharedEncoder: signedEncoder,
|
||||
// private state
|
||||
cookieSecret: decodedCookieSecret,
|
||||
cookieName: opts.CookieName,
|
||||
cookieDomain: opts.CookieDomain,
|
||||
cookieSecure: opts.CookieSecure,
|
||||
cookieCipher: cookieCipher,
|
||||
cookieOptions: cookieOptions,
|
||||
sessionStore: cookieStore,
|
||||
encryptedEncoder: encryptedEncoder,
|
||||
sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore},
|
||||
sessionStores: []sessions.SessionStore{cookieStore, qpStore},
|
||||
// IdP
|
||||
provider: provider,
|
||||
|
||||
templates: templates.New(),
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -31,35 +31,37 @@ var CSPHeaders = map[string]string{
|
|||
"Referrer-Policy": "Same-origin",
|
||||
}
|
||||
|
||||
// Handler returns the authenticate service's HTTP multiplexer, and routes.
|
||||
// Handler returns the authenticate service's handler chain.
|
||||
func (a *Authenticate) Handler() http.Handler {
|
||||
r := httputil.NewRouter()
|
||||
r.Use(middleware.SetHeaders(CSPHeaders))
|
||||
r.Use(csrf.Protect(
|
||||
a.cookieSecret,
|
||||
csrf.Secure(a.cookieSecure),
|
||||
csrf.Secure(a.cookieOptions.Secure),
|
||||
csrf.Path("/"),
|
||||
csrf.Domain(a.cookieDomain),
|
||||
csrf.UnsafePaths([]string{callbackPath}), // enforce CSRF on "safe" handler
|
||||
csrf.FormValueName("state"), // rfc6749 section-10.12
|
||||
csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieName)),
|
||||
csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieOptions.Name)),
|
||||
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
|
||||
))
|
||||
|
||||
r.HandleFunc("/robots.txt", a.RobotsTxt).Methods(http.MethodGet)
|
||||
// Identity Provider (IdP) endpoints
|
||||
r.HandleFunc("/oauth2/callback", a.OAuthCallback).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/v1/token", a.ExchangeToken)
|
||||
|
||||
// Proxy service endpoints
|
||||
v := r.PathPrefix("/.pomerium").Subrouter()
|
||||
v.Use(middleware.ValidateSignature(a.SharedKey))
|
||||
v.Use(middleware.ValidateRedirectURI(a.RedirectURL))
|
||||
v.Use(sessions.RetrieveSession(a.sessionStore))
|
||||
v.Use(middleware.ValidateSignature(a.sharedKey))
|
||||
v.Use(sessions.RetrieveSession(a.sessionLoaders...))
|
||||
v.Use(a.VerifySession)
|
||||
|
||||
v.HandleFunc("/sign_in", a.SignIn)
|
||||
v.HandleFunc("/sign_out", a.SignOut)
|
||||
|
||||
// programmatic access api endpoint
|
||||
api := r.PathPrefix("/api").Subrouter()
|
||||
api.Use(sessions.RetrieveSession(a.sessionLoaders...))
|
||||
api.HandleFunc("/v1/refresh", a.RefreshAPI)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
|
@ -71,14 +73,14 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
|||
if errors.Is(err, sessions.ErrExpired) {
|
||||
if err := a.refresh(w, r, state); err != nil {
|
||||
log.FromRequest(r).Debug().Str("cause", err.Error()).Msg("authenticate: couldn't refresh session")
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
a.redirectToIdentityProvider(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// redirect to restart middleware-chain following refresh
|
||||
http.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound)
|
||||
return
|
||||
} else if err != nil {
|
||||
log.FromRequest(r).Err(err).Msg("authenticate: unexpected session state")
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
log.FromRequest(r).Err(err).Msg("authenticate: malformed session")
|
||||
a.redirectToIdentityProvider(w, r)
|
||||
return
|
||||
}
|
||||
|
@ -95,7 +97,6 @@ func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessio
|
|||
return fmt.Errorf("authenticate: refresh save failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// RobotsTxt handles the /robots.txt route.
|
||||
|
@ -108,19 +109,64 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// SignIn handles to authenticating a user.
|
||||
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
||||
// grab and parse our redirect_uri
|
||||
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
// Add query param to let downstream apps (or auth endpoints) know
|
||||
// this request followed authentication. Useful for auth-forward-endpoint
|
||||
// redirecting
|
||||
// create a clone of the redirect URI, unless this is a programmatic request
|
||||
// in which case we will redirect back to proxy's callback endpoint
|
||||
callbackURL, _ := urlutil.DeepCopy(redirectURL)
|
||||
callbackURL.Path = "/.pomerium/callback"
|
||||
|
||||
q := redirectURL.Query()
|
||||
q.Add("pomerium-auth-callback", "true")
|
||||
|
||||
if q.Get("pomerium_programmatic_destination_url") != "" {
|
||||
callbackURL, err = urlutil.ParseAndValidateURL(q.Get("pomerium_programmatic_destination_url"))
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
}
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
s.SetImpersonation(q.Get("impersonate_email"), q.Get("impersonate_group"))
|
||||
|
||||
newSession := s.NewSession(a.RedirectURL.Host, []string{a.RedirectURL.Host, callbackURL.Host})
|
||||
if q.Get("pomerium_programmatic_destination_url") != "" {
|
||||
newSession.Programmatic = true
|
||||
encSession, err := a.encryptedEncoder.Marshal(newSession)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
q.Set("pomerium_refresh_token", string(encSession))
|
||||
}
|
||||
|
||||
// sign the route session, as a JWT
|
||||
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession(DefaultSessionDuration))
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
// encrypt our route-based token JWT avoiding any accidental logging
|
||||
encryptedJWT := cryptutil.Encrypt(a.sharedCipher, signedJWT, nil)
|
||||
// base64 our encrypted payload for URL-friendlyness
|
||||
encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT)
|
||||
|
||||
// add our encoded and encrypted route-session JWT to a query param
|
||||
q.Set("pomerium_jwt", encodedJWT)
|
||||
|
||||
redirectURL.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
// build our hmac-d redirect URL with our session, pointing back to the
|
||||
// proxy's callback URL which is responsible for setting our new route-session
|
||||
uri := urlutil.SignedRedirectURL(a.sharedKey, callbackURL, redirectURL)
|
||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// SignOut signs the user out and attempts to revoke the user's identity session
|
||||
|
@ -132,7 +178,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
err = a.provider.Revoke(session.AccessToken)
|
||||
err = a.provider.Revoke(r.Context(), session.AccessToken)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err))
|
||||
return
|
||||
|
@ -152,11 +198,12 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
|||
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
|
||||
// https://tools.ietf.org/html/rfc6749#section-4.2.1
|
||||
func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
a.sessionStore.ClearSession(w, r)
|
||||
redirectURL := a.RedirectURL.ResolveReference(r.URL)
|
||||
nonce := csrf.Token(r)
|
||||
now := time.Now().Unix()
|
||||
b := []byte(fmt.Sprintf("%s|%d|", nonce, now))
|
||||
enc := cryptutil.Encrypt(a.cipher, []byte(redirectURL.String()), b)
|
||||
enc := cryptutil.Encrypt(a.cookieCipher, []byte(redirectURL.String()), b)
|
||||
b = append(b, enc...)
|
||||
encodedState := base64.URLEncoding.EncodeToString(b)
|
||||
http.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
|
||||
|
@ -201,7 +248,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
return nil, httputil.Error("malformed state", http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
// split state into its it's components, e.g.
|
||||
// split state into concat'd components
|
||||
// (nonce|timestamp|redirect_url|encrypted_data(redirect_url)+mac(nonce,ts))
|
||||
statePayload := strings.SplitN(string(bytes), "|", 3)
|
||||
if len(statePayload) != 3 {
|
||||
|
@ -209,16 +256,16 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
fmt.Errorf("state malformed, size: %d", len(statePayload)))
|
||||
}
|
||||
|
||||
// verify that the returned timestamp is valid (replay attack)
|
||||
// verify that the returned timestamp is valid
|
||||
if err := cryptutil.ValidTimestamp(statePayload[1]); err != nil {
|
||||
return nil, httputil.Error(err.Error(), http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
// Use our AEAD construct to enforce secrecy and authenticity,:
|
||||
// Use our AEAD construct to enforce secrecy and authenticity:
|
||||
// mac: to validate the nonce again, and above timestamp
|
||||
// decrypt: to prevent leaking 'redirect_uri' to IdP or logs)
|
||||
// decrypt: to prevent leaking 'redirect_uri' to IdP or logs
|
||||
b := []byte(fmt.Sprint(statePayload[0], "|", statePayload[1], "|"))
|
||||
redirectString, err := cryptutil.Decrypt(a.cipher, []byte(statePayload[2]), b)
|
||||
redirectString, err := cryptutil.Decrypt(a.cookieCipher, []byte(statePayload[2]), b)
|
||||
if err != nil {
|
||||
return nil, httputil.Error("'state' has invalid hmac", http.StatusBadRequest, err)
|
||||
}
|
||||
|
@ -235,38 +282,45 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
return redirectURL, nil
|
||||
}
|
||||
|
||||
// ExchangeToken takes an identity provider issued JWT as input ('id_token)
|
||||
// and exchanges that token for a pomerium session. The provided token's
|
||||
// audience ('aud') attribute must match Pomerium's client_id.
|
||||
func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.FormValue("id_token")
|
||||
if code == "" {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("missing id token", http.StatusBadRequest, nil))
|
||||
// RefreshAPI loads a global state, and attempts to refresh the session's access
|
||||
// tokens and state with the identity provider. If successful, a new signed JWT
|
||||
// and refresh token (`refresh_token`) are returned as JSON
|
||||
func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) {
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if err != nil && !errors.Is(err, sessions.ErrExpired) {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
session, err := a.provider.IDTokenToSession(r.Context(), code)
|
||||
newSession, err := a.provider.Refresh(r.Context(), s)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusInternalServerError, err))
|
||||
return
|
||||
}
|
||||
encToken, err := sessions.MarshalSession(session, a.encoder)
|
||||
newSession = newSession.NewSession(s.Issuer, s.Audience)
|
||||
|
||||
encSession, err := a.encryptedEncoder.Marshal(newSession)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusInternalServerError, err))
|
||||
return
|
||||
}
|
||||
restSession := struct {
|
||||
Token string
|
||||
Expiry time.Time `json:",omitempty"`
|
||||
}{
|
||||
Token: encToken,
|
||||
Expiry: session.RefreshDeadline,
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(restSession)
|
||||
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession(DefaultSessionDuration))
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusInternalServerError, err))
|
||||
return
|
||||
}
|
||||
var response struct {
|
||||
JWT string `json:"jwt"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
response.RefreshToken = string(encSession)
|
||||
response.JWT = string(signedJWT)
|
||||
|
||||
jsonResponse, err := json.Marshal(&response)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(jsonBytes)
|
||||
w.Write(jsonResponse)
|
||||
}
|
||||
|
|
|
@ -7,22 +7,27 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func testAuthenticate() *Authenticate {
|
||||
var auth Authenticate
|
||||
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
|
||||
auth.SharedKey = "IzY7MOZwzfOkmELXgozHDKTxoT3nOYhwkcmUVINsRww="
|
||||
auth.cookieSecret = []byte(auth.SharedKey)
|
||||
auth.sharedKey = cryptutil.NewBase64Key()
|
||||
auth.cookieSecret = cryptutil.NewKey()
|
||||
auth.cookieOptions = &sessions.CookieOptions{Name: "name"}
|
||||
auth.templates = templates.New()
|
||||
return &auth
|
||||
}
|
||||
|
@ -67,29 +72,30 @@ func TestAuthenticate_Handler(t *testing.T) {
|
|||
|
||||
func TestAuthenticate_SignIn(t *testing.T) {
|
||||
t.Parallel()
|
||||
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
state string
|
||||
redirectURI string
|
||||
|
||||
scheme string
|
||||
host string
|
||||
qp map[string]string
|
||||
|
||||
session sessions.SessionStore
|
||||
restStore sessions.SessionStore
|
||||
provider identity.MockProvider
|
||||
encoder cryptutil.SecureEncoder
|
||||
encoder sessions.Encoder
|
||||
wantCode int
|
||||
}{
|
||||
{"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusFound},
|
||||
{"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockEncoder{}, http.StatusFound},
|
||||
{"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockEncoder{}, http.StatusFound},
|
||||
{"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockEncoder{}, http.StatusFound}, // mocking hmac is meh
|
||||
{"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockEncoder{}, http.StatusFound},
|
||||
|
||||
// {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusInternalServerError},
|
||||
{"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusFound},
|
||||
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusFound},
|
||||
{"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusBadRequest},
|
||||
// actually caught by go's handler, but we should keep the test.
|
||||
{"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusBadRequest},
|
||||
{"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{MarshalError: errors.New("error")}, http.StatusFound},
|
||||
{"good", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{}, http.StatusFound},
|
||||
{"session not valid", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{}, http.StatusFound},
|
||||
{"bad redirect uri query", "", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{}, http.StatusBadRequest},
|
||||
{"bad marshal", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||
{"session error", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{}, &encoding.MockEncoder{}, http.StatusBadRequest},
|
||||
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{}, http.StatusFound},
|
||||
{"encrypted encoder error", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{}, http.StatusBadRequest},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -97,16 +103,28 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
sessionStore: tt.session,
|
||||
provider: tt.provider,
|
||||
RedirectURL: uriParseHelper("https://some.example"),
|
||||
SharedKey: "secret",
|
||||
encoder: tt.encoder,
|
||||
sharedKey: "secret",
|
||||
sharedEncoder: tt.encoder,
|
||||
encryptedEncoder: tt.encoder,
|
||||
sharedCipher: aead,
|
||||
cookieOptions: &sessions.CookieOptions{
|
||||
Name: "cookie",
|
||||
Domain: "foo",
|
||||
},
|
||||
}
|
||||
uri := &url.URL{Host: "corp.some.example", Scheme: "https", Path: "/"}
|
||||
uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI)
|
||||
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
|
||||
uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
|
||||
|
||||
queryString := uri.Query()
|
||||
for k, v := range tt.qp {
|
||||
queryString.Set(k, v)
|
||||
}
|
||||
uri.RawQuery = queryString.Encode()
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/?redirect_uri="+uri.String(), nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
state, _ := tt.session.LoadSession(r)
|
||||
state, err := tt.session.LoadSession(r)
|
||||
ctx := r.Context()
|
||||
ctx = sessions.NewContext(ctx, state, nil)
|
||||
ctx = sessions.NewContext(ctx, state, err)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
@ -141,10 +159,10 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
wantCode int
|
||||
wantBody string
|
||||
}{
|
||||
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
|
||||
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"},
|
||||
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, ""},
|
||||
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, ""},
|
||||
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
|
||||
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"error\":\"could not revoke user session\"}\n"},
|
||||
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"error\":\"Bad Request\"}\n"},
|
||||
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"error\":\"malformed redirect_uri\"}\n"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -164,15 +182,19 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
ctx := r.Context()
|
||||
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||
r = r.WithContext(ctx)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
a.SignOut(w, r)
|
||||
if status := w.Code; status != tt.wantCode {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
|
||||
}
|
||||
if body := w.Body.String(); !strings.Contains(body, tt.wantBody) {
|
||||
t.Errorf("handler returned wrong body Body: got \n%s \n%s", body, tt.wantBody)
|
||||
body := w.Body.String()
|
||||
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
|
||||
t.Errorf("handler returned wrong body Body: %s", diff)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -199,19 +221,19 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
want string
|
||||
wantCode int
|
||||
}{
|
||||
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusFound},
|
||||
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound},
|
||||
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
|
||||
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusInternalServerError},
|
||||
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
|
||||
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
|
||||
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
|
||||
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError},
|
||||
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -224,7 +246,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
RedirectURL: authURL,
|
||||
sessionStore: tt.session,
|
||||
provider: tt.provider,
|
||||
cipher: aead,
|
||||
cookieCipher: aead,
|
||||
}
|
||||
u, _ := url.Parse("/oauthGet")
|
||||
params, _ := url.ParseQuery(u.RawQuery)
|
||||
|
@ -235,7 +257,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
// (nonce|timestamp|redirect_url|encrypt(redirect_url),mac(nonce,ts))
|
||||
b := []byte(fmt.Sprintf("%s|%d|%s", nonce, tt.ts, tt.extraMac))
|
||||
|
||||
enc := cryptutil.Encrypt(a.cipher, []byte(tt.redirectURI), b)
|
||||
enc := cryptutil.Encrypt(a.cookieCipher, []byte(tt.redirectURI), b)
|
||||
b = append(b, enc...)
|
||||
encodedState := base64.URLEncoding.EncodeToString(b)
|
||||
if tt.extraState != "" {
|
||||
|
@ -261,59 +283,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAuthenticate_ExchangeToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
idToken string
|
||||
restStore sessions.SessionStore
|
||||
encoder cryptutil.SecureEncoder
|
||||
provider identity.MockProvider
|
||||
want string
|
||||
}{
|
||||
{"good", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
|
||||
{"could not exchange identity for session", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionError: errors.New("error")}, ""},
|
||||
{"missing token", http.MethodPost, "", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "missing id token"},
|
||||
{"malformed form", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
|
||||
{"can't marshal token", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{MarshalError: errors.New("can't marshal token")}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "can't marshal token"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a := &Authenticate{
|
||||
encoder: tt.encoder,
|
||||
provider: tt.provider,
|
||||
sessionStore: tt.restStore,
|
||||
cipher: aead,
|
||||
}
|
||||
form := url.Values{}
|
||||
if tt.idToken != "" {
|
||||
form.Add("id_token", tt.idToken)
|
||||
}
|
||||
rawForm := form.Encode()
|
||||
|
||||
if tt.name == "malformed form" {
|
||||
rawForm = "example=%zzzzz"
|
||||
}
|
||||
r := httptest.NewRequest(tt.method, "/", strings.NewReader(rawForm))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
a.ExchangeToken(w, r)
|
||||
got := w.Body.String()
|
||||
if !strings.Contains(got, tt.want) {
|
||||
t.Errorf("Authenticate.ExchangeToken() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||
t.Parallel()
|
||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -331,11 +300,11 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
|
||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusOK},
|
||||
{"expired,refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
|
||||
{"expired,save error", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusFound},
|
||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
|
||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||
{"good refresh expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
||||
{"expired,refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
|
||||
{"expired,save error", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -344,12 +313,12 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
a := Authenticate{
|
||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
||||
sharedKey: cryptutil.NewBase64Key(),
|
||||
cookieSecret: cryptutil.NewKey(),
|
||||
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
||||
sessionStore: tt.session,
|
||||
provider: tt.provider,
|
||||
cipher: aead,
|
||||
cookieCipher: aead,
|
||||
}
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
state, _ := tt.session.LoadSession(r)
|
||||
|
@ -370,3 +339,57 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticate_RefreshAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
session sessions.SessionStore
|
||||
ctxError error
|
||||
|
||||
provider identity.Authenticator
|
||||
secretEncoder sessions.Encoder
|
||||
sharedEncoder sessions.Encoder
|
||||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, http.StatusOK},
|
||||
{"refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
||||
{"session is not refreshable error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
|
||||
{"secret encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, encoding.MockEncoder{MarshalError: errors.New("error")}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
||||
{"shared encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, encoding.MockEncoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a := Authenticate{
|
||||
sharedKey: cryptutil.NewBase64Key(),
|
||||
cookieSecret: cryptutil.NewKey(),
|
||||
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
||||
encryptedEncoder: tt.secretEncoder,
|
||||
sharedEncoder: tt.sharedEncoder,
|
||||
sessionStore: tt.session,
|
||||
provider: tt.provider,
|
||||
cookieCipher: aead,
|
||||
}
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
state, _ := tt.session.LoadSession(r)
|
||||
ctx := r.Context()
|
||||
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
a.RefreshAPI(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,30 @@
|
|||
# Changelog
|
||||
|
||||
## vUnreleased
|
||||
|
||||
### New
|
||||
|
||||
- Session state is now route-scoped. Each managed route uses a transparent, signed JSON Web Token (JWT) to assert identity.
|
||||
- Managed routes no longer need to be under the same subdomain! Access can be delegated to any route, on any domain.
|
||||
- Programmatic access now also uses JWT tokens. Access tokens are now generated via a standard oauth2 token flow, and credentials can be refreshed for as long as is permitted by the underlying identity provider.
|
||||
- User dashboard now pulls in additional user context fields (where supported) like the profile picture, first and last name, and so on.
|
||||
|
||||
### Security
|
||||
|
||||
- Some identity providers (Okta and Azure) previously used mutable signifiers to set and assert group membership. Group membership for all providers now use globally unique and immutable identifiers when available.
|
||||
|
||||
### Changed
|
||||
|
||||
- Azure AD identity provider now uses globally unique and immutable `ID` for [group membership](https://docs.microsoft.com/en-us/graph/api/group-get?view=graph-rest-1.0&tabs=http).
|
||||
- Okta no longer uses tokens to retrieve group membership. Group membership is now fetched using Okta's HTTP API. [Group membership](https://developer.okta.com/docs/reference/api/groups/) is now determined by the globally unique and immutable `ID` field.
|
||||
- Okta now requires an additional set of credentials to be used to query for group membership set as a [service account](https://www.pomerium.io/docs/reference/reference.html#identity-provider-service-account).
|
||||
- URLs are no longer validated to be on the same domain-tree as the authenticate service. Managed routes can live on any domain.
|
||||
|
||||
### Removed
|
||||
|
||||
- Force refresh has been removed from the dashboard.
|
||||
- Previous programmatic authentication endpoints (`/api/v1/token`) has been removed and is no longer supported.
|
||||
|
||||
## v0.4.2
|
||||
|
||||
### Security
|
||||
|
|
|
@ -11,6 +11,30 @@ description: >-
|
|||
|
||||
### Breaking
|
||||
|
||||
#### Subdomain requirement dropped
|
||||
|
||||
- Pomerium services and managed routes are no longer required to be on the same domain-tree. Access can be delegated to any route, on any domain (that you have access to, of course).
|
||||
|
||||
#### Azure AD
|
||||
|
||||
- The Azure AD provider now uses the globally unique and immutable`ID` instead of `group name` to attest a user's [group membership](https://docs.microsoft.com/en-us/graph/api/group-get?view=graph-rest-1.0&tabs=http). Please update your policies to use Group `ID`s instead of group names.
|
||||
|
||||
#### Okta
|
||||
|
||||
- Okta no longer uses tokens to retrieve group membership. [Group membership](https://developer.okta.com/docs/reference/api/groups/) is now fetched using Okta's API. Please update your policies to use Group `ID`s instead of group names.
|
||||
- Okta's group membership is now determined by the globally unique and immutable ID field.
|
||||
- Okta now requires an additional set of credentials to be used to query for group membership set as a [service account](https://www.pomerium.io/docs/reference/reference.html#identity-provider-service-account).
|
||||
|
||||
#### Force Refresh Removed
|
||||
|
||||
Force refresh has been removed from the dashboard. Logging out and back in again should have the equivalent desired effect.
|
||||
|
||||
#### Programmatic Access API changed
|
||||
|
||||
Previous programmatic authentication endpoints (`/api/v1/token`) has been removed and has been replaced by a per-route, oauth2 based auth flow. Please see updated [programmatic documentation](https://www.pomerium.io/docs/reference/programmatic-access.html) how to use the new programmatic access api.
|
||||
|
||||
#### Forward-auth route change
|
||||
|
||||
Previously, routes were verified by taking the downstream applications hostname in the form of a path `(e.g. ${fwdauth}/.pomerium/verify/httpbin.some.example`) variable. The new method for verifying a route using forward authentication is to pass the entire requested url in the form of a query string `(e.g. ${fwdauth}/.pomerium/verify?url=https://httpbin.some.example)` where the routed domain is the value of the `uri` key.
|
||||
|
||||
Note that the verification URL is no longer nested under the `.pomerium` endpoint.
|
||||
|
|
16
go.mod
16
go.mod
|
@ -3,7 +3,7 @@ module github.com/pomerium/pomerium
|
|||
go 1.12
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.40.0 // indirect
|
||||
cloud.google.com/go v0.47.0 // indirect
|
||||
contrib.go.opencensus.io/exporter/jaeger v0.1.0
|
||||
contrib.go.opencensus.io/exporter/prometheus v0.1.0
|
||||
github.com/fsnotify/fsnotify v1.4.7
|
||||
|
@ -25,14 +25,14 @@ require (
|
|||
github.com/spf13/viper v1.4.0
|
||||
github.com/stretchr/testify v1.4.0 // indirect
|
||||
go.opencensus.io v0.22.0
|
||||
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8
|
||||
golang.org/x/net v0.0.0-20191003171128-d98b1b443823
|
||||
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4
|
||||
golang.org/x/net v0.0.0-20191028085509-fe3aa8a45271
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
|
||||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9 // indirect
|
||||
google.golang.org/api v0.6.0
|
||||
google.golang.org/appengine v1.6.1 // indirect
|
||||
google.golang.org/genproto v0.0.0-20191002211648-c459b9ce5143 // indirect
|
||||
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c // indirect
|
||||
google.golang.org/api v0.13.0
|
||||
google.golang.org/appengine v1.6.5 // indirect
|
||||
google.golang.org/genproto v0.0.0-20191028173616-919d9bdd9fe6 // indirect
|
||||
google.golang.org/grpc v1.24.0
|
||||
gopkg.in/square/go-jose.v2 v2.3.1
|
||||
gopkg.in/square/go-jose.v2 v2.4.0
|
||||
gopkg.in/yaml.v2 v2.2.3
|
||||
)
|
||||
|
|
85
go.sum
85
go.sum
|
@ -2,14 +2,24 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT
|
|||
cloud.google.com/go v0.34.0 h1:eOI3/cP2VTU6uZLDYAoic+eyzzB9YyGmJ7eIjl8rOPg=
|
||||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU=
|
||||
cloud.google.com/go v0.40.0 h1:FjSY7bOj+WzJe6TZRVtXI2b9kAYvtNg4lMbcH2+MUkk=
|
||||
cloud.google.com/go v0.40.0/go.mod h1:Tk58MuI9rbLMKlAjeO/bDnteAx7tX2gJIXw4T5Jwlro=
|
||||
cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU=
|
||||
cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY=
|
||||
cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc=
|
||||
cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0=
|
||||
cloud.google.com/go v0.47.0 h1:1JUtpcY9E7+eTospEwWS2QXP3DEn7poB3E2j0jN74mM=
|
||||
cloud.google.com/go v0.47.0/go.mod h1:5p3Ky/7f3N10VBkhuR5LFtddroTiMyjZV/Kj5qOQFxU=
|
||||
cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o=
|
||||
cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
|
||||
cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I=
|
||||
cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw=
|
||||
contrib.go.opencensus.io/exporter/jaeger v0.1.0 h1:WNc9HbA38xEQmsI40Tjd/MNU/g8byN2Of7lwIjv0Jdc=
|
||||
contrib.go.opencensus.io/exporter/jaeger v0.1.0/go.mod h1:VYianECmuFPwU37O699Vc1GOcy+y8kOsfaxHRImmjbA=
|
||||
contrib.go.opencensus.io/exporter/prometheus v0.1.0 h1:SByaIoWwNgMdPSgl5sMqM2KDE5H/ukPWBRo314xiDvg=
|
||||
contrib.go.opencensus.io/exporter/prometheus v0.1.0/go.mod h1:cGFniUXGZlKRjzOyuZJ6mgB+PgBcCIa79kEKR8YCW+A=
|
||||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
|
||||
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
|
||||
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
|
||||
|
@ -40,6 +50,7 @@ github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFP
|
|||
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
|
||||
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
|
||||
|
@ -71,7 +82,11 @@ github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
|
|||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||
github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM=
|
||||
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||
github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk=
|
||||
|
@ -157,6 +172,7 @@ github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7z
|
|||
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
||||
github.com/rs/zerolog v1.14.3 h1:4EGfSkR2hJDB0s3oFfrlPqjU1e4WLncergLil3nEKW0=
|
||||
github.com/rs/zerolog v1.14.3/go.mod h1:3WXPzbXEEliJ+a6UFE4vhIxV8qR1EML6ngzP9ug4eYg=
|
||||
|
@ -199,16 +215,30 @@ go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/
|
|||
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmVhcMLU6v5fEb/ok4wyqtNU=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8 h1:1wopBVtVdWnn03fZelqdXTqk7U7zPQCb+T4rbU9ZEoU=
|
||||
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf h1:fnPsqIDRbCSgumaMCRpoIoF2s4qxv0xSSS0BVZUE/ss=
|
||||
golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4 h1:PDpCLFAH/YIX0QpHPf2eO7L4rC2OOirBrKtXTLLiNTY=
|
||||
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek=
|
||||
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||
golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
|
||||
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
|
||||
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
|
||||
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
|
@ -225,8 +255,9 @@ golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn
|
|||
golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65 h1:+rhAzEzT3f4JtomfC371qB+0Ola2caSKcY69NUBZrRQ=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqPKXhus2MkfktJ91kRh4=
|
||||
golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191028085509-fe3aa8a45271 h1:N66aaryRB3Ax92gH0v3hp1QYZ3zWWCCUR/j8Ifh45Ss=
|
||||
golang.org/x/net v0.0.0-20191028085509-fe3aa8a45271/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421 h1:Wo7BWFiOk0QRFMLYMqJGFMd9CgUAcGx7V+qEg/h5IBI=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
|
@ -247,12 +278,14 @@ golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5h
|
|||
golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9 h1:L2auWcuQIvxz9xSEqzESnV/QN/gNRXNApHi3fYwl2w0=
|
||||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c h1:S/FtSvpNLtFBgjTqcKsRpsa6aVsI6iztaz1bQd9BJwE=
|
||||
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
@ -266,35 +299,51 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
|
|||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||
golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||
golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20191010171213-8abd42400456/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
|
||||
google.golang.org/api v0.3.2/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
|
||||
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
|
||||
google.golang.org/api v0.6.0 h1:2tJEkRfnZL5g1GeBUlITh/rqT5HG3sFcoVCUUxmgJ2g=
|
||||
google.golang.org/api v0.6.0/go.mod h1:btoxGiFvQNVUZQ8W08zLtrVS08CNpINPEfxXxgJL1Q4=
|
||||
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
|
||||
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
|
||||
google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
|
||||
google.golang.org/api v0.13.0 h1:Q3Ui3V3/CVinFWFiW39Iw0kMuVrRzYX0wN6OPFp0lTA=
|
||||
google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/appengine v1.6.1 h1:QzqyMA1tlu6CgqCDUtU9V+ZKhLFT2dkJuANu5QaxI3I=
|
||||
google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0=
|
||||
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
|
||||
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
|
||||
google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
|
||||
google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
|
||||
google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
|
||||
google.golang.org/genproto v0.0.0-20190530194941-fb225487d101/go.mod h1:z3L6/3dTEVtUr6QSP8miRzeRqwQOioJ9I66odjN4I7s=
|
||||
google.golang.org/genproto v0.0.0-20191002211648-c459b9ce5143 h1:tikhlQEJeezbnu0Zcblj7g5vm/L7xt6g1vnfq8mRCS4=
|
||||
google.golang.org/genproto v0.0.0-20191002211648-c459b9ce5143/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
|
||||
google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
||||
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
||||
google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8=
|
||||
google.golang.org/genproto v0.0.0-20191009194640-548a555dbc03/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
|
||||
google.golang.org/genproto v0.0.0-20191028173616-919d9bdd9fe6 h1:UXl+Zk3jqqcbEVV7ace5lrt4YdA4tXiz3f/KbmD29Vo=
|
||||
google.golang.org/genproto v0.0.0-20191028173616-919d9bdd9fe6/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
|
||||
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||
google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
|
||||
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
|
||||
google.golang.org/grpc v1.24.0 h1:vb/1TCsVn3DcJlQ0Gs1yB1pKI6Do2/QNwxdKqmc/b0s=
|
||||
google.golang.org/grpc v1.24.0/go.mod h1:XDChyiUovWa60DnaeDeZmSW86xtLtjtZbwvSiRnRtcA=
|
||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||
|
@ -302,10 +351,11 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+
|
|||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
|
||||
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=
|
||||
gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/square/go-jose.v2 v2.4.0 h1:0kXPskUMGAXXWJlP05ktEMOV0vmzFQUWw6d+aZJQU8A=
|
||||
gopkg.in/square/go-jose.v2 v2.4.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
@ -319,4 +369,5 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh
|
|||
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
|
||||
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
@ -49,3 +52,80 @@ func bytesToCertPool(b []byte) (*x509.CertPool, error) {
|
|||
}
|
||||
return certPool, nil
|
||||
}
|
||||
|
||||
// DecodePublicKey decodes a PEM-encoded ECDSA public key.
|
||||
func DecodePublicKey(encodedKey []byte) (*ecdsa.PublicKey, error) {
|
||||
block, _ := pem.Decode(encodedKey)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("cryptutil: decoded nil PEM block")
|
||||
}
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ecdsaPub, ok := pub.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, errors.New("cryptutil: data was not an ECDSA public key")
|
||||
}
|
||||
|
||||
return ecdsaPub, nil
|
||||
}
|
||||
|
||||
// EncodePublicKey encodes an ECDSA public key to PEM format.
|
||||
func EncodePublicKey(key *ecdsa.PublicKey) ([]byte, error) {
|
||||
derBytes, err := x509.MarshalPKIXPublicKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block := &pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: derBytes,
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(block), nil
|
||||
}
|
||||
|
||||
// DecodePrivateKey decodes a PEM-encoded ECDSA private key.
|
||||
func DecodePrivateKey(encodedKey []byte) (*ecdsa.PrivateKey, error) {
|
||||
var skippedTypes []string
|
||||
var block *pem.Block
|
||||
|
||||
for {
|
||||
block, encodedKey = pem.Decode(encodedKey)
|
||||
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("cryptutil: failed to find EC PRIVATE KEY in PEM data after skipping types %v", skippedTypes)
|
||||
}
|
||||
|
||||
if block.Type == "EC PRIVATE KEY" {
|
||||
break
|
||||
} else {
|
||||
skippedTypes = append(skippedTypes, block.Type)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
privKey, err := x509.ParseECPrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return privKey, nil
|
||||
}
|
||||
|
||||
// EncodePrivateKey encodes an ECDSA private key to PEM format.
|
||||
func EncodePrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
|
||||
derKey, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyBlock := &pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: derKey,
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(keyBlock), nil
|
||||
}
|
||||
|
|
|
@ -1,10 +1,36 @@
|
|||
package cryptutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// A keypair for NIST P-256 / secp256r1
|
||||
// Generated using:
|
||||
// openssl ecparam -genkey -name prime256v1 -outform PEM
|
||||
var pemECPrivateKeyP256 = `-----BEGIN EC PARAMETERS-----
|
||||
BggqhkjOPQMBBw==
|
||||
-----END EC PARAMETERS-----
|
||||
-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIOI+EZsjyN3jvWJI/KDihFmqTuDpUe/if6f/pgGTBta/oAoGCCqGSM49
|
||||
AwEHoUQDQgAEhhObKJ1r1PcUw+3REd/TbmSZnDvXnFUSTwqQFo5gbfIlP+gvEYba
|
||||
+Rxj2hhqjfzqxIleRK40IRyEi3fJM/8Qhg==
|
||||
-----END EC PRIVATE KEY-----
|
||||
`
|
||||
|
||||
var pemECPublicKeyP256 = `-----BEGIN PUBLIC KEY-----
|
||||
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEhhObKJ1r1PcUw+3REd/TbmSZnDvX
|
||||
nFUSTwqQFo5gbfIlP+gvEYba+Rxj2hhqjfzqxIleRK40IRyEi3fJM/8Qhg==
|
||||
-----END PUBLIC KEY-----
|
||||
`
|
||||
|
||||
var garbagePEM = `-----BEGIN GARBAGE-----
|
||||
TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQ=
|
||||
-----END GARBAGE-----
|
||||
`
|
||||
|
||||
func TestCertifcateFromBase64(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
|
@ -91,3 +117,39 @@ func TestCertificateFromFile(t *testing.T) {
|
|||
}
|
||||
_ = listener
|
||||
}
|
||||
|
||||
func TestPublicKeyMarshaling(t *testing.T) {
|
||||
ecKey, err := DecodePublicKey([]byte(pemECPublicKeyP256))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = DecodePublicKey(nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
pemBytes, _ := EncodePublicKey(ecKey)
|
||||
if !bytes.Equal(pemBytes, []byte(pemECPublicKeyP256)) {
|
||||
t.Fatal("public key encoding did not match")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestPrivateKeyBadDecode(t *testing.T) {
|
||||
_, err := DecodePrivateKey([]byte(garbagePEM))
|
||||
if err == nil {
|
||||
t.Fatal("decoded garbage data without complaint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivateKeyMarshaling(t *testing.T) {
|
||||
ecKey, err := DecodePrivateKey([]byte(pemECPrivateKeyP256))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pemBytes, _ := EncodePrivateKey(ecKey)
|
||||
if !strings.HasSuffix(pemECPrivateKeyP256, string(pemBytes)) {
|
||||
t.Fatal("private key encoding did not match")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,9 @@
|
|||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
@ -30,106 +26,6 @@ func NewAEADCipherFromBase64(s string) (cipher.AEAD, error) {
|
|||
return NewAEADCipher(decoded)
|
||||
}
|
||||
|
||||
// SecureEncoder provides and interface for to encrypt and decrypting structures .
|
||||
type SecureEncoder interface {
|
||||
Marshal(interface{}) (string, error)
|
||||
Unmarshal(string, interface{}) error
|
||||
}
|
||||
|
||||
// SecureJSONEncoder implements SecureEncoder for JSON using an AEAD cipher.
|
||||
//
|
||||
// See https://en.wikipedia.org/wiki/Authenticated_encryption
|
||||
type SecureJSONEncoder struct {
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
// NewSecureJSONEncoder takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher.
|
||||
func NewSecureJSONEncoder(aead cipher.AEAD) SecureEncoder {
|
||||
return &SecureJSONEncoder{aead: aead}
|
||||
}
|
||||
|
||||
// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher
|
||||
// and base64 encodes the binary value as a string and returns the result
|
||||
//
|
||||
// can panic if source of random entropy is exhausted generating a nonce.
|
||||
func (c *SecureJSONEncoder) Marshal(s interface{}) (string, error) {
|
||||
// encode json value
|
||||
plaintext, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// compress the plaintext bytes
|
||||
compressed, err := compress(plaintext)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// encrypt the compressed JSON bytes
|
||||
ciphertext := Encrypt(c.aead, compressed, nil)
|
||||
|
||||
// base64-encode the result
|
||||
encoded := base64.RawURLEncoding.EncodeToString(ciphertext)
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
||||
// byte slice the passed cipher, and unmarshals the resulting JSON into the struct pointer passed
|
||||
func (c *SecureJSONEncoder) Unmarshal(value string, s interface{}) error {
|
||||
// convert base64 string value to bytes
|
||||
ciphertext, err := base64.RawURLEncoding.DecodeString(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// decrypt the bytes
|
||||
compressed, err := Decrypt(c.aead, ciphertext, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// decompress the unencrypted bytes
|
||||
plaintext, err := decompress(compressed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// unmarshal the unencrypted bytes
|
||||
err = json.Unmarshal(plaintext, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// compress gzips a set of bytes
|
||||
func compress(data []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cryptutil: failed to create a gzip writer: %q", err)
|
||||
}
|
||||
if writer == nil {
|
||||
return nil, fmt.Errorf("cryptutil: failed to create a gzip writer")
|
||||
}
|
||||
if _, err = writer.Write(data); err != nil {
|
||||
return nil, fmt.Errorf("cryptutil: failed to compress data with err: %q", err)
|
||||
}
|
||||
if err = writer.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// decompress un-gzips a set of bytes
|
||||
func decompress(data []byte) ([]byte, error) {
|
||||
reader, err := gzip.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cryptutil: failed to create a gzip reader: %q", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
var buf bytes.Buffer
|
||||
if _, err = io.Copy(&buf, reader); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts a value with optional associated data
|
||||
//
|
||||
// Panics if source of randomness fails.
|
||||
|
|
|
@ -39,106 +39,6 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
||||
key := NewKey()
|
||||
|
||||
a, err := NewAEADCipher(key)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
c := SecureJSONEncoder{aead: a}
|
||||
|
||||
type TC struct {
|
||||
Field string `json:"field"`
|
||||
}
|
||||
|
||||
tc := &TC{
|
||||
Field: "my plain text value",
|
||||
}
|
||||
|
||||
value1, err := c.Marshal(tc)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
value2, err := c.Marshal(tc)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
|
||||
if value1 == value2 {
|
||||
t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2)
|
||||
}
|
||||
|
||||
got1 := &TC{}
|
||||
err = c.Unmarshal(value1, got1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got1, tc) {
|
||||
t.Logf("want: %#v", tc)
|
||||
t.Logf(" got: %#v", got1)
|
||||
t.Fatalf("expected structs to be equal")
|
||||
}
|
||||
|
||||
got2 := &TC{}
|
||||
err = c.Unmarshal(value2, got2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got1, got2) {
|
||||
t.Logf("got2: %#v", got2)
|
||||
t.Logf("got1: %#v", got1)
|
||||
t.Fatalf("expected structs to be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureJSONEncoder_Marshal(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
s interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{"unsupported type",
|
||||
struct {
|
||||
Animal string `json:"animal"`
|
||||
Func func() `json:"sound"`
|
||||
}{
|
||||
Animal: "cat",
|
||||
Func: func() {},
|
||||
},
|
||||
true},
|
||||
{"simple",
|
||||
struct {
|
||||
Animal string `json:"animal"`
|
||||
Sound string `json:"sound"`
|
||||
}{
|
||||
Animal: "cat",
|
||||
Sound: "meow",
|
||||
},
|
||||
false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
c, err := NewAEADCipher(NewKey())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
e := SecureJSONEncoder{aead: c}
|
||||
|
||||
_, err = e.Marshal(tt.s)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SecureJSONEncoder.Marshal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAEADCipher(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
|
|
|
@ -7,6 +7,11 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultLeeway defines the default leeway for matching NotBefore/Expiry claims.
|
||||
DefaultLeeway = 1.0 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
errTimestampMalformed = errors.New("internal/cryptutil: timestamp malformed")
|
||||
errTimestampExpired = errors.New("internal/cryptutil: timestamp expired")
|
||||
|
@ -31,7 +36,6 @@ func CheckHMAC(data, suppliedMAC []byte, key string) bool {
|
|||
// ValidTimestamp is a helper function often used in conjunction with an HMAC
|
||||
// function to verify that the timestamp (in unix seconds) is within leeway
|
||||
// period.
|
||||
// todo(bdd) : should leeway be configurable?
|
||||
func ValidTimestamp(ts string) error {
|
||||
var timeStamp int64
|
||||
var err error
|
||||
|
|
|
@ -1,100 +0,0 @@
|
|||
// Package cryptutil provides encoding and decoding routines for various cryptographic structures.
|
||||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// DecodePublicKey decodes a PEM-encoded ECDSA public key.
|
||||
func DecodePublicKey(encodedKey []byte) (*ecdsa.PublicKey, error) {
|
||||
block, _ := pem.Decode(encodedKey)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("cryptutil: decoded nil PEM block")
|
||||
}
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ecdsaPub, ok := pub.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, errors.New("cryptutil: data was not an ECDSA public key")
|
||||
}
|
||||
|
||||
return ecdsaPub, nil
|
||||
}
|
||||
|
||||
// EncodePublicKey encodes an ECDSA public key to PEM format.
|
||||
func EncodePublicKey(key *ecdsa.PublicKey) ([]byte, error) {
|
||||
derBytes, err := x509.MarshalPKIXPublicKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block := &pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: derBytes,
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(block), nil
|
||||
}
|
||||
|
||||
// DecodePrivateKey decodes a PEM-encoded ECDSA private key.
|
||||
func DecodePrivateKey(encodedKey []byte) (*ecdsa.PrivateKey, error) {
|
||||
var skippedTypes []string
|
||||
var block *pem.Block
|
||||
|
||||
for {
|
||||
block, encodedKey = pem.Decode(encodedKey)
|
||||
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("cryptutil: failed to find EC PRIVATE KEY in PEM data after skipping types %v", skippedTypes)
|
||||
}
|
||||
|
||||
if block.Type == "EC PRIVATE KEY" {
|
||||
break
|
||||
} else {
|
||||
skippedTypes = append(skippedTypes, block.Type)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
privKey, err := x509.ParseECPrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return privKey, nil
|
||||
}
|
||||
|
||||
// EncodePrivateKey encodes an ECDSA private key to PEM format.
|
||||
func EncodePrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
|
||||
derKey, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyBlock := &pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: derKey,
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(keyBlock), nil
|
||||
}
|
||||
|
||||
// EncodeSignatureJWT encodes an ECDSA signature according to
|
||||
// https://tools.ietf.org/html/rfc7515#appendix-A.3.1
|
||||
func EncodeSignatureJWT(sig []byte) string {
|
||||
return base64.RawURLEncoding.EncodeToString(sig)
|
||||
}
|
||||
|
||||
// DecodeSignatureJWT decodes an ECDSA signature according to
|
||||
// https://tools.ietf.org/html/rfc7515#appendix-A.3.1
|
||||
func DecodeSignatureJWT(b64sig string) ([]byte, error) {
|
||||
return base64.RawURLEncoding.DecodeString(b64sig)
|
||||
}
|
|
@ -1,105 +0,0 @@
|
|||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// A keypair for NIST P-256 / secp256r1
|
||||
// Generated using:
|
||||
// openssl ecparam -genkey -name prime256v1 -outform PEM
|
||||
var pemECPrivateKeyP256 = `-----BEGIN EC PARAMETERS-----
|
||||
BggqhkjOPQMBBw==
|
||||
-----END EC PARAMETERS-----
|
||||
-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIOI+EZsjyN3jvWJI/KDihFmqTuDpUe/if6f/pgGTBta/oAoGCCqGSM49
|
||||
AwEHoUQDQgAEhhObKJ1r1PcUw+3REd/TbmSZnDvXnFUSTwqQFo5gbfIlP+gvEYba
|
||||
+Rxj2hhqjfzqxIleRK40IRyEi3fJM/8Qhg==
|
||||
-----END EC PRIVATE KEY-----
|
||||
`
|
||||
|
||||
var pemECPublicKeyP256 = `-----BEGIN PUBLIC KEY-----
|
||||
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEhhObKJ1r1PcUw+3REd/TbmSZnDvX
|
||||
nFUSTwqQFo5gbfIlP+gvEYba+Rxj2hhqjfzqxIleRK40IRyEi3fJM/8Qhg==
|
||||
-----END PUBLIC KEY-----
|
||||
`
|
||||
|
||||
var garbagePEM = `-----BEGIN GARBAGE-----
|
||||
TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQ=
|
||||
-----END GARBAGE-----
|
||||
`
|
||||
|
||||
func TestPublicKeyMarshaling(t *testing.T) {
|
||||
ecKey, err := DecodePublicKey([]byte(pemECPublicKeyP256))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = DecodePublicKey(nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
pemBytes, _ := EncodePublicKey(ecKey)
|
||||
if !bytes.Equal(pemBytes, []byte(pemECPublicKeyP256)) {
|
||||
t.Fatal("public key encoding did not match")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestPrivateKeyBadDecode(t *testing.T) {
|
||||
_, err := DecodePrivateKey([]byte(garbagePEM))
|
||||
if err == nil {
|
||||
t.Fatal("decoded garbage data without complaint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivateKeyMarshaling(t *testing.T) {
|
||||
ecKey, err := DecodePrivateKey([]byte(pemECPrivateKeyP256))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pemBytes, _ := EncodePrivateKey(ecKey)
|
||||
if !strings.HasSuffix(pemECPrivateKeyP256, string(pemBytes)) {
|
||||
t.Fatal("private key encoding did not match")
|
||||
}
|
||||
}
|
||||
|
||||
// Test vector from https://tools.ietf.org/html/rfc7515#appendix-A.3.1
|
||||
var jwtTest = []struct {
|
||||
sigBytes []byte
|
||||
b64sig string
|
||||
}{
|
||||
{
|
||||
sigBytes: []byte{14, 209, 33, 83, 121, 99, 108, 72, 60, 47, 127, 21,
|
||||
88, 7, 212, 2, 163, 178, 40, 3, 58, 249, 124, 126, 23, 129, 154, 195, 22, 158,
|
||||
166, 101, 197, 10, 7, 211, 140, 60, 112, 229, 216, 241, 45, 175,
|
||||
8, 74, 84, 128, 166, 101, 144, 197, 242, 147, 80, 154, 143, 63, 127, 138, 131,
|
||||
163, 84, 213},
|
||||
b64sig: "DtEhU3ljbEg8L38VWAfUAqOyKAM6-Xx-F4GawxaepmXFCgfTjDxw5djxLa8ISlSApmWQxfKTUJqPP3-Kg6NU1Q",
|
||||
},
|
||||
}
|
||||
|
||||
func TestJWTEncoding(t *testing.T) {
|
||||
for _, tt := range jwtTest {
|
||||
result := EncodeSignatureJWT(tt.sigBytes)
|
||||
|
||||
if strings.Compare(result, tt.b64sig) != 0 {
|
||||
t.Fatalf("expected %s, got %s\n", tt.b64sig, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTDecoding(t *testing.T) {
|
||||
for _, tt := range jwtTest {
|
||||
resultSig, err := DecodeSignatureJWT(tt.b64sig)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(resultSig, tt.sigBytes) {
|
||||
t.Fatalf("decoded signature was incorrect")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,107 +0,0 @@
|
|||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultLeeway defines the default leeway for matching NotBefore/Expiry claims.
|
||||
DefaultLeeway = 5.0 * time.Minute
|
||||
)
|
||||
|
||||
// JWTSigner implements JWT signing according to JSON Web Token (JWT) RFC7519
|
||||
// https://tools.ietf.org/html/rfc7519
|
||||
type JWTSigner interface {
|
||||
SignJWT(string, string, string) (string, error)
|
||||
}
|
||||
|
||||
// ES256Signer is struct containing the required fields to create a ES256 signed JSON Web Tokens
|
||||
type ES256Signer struct {
|
||||
signer jose.Signer
|
||||
|
||||
mu sync.Mutex
|
||||
// User (sub) is unique, stable identifier for the user.
|
||||
// Use in place of the x-pomerium-authenticated-user-id header.
|
||||
User string `json:"sub,omitempty"`
|
||||
|
||||
// Email (email) is a **custom** claim name identifier for the user email address.
|
||||
// Use in place of the x-pomerium-authenticated-user-email header.
|
||||
Email string `json:"email,omitempty"`
|
||||
|
||||
// Groups (groups) is a **custom** claim name identifier for the user's groups.
|
||||
// Use in place of the x-pomerium-authenticated-user-groups header.
|
||||
Groups string `json:"groups,omitempty"`
|
||||
|
||||
// Audience (aud) must be the destination of the upstream proxy locations.
|
||||
// e.g. `helloworld.corp.example.com`
|
||||
Audience jwt.Audience `json:"aud,omitempty"`
|
||||
// Issuer (iss) is the URL of the proxy.
|
||||
// e.g. `proxy.corp.example.com`
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
// Expiry (exp) is the expiration time in seconds since the UNIX epoch.
|
||||
// Allow 1 minute for skew. The maximum lifetime of a token is 10 minutes + 2 * skew.
|
||||
Expiry jwt.NumericDate `json:"exp,omitempty"`
|
||||
// IssuedAt (iat) is the time is measured in seconds since the UNIX epoch.
|
||||
// Allow 1 minute for skew.
|
||||
IssuedAt jwt.NumericDate `json:"iat,omitempty"`
|
||||
// IssuedAt (nbf) is the time is measured in seconds since the UNIX epoch.
|
||||
// Allow 1 minute for skew.
|
||||
NotBefore jwt.NumericDate `json:"nbf,omitempty"`
|
||||
}
|
||||
|
||||
// NewES256Signer creates a NIST P-256 (aka secp256r1 aka prime256v1) JWT signer
|
||||
// from a base64 encoded private key.
|
||||
//
|
||||
// RSA is not supported due to performance considerations of needing to sign each request.
|
||||
// Go's P-256 is constant-time and SHA-256 is faster on 64-bit machines and immune
|
||||
// to length extension attacks.
|
||||
// See also:
|
||||
// - https://cloud.google.com/iot/docs/how-tos/credentials/keys
|
||||
func NewES256Signer(privKey, audience string) (*ES256Signer, error) {
|
||||
decodedSigningKey, err := base64.StdEncoding.DecodeString(privKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, err := DecodePrivateKey(decodedSigningKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cryptutil: parsing key failed %v", err)
|
||||
}
|
||||
signer, err := jose.NewSigner(
|
||||
jose.SigningKey{
|
||||
Algorithm: jose.ES256, // ECDSA using P-256 and SHA-256
|
||||
Key: key,
|
||||
},
|
||||
(&jose.SignerOptions{}).WithType("JWT"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cryptutil: new signer failed %v", err)
|
||||
}
|
||||
return &ES256Signer{
|
||||
Issuer: "pomerium-proxy",
|
||||
Audience: jwt.Audience{audience},
|
||||
signer: signer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SignJWT creates a signed JWT containing claims for the logged in
|
||||
// user id (`sub`), email (`email`) and groups (`groups`).
|
||||
func (s *ES256Signer) SignJWT(user, email, groups string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.User = user
|
||||
s.Email = email
|
||||
s.Groups = groups
|
||||
now := time.Now()
|
||||
s.IssuedAt = *jwt.NewNumericDate(now)
|
||||
s.Expiry = *jwt.NewNumericDate(now.Add(DefaultLeeway))
|
||||
s.NotBefore = *jwt.NewNumericDate(now.Add(-1 * DefaultLeeway))
|
||||
rawJWT, err := jwt.Signed(s.signer).Claims(s).CompactSerialize()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cryptutil: sign failed %v", err)
|
||||
}
|
||||
return rawJWT, nil
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestES256Signer(t *testing.T) {
|
||||
signer, err := NewES256Signer(base64.StdEncoding.EncodeToString([]byte(pemECPrivateKeyP256)), "destination-url")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if signer == nil {
|
||||
t.Fatal("signer should not be nil")
|
||||
}
|
||||
rawJwt, err := signer.SignJWT("joe-user", "joe-user@example.com", "group1,group2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rawJwt == "" {
|
||||
t.Fatal("jwt should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewES256Signer(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
privKey string
|
||||
audience string
|
||||
wantErr bool
|
||||
}{
|
||||
{"working example", base64.StdEncoding.EncodeToString([]byte(pemECPrivateKeyP256)), "some-domain.com", false},
|
||||
{"bad private key", base64.StdEncoding.EncodeToString([]byte(garbagePEM)), "some-domain.com", true},
|
||||
{"bad base64 key", garbagePEM, "some-domain.com", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewES256Signer(tt.privKey, tt.audience)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewES256Signer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
110
internal/encoding/ecjson/ecjson.go
Normal file
110
internal/encoding/ecjson/ecjson.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
// Package ecjson represents encrypted and compressed content using JSON-based
|
||||
package ecjson // import "github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
// EncryptedCompressedJSON implements SecureEncoder for JSON using an AEAD cipher.
|
||||
//
|
||||
// See https://en.wikipedia.org/wiki/Authenticated_encryption
|
||||
type EncryptedCompressedJSON struct {
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
// New takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher.
|
||||
func New(aead cipher.AEAD) *EncryptedCompressedJSON {
|
||||
return &EncryptedCompressedJSON{aead: aead}
|
||||
}
|
||||
|
||||
// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher
|
||||
// and base64 encodes the binary value as a string and returns the result
|
||||
//
|
||||
// can panic if source of random entropy is exhausted generating a nonce.
|
||||
func (c *EncryptedCompressedJSON) Marshal(s interface{}) ([]byte, error) {
|
||||
// encode json value
|
||||
plaintext, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// compress the plaintext bytes
|
||||
compressed, err := compress(plaintext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// encrypt the compressed JSON bytes
|
||||
ciphertext := cryptutil.Encrypt(c.aead, compressed, nil)
|
||||
|
||||
// base64-encode the result
|
||||
encoded := base64.RawURLEncoding.EncodeToString(ciphertext)
|
||||
return []byte(encoded), nil
|
||||
}
|
||||
|
||||
// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
||||
// byte slice the passed cipher, and unmarshals the resulting JSON into the struct pointer passed
|
||||
func (c *EncryptedCompressedJSON) Unmarshal(data []byte, s interface{}) error {
|
||||
// convert base64 string value to bytes
|
||||
ciphertext, err := base64.RawURLEncoding.DecodeString(string(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// decrypt the bytes
|
||||
compressed, err := cryptutil.Decrypt(c.aead, ciphertext, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// decompress the unencrypted bytes
|
||||
plaintext, err := decompress(compressed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// unmarshal the unencrypted bytes
|
||||
err = json.Unmarshal(plaintext, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// compress gzips a set of bytes
|
||||
func compress(data []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cryptutil: failed to create a gzip writer: %q", err)
|
||||
}
|
||||
if writer == nil {
|
||||
return nil, fmt.Errorf("cryptutil: failed to create a gzip writer")
|
||||
}
|
||||
if _, err = writer.Write(data); err != nil {
|
||||
return nil, fmt.Errorf("cryptutil: failed to compress data with err: %q", err)
|
||||
}
|
||||
if err = writer.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// decompress un-gzips a set of bytes
|
||||
func decompress(data []byte) ([]byte, error) {
|
||||
reader, err := gzip.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cryptutil: failed to create a gzip reader: %q", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
var buf bytes.Buffer
|
||||
if _, err = io.Copy(&buf, reader); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
70
internal/encoding/jws/jws.go
Normal file
70
internal/encoding/jws/jws.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
// Package jws represents content secured with digitalsignatures
|
||||
// using JSON-based data structures as specified by rfc7515
|
||||
package jws // import "github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
// JSONWebSigner is the struct representing a signed JWT.
|
||||
// https://tools.ietf.org/html/rfc7519
|
||||
type JSONWebSigner struct {
|
||||
Signer jose.Signer
|
||||
Issuer string
|
||||
|
||||
key interface{}
|
||||
}
|
||||
|
||||
// NewHS256Signer creates a SHA256 JWT signer from a 32 byte key.
|
||||
func NewHS256Signer(key []byte, issuer string) (*JSONWebSigner, error) {
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key},
|
||||
(&jose.SignerOptions{}).WithType("JWT"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &JSONWebSigner{Signer: sig, key: key, Issuer: issuer}, nil
|
||||
}
|
||||
|
||||
// NewES256Signer creates a NIST P-256 (aka secp256r1 aka prime256v1) JWT signer
|
||||
// from a base64 encoded private key.
|
||||
//
|
||||
// RSA is not supported due to performance considerations of needing to sign each request.
|
||||
// Go's P-256 is constant-time and SHA-256 is faster on 64-bit machines and immune
|
||||
// to length extension attacks.
|
||||
// See : https://cloud.google.com/iot/docs/how-tos/credentials/keys
|
||||
func NewES256Signer(privKey, issuer string) (*JSONWebSigner, error) {
|
||||
decodedSigningKey, err := base64.StdEncoding.DecodeString(privKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, err := cryptutil.DecodePrivateKey(decodedSigningKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key},
|
||||
(&jose.SignerOptions{}).WithType("JWT"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &JSONWebSigner{Signer: sig, key: key, Issuer: issuer}, nil
|
||||
}
|
||||
|
||||
// Marshal signs, and serializes a JWT.
|
||||
func (c *JSONWebSigner) Marshal(x interface{}) ([]byte, error) {
|
||||
s, err := jwt.Signed(c.Signer).Claims(x).CompactSerialize()
|
||||
return []byte(s), err
|
||||
}
|
||||
|
||||
// Unmarshal parses and validates a signed JWT.
|
||||
func (c *JSONWebSigner) Unmarshal(value []byte, s interface{}) error {
|
||||
tok, err := jwt.ParseSigned(string(value))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return tok.Claims(c.key, s)
|
||||
}
|
|
@ -1,18 +1,18 @@
|
|||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
package encoding // import "github.com/pomerium/pomerium/internal/encoding"
|
||||
|
||||
// MockEncoder MockCSRFStore is a mock implementation of Cipher.
|
||||
type MockEncoder struct {
|
||||
MarshalResponse string
|
||||
MarshalResponse []byte
|
||||
MarshalError error
|
||||
UnmarshalError error
|
||||
}
|
||||
|
||||
// Marshal is a mock implementation of MockEncoder.
|
||||
func (mc MockEncoder) Marshal(i interface{}) (string, error) {
|
||||
func (mc MockEncoder) Marshal(i interface{}) ([]byte, error) {
|
||||
return mc.MarshalResponse, mc.MarshalError
|
||||
}
|
||||
|
||||
// Unmarshal is a mock implementation of MockEncoder.
|
||||
func (mc MockEncoder) Unmarshal(s string, i interface{}) error {
|
||||
func (mc MockEncoder) Unmarshal(s []byte, i interface{}) error {
|
||||
return mc.UnmarshalError
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||
package encoding // import "github.com/pomerium/pomerium/internal/encoding"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
@ -8,7 +8,7 @@ import (
|
|||
func TestMockEncoder(t *testing.T) {
|
||||
e := errors.New("err")
|
||||
mc := MockEncoder{
|
||||
MarshalResponse: "MarshalResponse",
|
||||
MarshalResponse: []byte("MarshalResponse"),
|
||||
MarshalError: e,
|
||||
UnmarshalError: e,
|
||||
}
|
||||
|
@ -16,10 +16,10 @@ func TestMockEncoder(t *testing.T) {
|
|||
if err != e {
|
||||
t.Error("unexpected Marshal error")
|
||||
}
|
||||
if s != "MarshalResponse" {
|
||||
if string(s) != "MarshalResponse" {
|
||||
t.Error("unexpected MarshalResponse error")
|
||||
}
|
||||
err = mc.Unmarshal("s", "s")
|
||||
err = mc.Unmarshal([]byte("s"), "s")
|
||||
if err != e {
|
||||
t.Error("unexpected Unmarshal error")
|
||||
}
|
|
@ -2,6 +2,7 @@ package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -27,7 +28,7 @@ var httpClient = &http.Client{
|
|||
}
|
||||
|
||||
// Client provides a simple helper interface to make HTTP requests
|
||||
func Client(method, endpoint, userAgent string, headers map[string]string, params url.Values, response interface{}) error {
|
||||
func Client(ctx context.Context, method, endpoint, userAgent string, headers map[string]string, params url.Values, response interface{}) error {
|
||||
var body io.Reader
|
||||
switch method {
|
||||
case http.MethodPost:
|
||||
|
@ -41,7 +42,7 @@ func Client(method, endpoint, userAgent string, headers map[string]string, param
|
|||
default:
|
||||
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
|
||||
}
|
||||
req, err := http.NewRequest(method, endpoint, body)
|
||||
req, err := http.NewRequestWithContext(ctx, method, endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -9,9 +9,9 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// HeaderForwardHost is the header key the identifies the originating
|
||||
// IP addresses of a client connecting to a web server through an HTTP proxy
|
||||
// or a load balancer.
|
||||
// HeaderForwardHost is the header key that identifies the original host requested
|
||||
// by the client in the Host HTTP request header.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-Host
|
||||
const HeaderForwardHost = "X-Forwarded-Host"
|
||||
|
||||
// NewReverseProxy returns a new ReverseProxy that routes
|
||||
|
|
|
@ -4,14 +4,13 @@ import (
|
|||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/jwt"
|
||||
"golang.org/x/oauth2/google"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
|
@ -22,14 +21,12 @@ import (
|
|||
|
||||
const defaultGoogleProviderURL = "https://accounts.google.com"
|
||||
|
||||
// JWTTokenURL is Google's OAuth 2.0 token URL to use with the JWT flow.
|
||||
const JWTTokenURL = "https://accounts.google.com/o/oauth2/token"
|
||||
|
||||
// GoogleProvider is an implementation of the Provider interface.
|
||||
type GoogleProvider struct {
|
||||
*Provider
|
||||
// non-standard oidc fields
|
||||
RevokeURL *url.URL
|
||||
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
|
||||
apiClient *admin.Service
|
||||
}
|
||||
|
||||
|
@ -61,13 +58,9 @@ func NewGoogleProvider(p *Provider) (*GoogleProvider, error) {
|
|||
gp := &GoogleProvider{
|
||||
Provider: p,
|
||||
}
|
||||
// google supports a revocation endpoint
|
||||
var claims struct {
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
}
|
||||
|
||||
// build api client to make group membership api calls
|
||||
if err := p.provider.Claims(&claims); err != nil {
|
||||
if err := p.provider.Claims(&gp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// if service account set, configure admin sdk calls
|
||||
|
@ -78,34 +71,37 @@ func NewGoogleProvider(p *Provider) (*GoogleProvider, error) {
|
|||
}
|
||||
// Required scopes for groups api
|
||||
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
|
||||
conf, err := JWTConfigFromJSON(apiCreds, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope)
|
||||
conf, err := google.JWTConfigFromJSON(apiCreds, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/google: failed making jwt config from json %v", err)
|
||||
}
|
||||
var credentialsFile struct {
|
||||
ImpersonateUser string `json:"impersonate_user"`
|
||||
}
|
||||
if err := json.Unmarshal(apiCreds, &credentialsFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf.Subject = credentialsFile.ImpersonateUser
|
||||
client := conf.Client(context.TODO())
|
||||
gp.apiClient, err = admin.New(client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/google: failed creating admin service %v", err)
|
||||
}
|
||||
gp.UserGroupFn = gp.UserGroups
|
||||
} else {
|
||||
log.Warn().Msg("identity/google: no service account, cannot retrieve groups")
|
||||
}
|
||||
|
||||
gp.RevokeURL, err = url.Parse(claims.RevokeURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gp, nil
|
||||
}
|
||||
|
||||
// Revoke revokes the access token a given session state.
|
||||
//
|
||||
// https://developers.google.com/identity/protocols/OAuth2WebServer#tokenrevoke
|
||||
func (p *GoogleProvider) Revoke(accessToken string) error {
|
||||
func (p *GoogleProvider) Revoke(ctx context.Context, token *oauth2.Token) error {
|
||||
params := url.Values{}
|
||||
params.Add("token", accessToken)
|
||||
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), nil, params, nil)
|
||||
params.Add("token", token.AccessToken)
|
||||
err := httputil.Client(ctx, http.MethodPost, p.RevokeURL, version.UserAgent(), nil, params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
return err
|
||||
}
|
||||
|
@ -127,95 +123,14 @@ func (p *GoogleProvider) GetSignInURL(state string) string {
|
|||
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "select_account consent"))
|
||||
}
|
||||
|
||||
// Authenticate creates an identity session with google from a authorization code, and follows up
|
||||
// call to the admin/group api to check what groups the user is in.
|
||||
func (p *GoogleProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
|
||||
oauth2Token, err := p.oauth.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/google: token exchange failed %v", err)
|
||||
}
|
||||
|
||||
// id_token is a JWT that contains identity information about the user
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("identity/google: response did not contain an id_token")
|
||||
}
|
||||
session, err := p.IDTokenToSession(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session.AccessToken = oauth2Token.AccessToken
|
||||
session.RefreshToken = oauth2Token.RefreshToken
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Refresh renews a user's session using an oidc refresh token withoutreprompting the user.
|
||||
// Group membership is also refreshed.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||
func (p *GoogleProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
if s.RefreshToken == "" {
|
||||
return nil, errors.New("identity: missing refresh token")
|
||||
}
|
||||
t := oauth2.Token{RefreshToken: s.RefreshToken}
|
||||
newToken, err := p.oauth.TokenSource(ctx, &t).Token()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("identity: refresh failed")
|
||||
return nil, err
|
||||
}
|
||||
// id_token contains claims about the authenticated user
|
||||
rawIDToken, ok := newToken.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("identity/google: response did not contain an id_token")
|
||||
}
|
||||
newSession, err := p.IDTokenToSession(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSession.AccessToken = newToken.AccessToken
|
||||
newSession.RefreshToken = s.RefreshToken
|
||||
return newSession, nil
|
||||
}
|
||||
|
||||
// IDTokenToSession takes an identity provider issued JWT as input ('id_token')
|
||||
// and returns a session state. The provided token's audience ('aud') must
|
||||
// match Pomerium's client_id.
|
||||
func (p *GoogleProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) {
|
||||
idToken, err := p.verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/google: could not verify id_token %v", err)
|
||||
}
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
}
|
||||
// parse claims from the raw, encoded jwt token
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("identity/google: failed to parse id_token claims %v", err)
|
||||
}
|
||||
|
||||
// google requires additional call to retrieve groups.
|
||||
groups, err := p.UserGroups(ctx, claims.Email)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/google: could not retrieve groups %v", err)
|
||||
}
|
||||
|
||||
return &sessions.State{
|
||||
IDToken: rawIDToken,
|
||||
RefreshDeadline: idToken.Expiry.Truncate(time.Second),
|
||||
Email: claims.Email,
|
||||
User: idToken.Subject,
|
||||
Groups: groups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UserGroups returns a slice of group names a given user is in
|
||||
// NOTE: groups via Directory API is limited to 1 QPS!
|
||||
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
|
||||
// https://developers.google.com/admin-sdk/directory/v1/limits
|
||||
func (p *GoogleProvider) UserGroups(ctx context.Context, user string) ([]string, error) {
|
||||
func (p *GoogleProvider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
||||
var groups []string
|
||||
if p.apiClient != nil {
|
||||
req := p.apiClient.Groups.List().UserKey(user).MaxResults(100)
|
||||
req := p.apiClient.Groups.List().UserKey(s.Subject).MaxResults(100)
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/google: group api request failed %v", err)
|
||||
|
@ -226,57 +141,3 @@ func (p *GoogleProvider) UserGroups(ctx context.Context, user string) ([]string,
|
|||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// JWTConfigFromJSON uses a Google Developers service account JSON key file to read
|
||||
// the credentials that authorize and authenticate the requests.
|
||||
// Create a service account on "Credentials" for your project at
|
||||
// https://console.developers.google.com to download a JSON key file.
|
||||
func JWTConfigFromJSON(jsonKey []byte, scope ...string) (*jwt.Config, error) {
|
||||
var f credentialsFile
|
||||
if err := json.Unmarshal(jsonKey, &f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if f.Type != "service_account" {
|
||||
return nil, fmt.Errorf("identity/google: 'type' field is %q (expected %q)", f.Type, "service_account")
|
||||
}
|
||||
// Service account must impersonate a user : https://stackoverflow.com/a/48601364
|
||||
if f.ImpersonateUser == "" {
|
||||
return nil, errors.New("identity/google: impersonate_user not found in json config")
|
||||
}
|
||||
scope = append([]string(nil), scope...) // copy
|
||||
return f.jwtConfig(scope), nil
|
||||
}
|
||||
|
||||
// credentialsFile is the unmarshalled representation of a credentials file.
|
||||
type credentialsFile struct {
|
||||
Type string `json:"type"` // serviceAccountKey or userCredentialsKey
|
||||
|
||||
// Service account must impersonate a user
|
||||
ImpersonateUser string `json:"impersonate_user"`
|
||||
// Service Account fields
|
||||
ClientEmail string `json:"client_email"`
|
||||
PrivateKeyID string `json:"private_key_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
TokenURL string `json:"token_uri"`
|
||||
ProjectID string `json:"project_id"`
|
||||
|
||||
// User Credential fields
|
||||
ClientSecret string `json:"client_secret"`
|
||||
ClientID string `json:"client_id"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
func (f *credentialsFile) jwtConfig(scopes []string) *jwt.Config {
|
||||
cfg := &jwt.Config{
|
||||
Subject: f.ImpersonateUser,
|
||||
Email: f.ClientEmail,
|
||||
PrivateKey: []byte(f.PrivateKey),
|
||||
PrivateKeyID: f.PrivateKeyID,
|
||||
Scopes: scopes,
|
||||
TokenURL: f.TokenURL,
|
||||
}
|
||||
if cfg.TokenURL == "" {
|
||||
cfg.TokenURL = JWTTokenURL
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ const defaultAzureGroupURL = "https://graph.microsoft.com/v1.0/me/memberOf"
|
|||
type AzureProvider struct {
|
||||
*Provider
|
||||
// non-standard oidc fields
|
||||
RevokeURL *url.URL
|
||||
RevokeURL string `json:"end_session_endpoint"`
|
||||
}
|
||||
|
||||
// NewAzureProvider returns a new AzureProvider and sets the provider url endpoints.
|
||||
|
@ -54,84 +54,22 @@ func NewAzureProvider(p *Provider) (*AzureProvider, error) {
|
|||
Scopes: p.Scopes,
|
||||
}
|
||||
|
||||
azureProvider := &AzureProvider{
|
||||
Provider: p,
|
||||
}
|
||||
// azure has a "end session endpoint"
|
||||
var claims struct {
|
||||
RevokeURL string `json:"end_session_endpoint"`
|
||||
}
|
||||
if err := p.provider.Claims(&claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
azureProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
|
||||
if err != nil {
|
||||
azureProvider := &AzureProvider{Provider: p}
|
||||
if err := p.provider.Claims(&azureProvider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.UserGroupFn = azureProvider.UserGroups
|
||||
|
||||
return azureProvider, nil
|
||||
}
|
||||
|
||||
// Authenticate creates an identity session with azure from a authorization code, and follows up
|
||||
// call to the groups api to check what groups the user is in.
|
||||
func (p *AzureProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
|
||||
// convert authorization code into a token
|
||||
oauth2Token, err := p.oauth.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/microsoft: token exchange failed %v", err)
|
||||
}
|
||||
|
||||
// id_token contains claims about the authenticated user
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("identity/microsoft: response did not contain an id_token")
|
||||
}
|
||||
// Parse and verify ID Token payload.
|
||||
session, err := p.IDTokenToSession(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/microsoft: could not verify id_token %v", err)
|
||||
}
|
||||
|
||||
session.AccessToken = oauth2Token.AccessToken
|
||||
session.RefreshToken = oauth2Token.RefreshToken
|
||||
session.Groups, err = p.UserGroups(ctx, session.AccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/microsoft: could not retrieve groups %v", err)
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// IDTokenToSession takes an identity provider issued JWT as input ('id_token')
|
||||
// and returns a session state. The provided token's audience ('aud') must
|
||||
// match Pomerium's client_id.
|
||||
func (p *AzureProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) {
|
||||
idToken, err := p.verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/microsoft: could not verify id_token %v", err)
|
||||
}
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
}
|
||||
// parse claims from the raw, encoded jwt token
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("identity/microsoft: failed to parse id_token claims %v", err)
|
||||
}
|
||||
|
||||
return &sessions.State{
|
||||
IDToken: rawIDToken,
|
||||
RefreshDeadline: idToken.Expiry.Truncate(time.Second),
|
||||
Email: claims.Email,
|
||||
User: idToken.Subject,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Revoke revokes the access token a given session state.
|
||||
// https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-protocols-oidc#send-a-sign-out-request
|
||||
func (p *AzureProvider) Revoke(token string) error {
|
||||
func (p *AzureProvider) Revoke(ctx context.Context, token *oauth2.Token) error {
|
||||
params := url.Values{}
|
||||
params.Add("token", token)
|
||||
err := httputil.Client(http.MethodPost, p.RevokeURL.String(), version.UserAgent(), nil, params, nil)
|
||||
params.Add("token", token.AccessToken)
|
||||
err := httputil.Client(ctx, http.MethodPost, p.RevokeURL, version.UserAgent(), nil, params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
return err
|
||||
}
|
||||
|
@ -143,34 +81,14 @@ func (p *AzureProvider) GetSignInURL(state string) string {
|
|||
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "select_account"))
|
||||
}
|
||||
|
||||
// Refresh renews a user's session using an oid refresh token without reprompting the user.
|
||||
// Group membership is also refreshed.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||
func (p *AzureProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
if s.RefreshToken == "" {
|
||||
return nil, errors.New("identity/microsoft: missing refresh token")
|
||||
}
|
||||
t := oauth2.Token{RefreshToken: s.RefreshToken}
|
||||
newToken, err := p.oauth.TokenSource(ctx, &t).Token()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("identity/microsoft: refresh failed")
|
||||
return nil, err
|
||||
}
|
||||
s.AccessToken = newToken.AccessToken
|
||||
s.RefreshDeadline = newToken.Expiry.Truncate(time.Second)
|
||||
s.Groups, err = p.UserGroups(ctx, s.AccessToken)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("identity/microsoft: refresh failed")
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// UserGroups returns a slice of group names a given user is in.
|
||||
// `Directory.Read.All` is required.
|
||||
// https://docs.microsoft.com/en-us/graph/api/resources/directoryobject?view=graph-rest-1.0
|
||||
// https://docs.microsoft.com/en-us/graph/api/user-list-memberof?view=graph-rest-1.0
|
||||
func (p *AzureProvider) UserGroups(ctx context.Context, accessToken string) ([]string, error) {
|
||||
func (p *AzureProvider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
||||
if s == nil || s.AccessToken == nil {
|
||||
return nil, errors.New("identity/azure: session cannot be nil")
|
||||
}
|
||||
var response struct {
|
||||
Groups []struct {
|
||||
ID string `json:"id"`
|
||||
|
@ -180,15 +98,15 @@ func (p *AzureProvider) UserGroups(ctx context.Context, accessToken string) ([]s
|
|||
GroupTypes []string `json:"groupTypes,omitempty"`
|
||||
} `json:"value"`
|
||||
}
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", accessToken)}
|
||||
err := httputil.Client(http.MethodGet, defaultAzureGroupURL, version.UserAgent(), headers, nil, &response)
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
|
||||
err := httputil.Client(ctx, http.MethodGet, defaultAzureGroupURL, version.UserAgent(), headers, nil, &response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var groups []string
|
||||
for _, group := range response.Groups {
|
||||
log.Debug().Str("DisplayName", group.DisplayName).Str("ID", group.ID).Msg("identity/microsoft: group")
|
||||
groups = append(groups, group.DisplayName)
|
||||
groups = append(groups, group.ID)
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ package identity // import "github.com/pomerium/pomerium/internal/identity"
|
|||
import (
|
||||
"context"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
|
@ -10,11 +12,7 @@ import (
|
|||
type MockProvider struct {
|
||||
AuthenticateResponse sessions.State
|
||||
AuthenticateError error
|
||||
IDTokenToSessionResponse sessions.State
|
||||
IDTokenToSessionError error
|
||||
ValidateResponse bool
|
||||
ValidateError error
|
||||
RefreshResponse *sessions.State
|
||||
RefreshResponse sessions.State
|
||||
RefreshError error
|
||||
RevokeError error
|
||||
GetSignInURLResponse string
|
||||
|
@ -25,23 +23,13 @@ func (mp MockProvider) Authenticate(ctx context.Context, code string) (*sessions
|
|||
return &mp.AuthenticateResponse, mp.AuthenticateError
|
||||
}
|
||||
|
||||
// IDTokenToSession is a mocked providers function.
|
||||
func (mp MockProvider) IDTokenToSession(ctx context.Context, code string) (*sessions.State, error) {
|
||||
return &mp.IDTokenToSessionResponse, mp.IDTokenToSessionError
|
||||
}
|
||||
|
||||
// Validate is a mocked providers function.
|
||||
func (mp MockProvider) Validate(ctx context.Context, s string) (bool, error) {
|
||||
return mp.ValidateResponse, mp.ValidateError
|
||||
}
|
||||
|
||||
// Refresh is a mocked providers function.
|
||||
func (mp MockProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
return mp.RefreshResponse, mp.RefreshError
|
||||
return &mp.RefreshResponse, mp.RefreshError
|
||||
}
|
||||
|
||||
// Revoke is a mocked providers function.
|
||||
func (mp MockProvider) Revoke(s string) error {
|
||||
func (mp MockProvider) Revoke(ctx context.Context, s *oauth2.Token) error {
|
||||
return mp.RevokeError
|
||||
}
|
||||
|
||||
|
|
|
@ -2,14 +2,9 @@ package identity // import "github.com/pomerium/pomerium/internal/identity"
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
@ -17,6 +12,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
|
@ -26,7 +22,8 @@ import (
|
|||
type OktaProvider struct {
|
||||
*Provider
|
||||
|
||||
RevokeURL *url.URL
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
userAPI *url.URL
|
||||
}
|
||||
|
||||
// NewOktaProvider creates a new instance of Okta as an identity provider.
|
||||
|
@ -53,80 +50,62 @@ func NewOktaProvider(p *Provider) (*OktaProvider, error) {
|
|||
}
|
||||
|
||||
// okta supports a revocation endpoint
|
||||
var claims struct {
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
}
|
||||
if err := p.provider.Claims(&claims); err != nil {
|
||||
oktaProvider := OktaProvider{Provider: p}
|
||||
if err := p.provider.Claims(&oktaProvider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oktaProvider := OktaProvider{Provider: p}
|
||||
|
||||
oktaProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
|
||||
if p.ServiceAccount != "" {
|
||||
p.UserGroupFn = oktaProvider.UserGroups
|
||||
userAPI, err := urlutil.ParseAndValidateURL(p.ProviderURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userAPI.Path = "/api/v1/users/"
|
||||
oktaProvider.userAPI = userAPI
|
||||
|
||||
} else {
|
||||
log.Warn().Msg("identity/okta: api token provided, cannot retrieve groups")
|
||||
}
|
||||
|
||||
return &oktaProvider, nil
|
||||
}
|
||||
|
||||
// Revoke revokes the access token a given session state.
|
||||
// https://developer.okta.com/docs/api/resources/oidc#revoke
|
||||
func (p *OktaProvider) Revoke(token string) error {
|
||||
func (p *OktaProvider) Revoke(ctx context.Context, token *oauth2.Token) error {
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", p.ClientSecret)
|
||||
params.Add("token", token)
|
||||
params.Add("token", token.AccessToken)
|
||||
params.Add("token_type_hint", "refresh_token")
|
||||
err := httputil.Client(http.MethodPost, p.RevokeURL.String(), version.UserAgent(), nil, params, nil)
|
||||
err := httputil.Client(ctx, http.MethodPost, p.RevokeURL, version.UserAgent(), nil, params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type accessToken struct {
|
||||
Subject string `json:"sub"`
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
|
||||
// Refresh renews a user's session using an oid refresh token without reprompting the user.
|
||||
// Group membership is also refreshed. If configured properly, Okta is we can configure the access token
|
||||
// to include group membership claims which allows us to avoid a follow up oauth2 call.
|
||||
func (p *OktaProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
if s.RefreshToken == "" {
|
||||
return nil, errors.New("identity/okta: missing refresh token")
|
||||
// UserGroups fetches the groups of which the user is a member
|
||||
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
|
||||
func (p *OktaProvider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
||||
var response []struct {
|
||||
ID string `json:"id"`
|
||||
Profile struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"profile"`
|
||||
}
|
||||
t := oauth2.Token{RefreshToken: s.RefreshToken}
|
||||
newToken, err := p.oauth.TokenSource(ctx, &t).Token()
|
||||
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("SSWS %s", p.ServiceAccount)}
|
||||
err := httputil.Client(ctx, http.MethodGet, fmt.Sprintf("%s/%s/groups", p.userAPI.String(), s.Subject), version.UserAgent(), headers, nil, &response)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("identity/okta: refresh failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload, err := parseJWT(newToken.AccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/okta: malformed access token jwt: %v", err)
|
||||
var groups []string
|
||||
for _, group := range response {
|
||||
log.Debug().Interface("group", group).Msg("identity/okta: group")
|
||||
groups = append(groups, group.ID)
|
||||
}
|
||||
var token accessToken
|
||||
if err := json.Unmarshal(payload, &token); err != nil {
|
||||
return nil, fmt.Errorf("identity/okta: failed to unmarshal access token claims: %v", err)
|
||||
}
|
||||
if len(token.Groups) != 0 {
|
||||
s.Groups = token.Groups
|
||||
}
|
||||
|
||||
s.AccessToken = newToken.AccessToken
|
||||
s.RefreshDeadline = newToken.Expiry.Truncate(time.Second)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func parseJWT(p string) ([]byte, error) {
|
||||
parts := strings.Split(p, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
|
||||
}
|
||||
return payload, nil
|
||||
return groups, nil
|
||||
}
|
||||
|
|
|
@ -12,23 +12,22 @@ import (
|
|||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
const defaultOneLoginProviderURL = "https://openid-connect.onelogin.com/oidc"
|
||||
const defaultOneloginGroupURL = "https://openid-connect.onelogin.com/oidc/me"
|
||||
|
||||
// OneLoginProvider provides a standard, OpenID Connect implementation
|
||||
// of an authorization identity provider.
|
||||
type OneLoginProvider struct {
|
||||
*Provider
|
||||
|
||||
// non-standard oidc fields
|
||||
RevokeURL *url.URL
|
||||
AdminCreds *credentialsFile
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
}
|
||||
|
||||
const defaultOneLoginProviderURL = "https://openid-connect.onelogin.com/oidc"
|
||||
|
||||
// NewOneLoginProvider creates a new instance of an OpenID Connect provider.
|
||||
func NewOneLoginProvider(p *Provider) (*OneLoginProvider, error) {
|
||||
ctx := context.Background()
|
||||
|
@ -52,72 +51,38 @@ func NewOneLoginProvider(p *Provider) (*OneLoginProvider, error) {
|
|||
Scopes: p.Scopes,
|
||||
}
|
||||
|
||||
// okta supports a revocation endpoint
|
||||
var claims struct {
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
}
|
||||
if err := p.provider.Claims(&claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
OneLoginProvider := OneLoginProvider{Provider: p}
|
||||
olProvider := OneLoginProvider{Provider: p}
|
||||
|
||||
OneLoginProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
|
||||
if err != nil {
|
||||
if err := p.provider.Claims(&olProvider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &OneLoginProvider, nil
|
||||
p.UserGroupFn = olProvider.UserGroups
|
||||
|
||||
return &olProvider, nil
|
||||
}
|
||||
|
||||
// Revoke revokes the access token a given session state.
|
||||
// https://developers.onelogin.com/openid-connect/api/revoke-session
|
||||
func (p *OneLoginProvider) Revoke(token string) error {
|
||||
func (p *OneLoginProvider) Revoke(ctx context.Context, token *oauth2.Token) error {
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", p.ClientSecret)
|
||||
params.Add("token", token)
|
||||
params.Add("token", token.AccessToken)
|
||||
params.Add("token_type_hint", "access_token")
|
||||
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), nil, params, nil)
|
||||
err := httputil.Client(ctx, http.MethodPost, p.RevokeURL, version.UserAgent(), nil, params, nil)
|
||||
if err != nil && err != httputil.ErrTokenRevoked {
|
||||
log.Error().Err(err).Msg("authenticate/providers: failed to revoke session")
|
||||
return err
|
||||
return fmt.Errorf("identity/onelogin: revocation error %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSignInURL returns the sign in url with typical oauth parameters
|
||||
func (p *OneLoginProvider) GetSignInURL(state string) string {
|
||||
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||
}
|
||||
|
||||
// Refresh renews a user's session using an oid refresh token without reprompting the user.
|
||||
// Group membership is also refreshed.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||
func (p *OneLoginProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
if s.RefreshToken == "" {
|
||||
return nil, errors.New("identity/microsoft: missing refresh token")
|
||||
}
|
||||
t := oauth2.Token{RefreshToken: s.RefreshToken}
|
||||
newToken, err := p.oauth.TokenSource(ctx, &t).Token()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("identity/microsoft: refresh failed")
|
||||
return nil, err
|
||||
}
|
||||
s.AccessToken = newToken.AccessToken
|
||||
s.RefreshDeadline = newToken.Expiry.Truncate(time.Second)
|
||||
s.Groups, err = p.UserGroups(ctx, s.AccessToken)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("identity/microsoft: refresh failed")
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
const defaultOneloginGroupURL = "https://openid-connect.onelogin.com/oidc/me"
|
||||
|
||||
// UserGroups returns a slice of group names a given user is in.
|
||||
// https://developers.onelogin.com/openid-connect/api/user-info
|
||||
func (p *OneLoginProvider) UserGroups(ctx context.Context, accessToken string) ([]string, error) {
|
||||
func (p *OneLoginProvider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
|
||||
if s == nil || s.AccessToken == nil {
|
||||
return nil, errors.New("identity/onelogin: session cannot be nil")
|
||||
}
|
||||
var response struct {
|
||||
User string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
|
@ -128,15 +93,10 @@ func (p *OneLoginProvider) UserGroups(ctx context.Context, accessToken string) (
|
|||
FamilyName string `json:"family_name"`
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", accessToken)}
|
||||
err := httputil.Client(http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, &response)
|
||||
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
|
||||
err := httputil.Client(ctx, http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, &response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var groups []string
|
||||
for _, group := range response.Groups {
|
||||
log.Debug().Str("ID", group).Msg("identity/onelogin: group")
|
||||
groups = append(groups, group)
|
||||
}
|
||||
return groups, nil
|
||||
return response.Groups, nil
|
||||
}
|
||||
|
|
|
@ -7,11 +7,8 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
@ -34,22 +31,13 @@ const (
|
|||
|
||||
// ErrMissingProviderURL is returned when an identity provider requires a provider url
|
||||
// does not receive one.
|
||||
var ErrMissingProviderURL = errors.New("identity: missing provider url")
|
||||
|
||||
// UserGrouper is an interface representing the ability to retrieve group membership information
|
||||
// from an identity provider
|
||||
type UserGrouper interface {
|
||||
// UserGroups returns a slice of group names a given user is in
|
||||
UserGroups(context.Context, string) ([]string, error)
|
||||
}
|
||||
var ErrMissingProviderURL = errors.New("internal/identity: missing provider url")
|
||||
|
||||
// Authenticator is an interface representing the ability to authenticate with an identity provider.
|
||||
type Authenticator interface {
|
||||
Authenticate(context.Context, string) (*sessions.State, error)
|
||||
IDTokenToSession(context.Context, string) (*sessions.State, error)
|
||||
Validate(context.Context, string) (bool, error)
|
||||
Refresh(context.Context, *sessions.State) (*sessions.State, error)
|
||||
Revoke(string) error
|
||||
Revoke(context.Context, *oauth2.Token) error
|
||||
GetSignInURL(state string) string
|
||||
}
|
||||
|
||||
|
@ -59,8 +47,8 @@ func New(providerName string, p *Provider) (a Authenticator, err error) {
|
|||
switch providerName {
|
||||
case AzureProviderName:
|
||||
a, err = NewAzureProvider(p)
|
||||
case GitlabProviderName:
|
||||
return nil, fmt.Errorf("identity: %s currently not supported", providerName)
|
||||
// case GitlabProviderName:
|
||||
// return nil, fmt.Errorf("internal/identity: %s currently not supported", providerName)
|
||||
case GoogleProviderName:
|
||||
a, err = NewGoogleProvider(p)
|
||||
case OIDCProviderName:
|
||||
|
@ -70,7 +58,7 @@ func New(providerName string, p *Provider) (a Authenticator, err error) {
|
|||
case OneLoginProviderName:
|
||||
a, err = NewOneLoginProvider(p)
|
||||
default:
|
||||
return nil, fmt.Errorf("identity: %s provider not known", providerName)
|
||||
return nil, fmt.Errorf("internal/identity: %s provider not known", providerName)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -85,13 +73,16 @@ type Provider struct {
|
|||
ProviderName string
|
||||
|
||||
RedirectURL *url.URL
|
||||
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ProviderURL string
|
||||
Scopes []string
|
||||
|
||||
// Some providers, such as google, require additional remote api calls to retrieve
|
||||
// user details like groups. Provider is responsible for parsing.
|
||||
UserGroupFn func(context.Context, *sessions.State) ([]string, error)
|
||||
|
||||
// ServiceAccount can be set for those providers that require additional
|
||||
// credentials or tokens to do follow up API calls (e.g. Google)
|
||||
ServiceAccount string
|
||||
|
||||
provider *oidc.Provider
|
||||
|
@ -110,94 +101,73 @@ func (p *Provider) GetSignInURL(state string) string {
|
|||
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||
}
|
||||
|
||||
// Validate validates a given session's from it's JWT token
|
||||
// The function verifies it's been signed by the provider, preforms
|
||||
// any additional checks depending on the Config, and returns the payload.
|
||||
//
|
||||
// Validate does NOT do nonce validation.
|
||||
// Validate does NOT check if revoked.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func (p *Provider) Validate(ctx context.Context, idToken string) (bool, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "identity.provider.Validate")
|
||||
defer span.End()
|
||||
_, err := p.verifier.Verify(ctx, idToken)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("identity: failed to verify session state")
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// IDTokenToSession takes an identity provider issued JWT as input ('id_token')
|
||||
// and returns a session state. The provided token's audience ('aud') must
|
||||
// match Pomerium's client_id.
|
||||
func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) {
|
||||
idToken, err := p.verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity: could not verify id_token: %v", err)
|
||||
}
|
||||
// extract additional, non-oidc standard claims
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("identity: failed to parse id_token claims: %v", err)
|
||||
}
|
||||
|
||||
return &sessions.State{
|
||||
IDToken: rawIDToken,
|
||||
User: idToken.Subject,
|
||||
RefreshDeadline: idToken.Expiry.Truncate(time.Second),
|
||||
Email: claims.Email,
|
||||
Groups: claims.Groups,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// Authenticate creates a session with an identity provider from a authorization code
|
||||
// Authenticate creates an identity session with google from a authorization code, and follows up
|
||||
// call to the admin/group api to check what groups the user is in.
|
||||
func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
|
||||
// exchange authorization for a oidc token
|
||||
oauth2Token, err := p.oauth.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity: failed token exchange: %v", err)
|
||||
return nil, fmt.Errorf("internal/identity: token exchange failed: %w", err)
|
||||
}
|
||||
//id_token contains claims about the authenticated user
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||
}
|
||||
session, err := p.IDTokenToSession(ctx, rawIDToken)
|
||||
idToken, err := p.IdentityFromToken(ctx, oauth2Token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity: could not verify id_token: %v", err)
|
||||
}
|
||||
session.AccessToken = oauth2Token.AccessToken
|
||||
session.RefreshToken = oauth2Token.RefreshToken
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Refresh renews a user's session using therefresh_token without reprompting
|
||||
// the user. If supported, group membership is also refreshed.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||
func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
if s.RefreshToken == "" {
|
||||
return nil, errors.New("identity: missing refresh token")
|
||||
}
|
||||
t := oauth2.Token{RefreshToken: s.RefreshToken}
|
||||
newToken, err := p.oauth.TokenSource(ctx, &t).Token()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("identity: refresh failed")
|
||||
return nil, err
|
||||
}
|
||||
s.AccessToken = newToken.AccessToken
|
||||
s.RefreshDeadline = newToken.Expiry.Truncate(time.Second)
|
||||
|
||||
s, err := sessions.NewStateFromTokens(idToken, oauth2Token, p.RedirectURL.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p.UserGroupFn != nil {
|
||||
s.Groups, err = p.UserGroupFn(ctx, s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("internal/identity: could not retrieve groups %w", err)
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Refresh renews a user's session using an oidc refresh token withoutreprompting the user.
|
||||
// Group membership is also refreshed.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||
func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
|
||||
if s.AccessToken == nil || s.AccessToken.RefreshToken == "" {
|
||||
return nil, errors.New("internal/identity: missing refresh token")
|
||||
}
|
||||
|
||||
t := oauth2.Token{RefreshToken: s.AccessToken.RefreshToken}
|
||||
oauthToken, err := p.oauth.TokenSource(ctx, &t).Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("internal/identity: refresh failed %w", err)
|
||||
}
|
||||
idToken, err := p.IdentityFromToken(ctx, oauthToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.UpdateState(idToken, oauthToken); err != nil {
|
||||
return nil, fmt.Errorf("internal/identity: state update failed %w", err)
|
||||
}
|
||||
if p.UserGroupFn != nil {
|
||||
s.Groups, err = p.UserGroupFn(ctx, s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("internal/identity: could not retrieve groups %w", err)
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// IdentityFromToken takes an identity provider issued JWT as input ('id_token')
|
||||
// and returns a session state. The provided token's audience ('aud') must
|
||||
// match Pomerium's client_id.
|
||||
func (p *Provider) IdentityFromToken(ctx context.Context, t *oauth2.Token) (*oidc.IDToken, error) {
|
||||
rawIDToken, ok := t.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal/identity: id_token not found")
|
||||
}
|
||||
return p.verifier.Verify(ctx, rawIDToken)
|
||||
}
|
||||
|
||||
// Revoke enables a user to revoke her token. If the identity provider supports revocation
|
||||
// the endpoint is available, otherwise an error is thrown.
|
||||
func (p *Provider) Revoke(token string) error {
|
||||
return fmt.Errorf("identity: revoke not implemented by %s", p.ProviderName)
|
||||
func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error {
|
||||
return fmt.Errorf("internal/identity: revoke not implemented by %s", p.ProviderName)
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -12,8 +11,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
|
||||
"golang.org/x/net/publicsuffix"
|
||||
)
|
||||
|
||||
// SetHeaders sets a map of response headers.
|
||||
|
@ -30,72 +27,6 @@ func SetHeaders(headers map[string]string) func(next http.Handler) http.Handler
|
|||
}
|
||||
}
|
||||
|
||||
// ValidateClientSecret checks the request header for the client secret and returns
|
||||
// an error if it does not match the proxy client secret
|
||||
func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateClientSecret")
|
||||
defer span.End()
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("couldn't parse form", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
clientSecret := r.Form.Get("shared_secret")
|
||||
// check the request header for the client secret
|
||||
if clientSecret == "" {
|
||||
clientSecret = r.Header.Get("X-Client-Secret")
|
||||
}
|
||||
|
||||
if clientSecret != sharedSecret {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("client secret mismatch", http.StatusBadRequest, nil))
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateRedirectURI checks the redirect uri in the query parameters and ensures that
|
||||
// the its domain is in the list of proxy root domains.
|
||||
func ValidateRedirectURI(rootDomain *url.URL) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateRedirectURI")
|
||||
defer span.End()
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("couldn't parse form", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
redirectURI, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("bad redirect_uri", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
if !SameDomain(redirectURI, rootDomain) {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("redirect uri and root domain differ", http.StatusBadRequest, nil))
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// SameDomain checks to see if two URLs share the top level domain (TLD Plus One).
|
||||
func SameDomain(u, j *url.URL) bool {
|
||||
a, err := publicsuffix.EffectiveTLDPlusOne(u.Hostname())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
b, err := publicsuffix.EffectiveTLDPlusOne(j.Hostname())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return a == b
|
||||
}
|
||||
|
||||
// ValidateSignature ensures the request is valid and has been signed with
|
||||
// the correspdoning client secret key
|
||||
func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler {
|
||||
|
|
|
@ -17,36 +17,6 @@ func hmacHelperFunc(rawRedirect string, timestamp time.Time, secret string) []by
|
|||
return cryptutil.GenerateHMAC(data, secret)
|
||||
}
|
||||
|
||||
func Test_SameDomain(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
rootDomains string
|
||||
want bool
|
||||
}{
|
||||
{"good url redirect", "https://example.com/redirect", "https://example.com", true},
|
||||
{"good multilevel", "https://httpbin.a.corp.example.com", "https://auth.b.corp.example.com", true},
|
||||
{"good complex tld", "https://httpbin.a.corp.example.co.uk", "https://auth.b.corp.example.co.uk", true},
|
||||
{"bad complex tld", "https://httpbin.a.corp.notexample.co.uk", "https://auth.b.corp.example.co.uk", false},
|
||||
{"simple sub", "https://auth.example.com", "https://test.example.com", true},
|
||||
{"bad domain", "https://auth.example.com/redirect", "https://test.notexample.com", false},
|
||||
{"malformed url", "^example.com/redirect", "https://notexample.com", false},
|
||||
{"empty domain list", "https://example.com/redirect", ".com", false},
|
||||
{"empty domain", "https://example.com/redirect", "", false},
|
||||
{"empty url", "", "example.com", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u, _ := url.Parse(tt.uri)
|
||||
j, _ := url.Parse(tt.rootDomains)
|
||||
if got := SameDomain(u, j); got != tt.want {
|
||||
t.Errorf("SameDomain() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ValidSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
goodURL := "https://example.com/redirect"
|
||||
|
@ -109,87 +79,6 @@ func TestSetHeaders(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestValidateRedirectURI(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
rootDomain string
|
||||
redirectURI string
|
||||
status int
|
||||
}{
|
||||
{"simple", "https://auth.google.com", "redirect_uri=https://b.google.com", http.StatusOK},
|
||||
{"deep ok", "https://a.some.really.deep.sub.domain.google.com", "redirect_uri=https://b.some.really.deep.sub.domain.google.com", http.StatusOK},
|
||||
{"bad match", "https://auth.aol.com", "redirect_uri=https://test.google.com", http.StatusBadRequest},
|
||||
{"bad simple", "https://auth.corp.aol.com", "redirect_uri=https://test.corp.google.com", http.StatusBadRequest},
|
||||
{"deep bad", "https://a.some.really.deep.sub.domain.scroogle.com", "redirect_uri=https://b.some.really.deep.sub.domain.google.com", http.StatusBadRequest},
|
||||
{"with cname", "https://auth.google.com", "redirect_uri=https://www.google.com", http.StatusOK},
|
||||
{"with path", "https://auth.google.com", "redirect_uri=https://www.google.com/path", http.StatusOK},
|
||||
{"http mistmatch", "https://auth.google.com", "redirect_uri=http://www.google.com/path", http.StatusOK},
|
||||
{"http", "http://auth.google.com", "redirect_uri=http://www.google.com/path", http.StatusOK},
|
||||
{"ip", "http://1.1.1.1", "redirect_uri=http://8.8.8.8", http.StatusBadRequest},
|
||||
{"redirect get param not set", "https://auth.google.com", "not_redirect_uri!=https://b.google.com", http.StatusBadRequest},
|
||||
{"malformed, invalid get params", "https://auth.google.com", "redirect_uri=https://%zzzzz", http.StatusBadRequest},
|
||||
{"malformed, invalid url", "https://auth.google.com", "redirect_uri=https://accounts.google.^", http.StatusBadRequest},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{RawQuery: tt.redirectURI},
|
||||
}
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Hi"))
|
||||
})
|
||||
rr := httptest.NewRecorder()
|
||||
u, _ := url.Parse(tt.rootDomain)
|
||||
handler := ValidateRedirectURI(u)(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != tt.status {
|
||||
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
|
||||
t.Errorf("%s", rr.Body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateClientSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
sharedSecret string
|
||||
clientGetValue string
|
||||
clientHeaderValue string
|
||||
status int
|
||||
}{
|
||||
{"simple", "secret", "secret", "secret", http.StatusOK},
|
||||
{"missing get param, valid header", "secret", "", "secret", http.StatusOK},
|
||||
{"missing both", "secret", "", "", http.StatusBadRequest},
|
||||
{"simple bad", "bad-secret", "secret", "", http.StatusBadRequest},
|
||||
{"malformed, invalid hex digits", "secret", "%zzzzz", "", http.StatusBadRequest},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := &http.Request{
|
||||
Method: http.MethodGet,
|
||||
Header: http.Header{"X-Client-Secret": []string{tt.clientHeaderValue}},
|
||||
URL: &url.URL{RawQuery: fmt.Sprintf("shared_secret=%s", tt.clientGetValue)},
|
||||
}
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Hi"))
|
||||
})
|
||||
rr := httptest.NewRecorder()
|
||||
handler := ValidateClientSecret(tt.sharedSecret)(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != tt.status {
|
||||
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
|
||||
t.Errorf("%s", rr.Body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -28,49 +27,65 @@ const (
|
|||
// CookieStore implements the session store interface for session cookies.
|
||||
type CookieStore struct {
|
||||
Name string
|
||||
CookieDomain string
|
||||
CookieExpire time.Duration
|
||||
CookieHTTPOnly bool
|
||||
CookieSecure bool
|
||||
Encoder cryptutil.SecureEncoder
|
||||
Domain string
|
||||
Expire time.Duration
|
||||
HTTPOnly bool
|
||||
Secure bool
|
||||
encoder Marshaler
|
||||
decoder Unmarshaler
|
||||
}
|
||||
|
||||
// CookieStoreOptions holds options for CookieStore
|
||||
type CookieStoreOptions struct {
|
||||
// CookieOptions holds options for CookieStore
|
||||
type CookieOptions struct {
|
||||
Name string
|
||||
CookieDomain string
|
||||
CookieExpire time.Duration
|
||||
CookieHTTPOnly bool
|
||||
CookieSecure bool
|
||||
Encoder cryptutil.SecureEncoder
|
||||
Domain string
|
||||
Expire time.Duration
|
||||
HTTPOnly bool
|
||||
Secure bool
|
||||
}
|
||||
|
||||
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
|
||||
func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
|
||||
func NewCookieStore(opts *CookieOptions, encoder Encoder) (*CookieStore, error) {
|
||||
if opts.Name == "" {
|
||||
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
|
||||
}
|
||||
if opts.Encoder == nil {
|
||||
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
|
||||
if encoder == nil {
|
||||
return nil, fmt.Errorf("internal/sessions: decoder cannot be nil")
|
||||
}
|
||||
|
||||
return &CookieStore{
|
||||
Name: opts.Name,
|
||||
CookieSecure: opts.CookieSecure,
|
||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||
CookieDomain: opts.CookieDomain,
|
||||
CookieExpire: opts.CookieExpire,
|
||||
Encoder: opts.Encoder,
|
||||
Secure: opts.Secure,
|
||||
HTTPOnly: opts.HTTPOnly,
|
||||
Domain: opts.Domain,
|
||||
Expire: opts.Expire,
|
||||
encoder: encoder,
|
||||
decoder: encoder,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewCookieLoader returns a new session with ciphers for each of the cookie secrets
|
||||
func NewCookieLoader(opts *CookieOptions, decoder Unmarshaler) (*CookieStore, error) {
|
||||
if opts.Name == "" {
|
||||
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
|
||||
}
|
||||
if decoder == nil {
|
||||
return nil, fmt.Errorf("internal/sessions: decoder cannot be nil")
|
||||
}
|
||||
return &CookieStore{
|
||||
Name: opts.Name,
|
||||
Secure: opts.Secure,
|
||||
HTTPOnly: opts.HTTPOnly,
|
||||
Domain: opts.Domain,
|
||||
Expire: opts.Expire,
|
||||
decoder: decoder,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
domain := req.Host
|
||||
|
||||
if cs.CookieDomain != "" {
|
||||
domain = cs.CookieDomain
|
||||
} else {
|
||||
domain = ParentSubdomain(domain)
|
||||
if cs.Domain != "" {
|
||||
domain = cs.Domain
|
||||
}
|
||||
|
||||
if h, _, err := net.SplitHostPort(domain); err == nil {
|
||||
|
@ -81,8 +96,8 @@ func (cs *CookieStore) makeCookie(req *http.Request, name string, value string,
|
|||
Value: value,
|
||||
Path: "/",
|
||||
Domain: domain,
|
||||
HttpOnly: cs.CookieHTTPOnly,
|
||||
Secure: cs.CookieSecure,
|
||||
HttpOnly: cs.HTTPOnly,
|
||||
Secure: cs.Secure,
|
||||
}
|
||||
// only set an expiration if we want one, otherwise default to non perm session based
|
||||
if expiration != 0 {
|
||||
|
@ -98,23 +113,38 @@ func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
|
|||
|
||||
// LoadSession returns a State from the cookie in the request.
|
||||
func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
|
||||
cipherText := loadChunkedCookie(req, cs.Name)
|
||||
if cipherText == "" {
|
||||
data := loadChunkedCookie(req, cs.Name)
|
||||
if data == "" {
|
||||
return nil, ErrNoSessionFound
|
||||
}
|
||||
session, err := UnmarshalSession(cipherText, cs.Encoder)
|
||||
var session State
|
||||
err := cs.decoder.Unmarshal([]byte(data), &session)
|
||||
if err != nil {
|
||||
return nil, ErrMalformed
|
||||
}
|
||||
return session, nil
|
||||
|
||||
return &session, err
|
||||
}
|
||||
|
||||
// SaveSession saves a session state to a request sessions.
|
||||
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
|
||||
value, err := MarshalSession(s, cs.Encoder)
|
||||
// SaveSession saves a session state to a request's cookie store.
|
||||
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, x interface{}) error {
|
||||
var value string
|
||||
if cs.encoder != nil {
|
||||
data, err := cs.encoder.Marshal(x)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value = string(data)
|
||||
} else {
|
||||
switch v := x.(type) {
|
||||
case []byte:
|
||||
value = string(v)
|
||||
case string:
|
||||
value = v
|
||||
default:
|
||||
return errors.New("internal/sessions: cannot save non-string type")
|
||||
}
|
||||
}
|
||||
cs.setSessionCookie(w, req, value)
|
||||
return nil
|
||||
}
|
||||
|
@ -125,7 +155,7 @@ func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expira
|
|||
}
|
||||
|
||||
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
|
||||
cs.setCookie(w, cs.makeSessionCookie(req, val, cs.CookieExpire, time.Now()))
|
||||
cs.setCookie(w, cs.makeSessionCookie(req, val, cs.Expire, time.Now()))
|
||||
}
|
||||
|
||||
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
||||
|
@ -153,11 +183,11 @@ func loadChunkedCookie(r *http.Request, cookieName string) string {
|
|||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
cipherText := c.Value
|
||||
data := c.Value
|
||||
// if the first byte is our canary byte, we need to handle the multipart bit
|
||||
if []byte(c.Value)[0] == ChunkedCanaryByte {
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "%s", cipherText[1:])
|
||||
fmt.Fprintf(&b, "%s", data[1:])
|
||||
for i := 1; i <= MaxNumChunks; i++ {
|
||||
next, err := r.Cookie(fmt.Sprintf("%s_%d", cookieName, i))
|
||||
if err != nil {
|
||||
|
@ -165,9 +195,9 @@ func loadChunkedCookie(r *http.Request, cookieName string) string {
|
|||
}
|
||||
fmt.Fprintf(&b, "%s", next.Value)
|
||||
}
|
||||
cipherText = b.String()
|
||||
data = b.String()
|
||||
}
|
||||
return cipherText
|
||||
return data
|
||||
}
|
||||
|
||||
func chunk(s string, size int) []string {
|
||||
|
|
|
@ -6,86 +6,43 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||
)
|
||||
|
||||
type MockEncoder struct{}
|
||||
|
||||
func (a MockEncoder) Marshal(s interface{}) (string, error) { return "", errors.New("error") }
|
||||
func (a MockEncoder) Unmarshal(s string, i interface{}) error {
|
||||
if s == "unmarshal error" || s == "error" {
|
||||
return errors.New("error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func TestNewCookieStore(t *testing.T) {
|
||||
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
encoder := cryptutil.NewSecureJSONEncoder(cipher)
|
||||
encoder := ecjson.New(cipher)
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *CookieStoreOptions
|
||||
opts *CookieOptions
|
||||
encoder Encoder
|
||||
want *CookieStore
|
||||
wantErr bool
|
||||
}{
|
||||
{"good",
|
||||
&CookieStoreOptions{
|
||||
Name: "_cookie",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
Encoder: encoder,
|
||||
},
|
||||
&CookieStore{
|
||||
Name: "_cookie",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
Encoder: encoder,
|
||||
},
|
||||
false},
|
||||
{"missing name",
|
||||
&CookieStoreOptions{
|
||||
Name: "",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
Encoder: encoder,
|
||||
},
|
||||
nil,
|
||||
true},
|
||||
{"missing cipher",
|
||||
&CookieStoreOptions{
|
||||
Name: "_pomerium",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
Encoder: nil,
|
||||
},
|
||||
nil,
|
||||
true},
|
||||
{"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
|
||||
{"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
|
||||
{"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewCookieStore(tt.opts)
|
||||
got, err := NewCookieStore(tt.opts, tt.encoder)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
cmpOpts := []cmp.Option{
|
||||
cmpopts.IgnoreUnexported(cryptutil.SecureJSONEncoder{}),
|
||||
cmpopts.IgnoreUnexported(CookieStore{}),
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
||||
|
@ -94,6 +51,40 @@ func TestNewCookieStore(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
func TestNewCookieLoader(t *testing.T) {
|
||||
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
encoder := ecjson.New(cipher)
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *CookieOptions
|
||||
encoder Encoder
|
||||
want *CookieStore
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
|
||||
{"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
|
||||
{"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewCookieLoader(tt.opts, tt.encoder)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewCookieLoader() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
cmpOpts := []cmp.Option{
|
||||
cmpopts.IgnoreUnexported(CookieStore{}),
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
||||
t.Errorf("NewCookieLoader() = %s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookieStore_makeCookie(t *testing.T) {
|
||||
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
|
||||
|
@ -114,10 +105,10 @@ func TestCookieStore_makeCookie(t *testing.T) {
|
|||
want *http.Cookie
|
||||
wantCSRF *http.Cookie
|
||||
}{
|
||||
{"good", "http://httpbin.corp.pomerium.io", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
||||
{"domains with https", "https://httpbin.corp.pomerium.io", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
||||
{"domain with port", "http://httpbin.corp.pomerium.io:443", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
||||
{"expiration set", "http://httpbin.corp.pomerium.io:443", "", "_pomerium", "value", 10 * time.Second, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
||||
{"good", "http://httpbin.corp.pomerium.io", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
||||
{"domains with https", "https://httpbin.corp.pomerium.io", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
||||
{"domain with port", "http://httpbin.corp.pomerium.io:443", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
||||
{"expiration set", "http://httpbin.corp.pomerium.io:443", "", "_pomerium", "value", 10 * time.Second, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
||||
{"good", "http://httpbin.corp.pomerium.io", "pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -125,13 +116,14 @@ func TestCookieStore_makeCookie(t *testing.T) {
|
|||
r := httptest.NewRequest("GET", tt.domain, nil)
|
||||
|
||||
s, err := NewCookieStore(
|
||||
&CookieStoreOptions{
|
||||
&CookieOptions{
|
||||
Name: "_pomerium",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: tt.cookieDomain,
|
||||
CookieExpire: 10 * time.Second,
|
||||
Encoder: cryptutil.NewSecureJSONEncoder(cipher)})
|
||||
Secure: true,
|
||||
HTTPOnly: true,
|
||||
Domain: tt.cookieDomain,
|
||||
Expire: 10 * time.Second,
|
||||
},
|
||||
ecjson.New(cipher))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -151,7 +143,6 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cipher := cryptutil.NewSecureJSONEncoder(c)
|
||||
|
||||
hugeString := make([]byte, 4097)
|
||||
if _, err := rand.Read(hugeString); err != nil {
|
||||
|
@ -160,23 +151,28 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
State *State
|
||||
cipher cryptutil.SecureEncoder
|
||||
encoder Encoder
|
||||
decoder Encoder
|
||||
wantErr bool
|
||||
wantLoadErr bool
|
||||
}{
|
||||
{"good", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
|
||||
{"bad cipher", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, MockEncoder{}, true, true},
|
||||
{"huge cookie", &State{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
|
||||
{"good", &State{Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
|
||||
{"bad cipher", &State{Email: "user@domain.com", User: "user"}, nil, nil, true, true},
|
||||
{"huge cookie", &State{Subject: fmt.Sprintf("%x", hugeString), Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
|
||||
{"marshal error", &State{Email: "user@domain.com", User: "user"}, encoding.MockEncoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true},
|
||||
{"nil encoder cannot save non string type", &State{Email: "user@domain.com", User: "user"}, nil, ecjson.New(c), true, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &CookieStore{
|
||||
Name: "_pomerium",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
Encoder: tt.cipher}
|
||||
Secure: true,
|
||||
HTTPOnly: true,
|
||||
Domain: "pomerium.io",
|
||||
Expire: 10 * time.Second,
|
||||
encoder: tt.encoder,
|
||||
decoder: tt.encoder,
|
||||
}
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
@ -195,54 +191,19 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
|||
t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
|
||||
return
|
||||
}
|
||||
cmpOpts := []cmp.Option{
|
||||
cmpopts.IgnoreUnexported(State{}),
|
||||
}
|
||||
if err == nil {
|
||||
if diff := cmp.Diff(state, tt.State); diff != "" {
|
||||
if diff := cmp.Diff(state, tt.State, cmpOpts...); diff != "" {
|
||||
t.Errorf("CookieStore.LoadSession() got = %s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMockSessionStore(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mockCSRF *MockSessionStore
|
||||
saveSession *State
|
||||
wantLoadErr bool
|
||||
wantSaveErr bool
|
||||
}{
|
||||
{"basic",
|
||||
&MockSessionStore{
|
||||
ResponseSession: "test",
|
||||
Session: &State{AccessToken: "AccessToken"},
|
||||
SaveError: nil,
|
||||
LoadError: nil,
|
||||
},
|
||||
&State{AccessToken: "AccessToken"},
|
||||
false,
|
||||
false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ms := tt.mockCSRF
|
||||
|
||||
err := ms.SaveSession(nil, nil, tt.saveSession)
|
||||
if (err != nil) != tt.wantSaveErr {
|
||||
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantSaveErr %v", err, tt.wantSaveErr)
|
||||
return
|
||||
}
|
||||
got, err := ms.LoadSession(nil)
|
||||
if (err != nil) != tt.wantLoadErr {
|
||||
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantLoadErr %v", err, tt.wantLoadErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.mockCSRF.Session) {
|
||||
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Session)
|
||||
}
|
||||
ms.ClearSession(nil, nil)
|
||||
if ms.ResponseSession != "" {
|
||||
t.Errorf("ResponseSession not empty! %s", ms.ResponseSession)
|
||||
w = httptest.NewRecorder()
|
||||
s.ClearSession(w, r)
|
||||
x := w.Header().Get("Set-Cookie")
|
||||
if !strings.Contains(x, "_pomerium=; Path=/;") {
|
||||
t.Errorf(x)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -3,14 +3,9 @@ package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
|||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultAuthHeader and defaultAuthType are default header name for the
|
||||
// authorization bearer token header as defined in rfc2617
|
||||
// https://tools.ietf.org/html/rfc6750#section-2.1
|
||||
defaultAuthHeader = "Authorization"
|
||||
defaultAuthType = "Bearer"
|
||||
)
|
||||
|
@ -20,41 +15,44 @@ const (
|
|||
type HeaderStore struct {
|
||||
authHeader string
|
||||
authType string
|
||||
encoder cryptutil.SecureEncoder
|
||||
encoder Unmarshaler
|
||||
}
|
||||
|
||||
// NewHeaderStore returns a new header store for loading sessions from
|
||||
// authorization headers.
|
||||
func NewHeaderStore(enc cryptutil.SecureEncoder) *HeaderStore {
|
||||
// authorization header as defined in as defined in rfc2617
|
||||
//
|
||||
// NOTA BENE: While most servers do not log Authorization headers by default,
|
||||
// you should ensure no other services are logging or leaking your auth headers.
|
||||
func NewHeaderStore(enc Unmarshaler, headerType string) *HeaderStore {
|
||||
if headerType == "" {
|
||||
headerType = defaultAuthType
|
||||
}
|
||||
return &HeaderStore{
|
||||
authHeader: defaultAuthHeader,
|
||||
authType: defaultAuthType,
|
||||
authType: headerType,
|
||||
encoder: enc,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSession tries to retrieve the token string from the Authorization header.
|
||||
//
|
||||
// NOTA BENE: While most servers do not log Authorization headers by default,
|
||||
// you should ensure no other services are logging or leaking your auth headers.
|
||||
func (as *HeaderStore) LoadSession(r *http.Request) (*State, error) {
|
||||
cipherText := as.tokenFromHeader(r)
|
||||
cipherText := TokenFromHeader(r, as.authHeader, as.authType)
|
||||
if cipherText == "" {
|
||||
return nil, ErrNoSessionFound
|
||||
}
|
||||
session, err := UnmarshalSession(cipherText, as.encoder)
|
||||
if err != nil {
|
||||
var session State
|
||||
if err := as.encoder.Unmarshal([]byte(cipherText), &session); err != nil {
|
||||
return nil, ErrMalformed
|
||||
}
|
||||
return session, nil
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// retrieve the value of the authorization header
|
||||
func (as *HeaderStore) tokenFromHeader(r *http.Request) string {
|
||||
bearer := r.Header.Get(as.authHeader)
|
||||
atSize := len(as.authType)
|
||||
if len(bearer) > atSize && strings.EqualFold(bearer[0:atSize], as.authType) {
|
||||
// TokenFromHeader retrieves the value of the authorization header from a given
|
||||
// request, header key, and authentication type.
|
||||
func TokenFromHeader(r *http.Request, authHeader, authType string) string {
|
||||
bearer := r.Header.Get(authHeader)
|
||||
atSize := len(authType)
|
||||
if len(bearer) > atSize && strings.EqualFold(bearer[0:atSize], authType) {
|
||||
return bearer[atSize+1:]
|
||||
}
|
||||
return ""
|
||||
|
|
|
@ -12,10 +12,8 @@ var (
|
|||
ErrorCtxKey = &contextKey{"Error"}
|
||||
)
|
||||
|
||||
// RetrieveSession will search for a auth session in a http request, in the order:
|
||||
// 1. `pomerium_session` URI query parameter
|
||||
// 2. `Authorization: BEARER` request header
|
||||
// 3. Cookie `_pomerium` value
|
||||
// RetrieveSession takes a slice of session loaders and tries to find a valid
|
||||
// session in the order they were supplied and is added to the request's context
|
||||
func RetrieveSession(s ...SessionLoader) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return retrieve(s...)(next)
|
||||
|
@ -34,35 +32,21 @@ func retrieve(s ...SessionLoader) func(http.Handler) http.Handler {
|
|||
}
|
||||
}
|
||||
|
||||
// retrieveFromRequest extracts sessions state from the request by calling
|
||||
// token find functions in the order they where provided.
|
||||
func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, error) {
|
||||
state := new(State)
|
||||
var err error
|
||||
|
||||
// Extract sessions state from the request by calling token find functions in
|
||||
// the order they where provided. Further extraction stops if a function
|
||||
// returns a non-empty string.
|
||||
for _, s := range sessions {
|
||||
state, err = s.LoadSession(r)
|
||||
state, err := s.LoadSession(r)
|
||||
if err != nil && !errors.Is(err, ErrNoSessionFound) {
|
||||
// unexpected error
|
||||
return nil, err
|
||||
}
|
||||
// break, we found a session state
|
||||
if state != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
// no session found if state is still empty
|
||||
if state == nil {
|
||||
return nil, ErrNoSessionFound
|
||||
}
|
||||
|
||||
if err = state.Valid(); err != nil {
|
||||
// a little unusual but we want to return the expired state too
|
||||
return state, err
|
||||
}
|
||||
if state != nil {
|
||||
err := state.Verify(r.Host)
|
||||
return state, err // N.B.: state is _not nil_
|
||||
}
|
||||
}
|
||||
|
||||
return state, nil
|
||||
return nil, ErrNoSessionFound
|
||||
}
|
||||
|
||||
// NewContext sets context values for the user session state and error.
|
||||
|
|
|
@ -11,6 +11,8 @@ import (
|
|||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func TestNewContext(t *testing.T) {
|
||||
|
@ -27,7 +29,7 @@ func TestNewContext(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctxOut := NewContext(tt.ctx, tt.t, tt.err)
|
||||
stateOut, errOut := FromContext(ctxOut)
|
||||
if diff := cmp.Diff(tt.t, stateOut); diff != "" {
|
||||
if diff := cmp.Diff(tt.t.Email, stateOut.Email); diff != "" {
|
||||
t.Errorf("NewContext() = %s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tt.err, errOut); diff != "" {
|
||||
|
@ -67,56 +69,54 @@ func TestVerifier(t *testing.T) {
|
|||
wantBody string
|
||||
wantStatus int
|
||||
}{
|
||||
{"good cookie session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||
{"malformed cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"good auth header session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||
{"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
{"good cookie session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||
{"malformed cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"good auth header session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||
{"malformed auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"good auth query param session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||
{"malformed auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"no session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||
encoder := cryptutil.NewSecureJSONEncoder(cipher)
|
||||
|
||||
encoder := ecjson.New(cipher)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
encSession, err := MarshalSession(&tt.state, encoder)
|
||||
encSession, err := encoder.Marshal(&tt.state)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(tt.name, "malformed") {
|
||||
// add some garbage to the end of the string
|
||||
encSession += cryptutil.NewBase64Key()
|
||||
encSession = append(encSession, cryptutil.NewKey()...)
|
||||
}
|
||||
|
||||
cs, err := NewCookieStore(&CookieStoreOptions{
|
||||
cs, err := NewCookieStore(&CookieOptions{
|
||||
Name: "_pomerium",
|
||||
Encoder: encoder,
|
||||
})
|
||||
}, encoder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
as := NewHeaderStore(encoder)
|
||||
as := NewHeaderStore(encoder, "")
|
||||
|
||||
qp := NewQueryParamStore(encoder)
|
||||
qp := NewQueryParamStore(encoder, "")
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
if tt.cookie {
|
||||
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: encSession})
|
||||
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: string(encSession)})
|
||||
} else if tt.header {
|
||||
r.Header.Set("Authorization", "Bearer "+encSession)
|
||||
r.Header.Set("Authorization", "Bearer "+string(encSession))
|
||||
} else if tt.param {
|
||||
q := r.URL.Query()
|
||||
|
||||
q.Set("pomerium_session", encSession)
|
||||
q.Set("pomerium_session", string(encSession))
|
||||
r.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,6 @@ func (ms MockSessionStore) LoadSession(*http.Request) (*State, error) {
|
|||
}
|
||||
|
||||
// SaveSession returns a save error.
|
||||
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *State) error {
|
||||
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
|
||||
return ms.SaveError
|
||||
}
|
||||
|
|
50
internal/sessions/mock_store_test.go
Normal file
50
internal/sessions/mock_store_test.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMockSessionStore(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mockCSRF *MockSessionStore
|
||||
saveSession *State
|
||||
wantLoadErr bool
|
||||
wantSaveErr bool
|
||||
}{
|
||||
{"basic",
|
||||
&MockSessionStore{
|
||||
ResponseSession: "test",
|
||||
Session: &State{Subject: "0101"},
|
||||
SaveError: nil,
|
||||
LoadError: nil,
|
||||
},
|
||||
&State{Subject: "0101"},
|
||||
false,
|
||||
false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ms := tt.mockCSRF
|
||||
|
||||
err := ms.SaveSession(nil, nil, tt.saveSession)
|
||||
if (err != nil) != tt.wantSaveErr {
|
||||
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantSaveErr %v", err, tt.wantSaveErr)
|
||||
return
|
||||
}
|
||||
got, err := ms.LoadSession(nil)
|
||||
if (err != nil) != tt.wantLoadErr {
|
||||
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantLoadErr %v", err, tt.wantLoadErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.mockCSRF.Session) {
|
||||
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Session)
|
||||
}
|
||||
ms.ClearSession(nil, nil)
|
||||
if ms.ResponseSession != "" {
|
||||
t.Errorf("ResponseSession not empty! %s", ms.ResponseSession)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -2,8 +2,6 @@ package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -14,31 +12,55 @@ const (
|
|||
// query strings / query parameters.
|
||||
type QueryParamStore struct {
|
||||
queryParamKey string
|
||||
encoder cryptutil.SecureEncoder
|
||||
encoder Marshaler
|
||||
decoder Unmarshaler
|
||||
}
|
||||
|
||||
// NewQueryParamStore returns a new query param store for loading sessions from
|
||||
// query strings / query parameters.
|
||||
func NewQueryParamStore(enc cryptutil.SecureEncoder) *QueryParamStore {
|
||||
//
|
||||
// NOTA BENE: By default, most servers _DO_ log query params, the leaking or
|
||||
// accidental logging of which should be considered a security issue.
|
||||
func NewQueryParamStore(enc Encoder, qp string) *QueryParamStore {
|
||||
if qp == "" {
|
||||
qp = defaultQueryParamKey
|
||||
}
|
||||
return &QueryParamStore{
|
||||
queryParamKey: defaultQueryParamKey,
|
||||
queryParamKey: qp,
|
||||
encoder: enc,
|
||||
decoder: enc,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSession tries to retrieve the token string from URL query parameters.
|
||||
//
|
||||
// NOTA BENE: By default, most servers _DO_ log query params, the leaking or
|
||||
// accidental logging of which should be considered a security issue.
|
||||
func (qp *QueryParamStore) LoadSession(r *http.Request) (*State, error) {
|
||||
cipherText := r.URL.Query().Get(qp.queryParamKey)
|
||||
if cipherText == "" {
|
||||
return nil, ErrNoSessionFound
|
||||
}
|
||||
session, err := UnmarshalSession(cipherText, qp.encoder)
|
||||
if err != nil {
|
||||
var session State
|
||||
if err := qp.decoder.Unmarshal([]byte(cipherText), &session); err != nil {
|
||||
return nil, ErrMalformed
|
||||
}
|
||||
return session, nil
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// ClearSession clears the session cookie from a request's query param key `pomerium_session`.
|
||||
func (qp *QueryParamStore) ClearSession(w http.ResponseWriter, r *http.Request) {
|
||||
params := r.URL.Query()
|
||||
params.Del(qp.queryParamKey)
|
||||
r.URL.RawQuery = params.Encode()
|
||||
}
|
||||
|
||||
// SaveSession sets a session to a request's query param key `pomerium_session`
|
||||
func (qp *QueryParamStore) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
|
||||
data, err := qp.encoder.Marshal(x)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.URL.Query().Get(qp.queryParamKey)
|
||||
params := r.URL.Query()
|
||||
params.Set(qp.queryParamKey, string(data))
|
||||
r.URL.RawQuery = params.Encode()
|
||||
return nil
|
||||
}
|
||||
|
|
47
internal/sessions/query_store_test.go
Normal file
47
internal/sessions/query_store_test.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package sessions
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
)
|
||||
|
||||
func TestNewQueryParamStore(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
State *State
|
||||
|
||||
enc Encoder
|
||||
qp string
|
||||
wantErr bool
|
||||
wantURL *url.URL
|
||||
}{
|
||||
{"simple good", &State{Email: "user@domain.com", User: "user"}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, "", false, &url.URL{Path: "/", RawQuery: "pomerium_session=ok"}},
|
||||
{"marshall error", &State{Email: "user@domain.com", User: "user"}, encoding.MockEncoder{MarshalError: errors.New("error")}, "", true, &url.URL{Path: "/"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewQueryParamStore(tt.enc, tt.qp)
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewQueryParamStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(r.URL, tt.wantURL); diff != "" {
|
||||
t.Errorf("NewQueryParamStore() = %v", diff)
|
||||
}
|
||||
got.ClearSession(w, r)
|
||||
if diff := cmp.Diff(r.URL, &url.URL{Path: "/"}); diff != "" {
|
||||
t.Errorf("NewQueryParamStore() = %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,46 +1,147 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultLeeway defines the default leeway for matching NotBefore/Expiry claims.
|
||||
DefaultLeeway = 1.0 * time.Minute
|
||||
)
|
||||
|
||||
// timeNow is time.Now but pulled out as a variable for tests.
|
||||
var timeNow = time.Now
|
||||
|
||||
// State is our object that keeps track of a user's session state
|
||||
type State struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshDeadline time.Time `json:"refresh_deadline"`
|
||||
// Public claim values (as specified in RFC 7519).
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
Audience jwt.Audience `json:"aud,omitempty"`
|
||||
Expiry *jwt.NumericDate `json:"exp,omitempty"`
|
||||
NotBefore *jwt.NumericDate `json:"nbf,omitempty"`
|
||||
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
|
||||
ID string `json:"jti,omitempty"`
|
||||
|
||||
// core pomerium identity claims ; not standard to RFC 7519
|
||||
Email string `json:"email"`
|
||||
User string `json:"user"`
|
||||
Groups []string `json:"groups"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
User string `json:"user,omitempty"` // google
|
||||
|
||||
ImpersonateEmail string
|
||||
ImpersonateGroups []string
|
||||
// commonly supported IdP information
|
||||
// https://www.iana.org/assignments/jwt/jwt.xhtml#claims
|
||||
Name string `json:"name,omitempty"` // google
|
||||
GivenName string `json:"given_name,omitempty"` // google
|
||||
FamilyName string `json:"family_name,omitempty"` // google
|
||||
Picture string `json:"picture,omitempty"` // google
|
||||
EmailVerified bool `json:"email_verified,omitempty"` // google
|
||||
|
||||
// Impersonate-able fields
|
||||
ImpersonateEmail string `json:"impersonate_email,omitempty"`
|
||||
ImpersonateGroups []string `json:"impersonate_groups,omitempty"`
|
||||
|
||||
// Programmatic whether this state is used for machine-to-machine
|
||||
// programatic access.
|
||||
Programmatic bool `json:"programatic"`
|
||||
|
||||
AccessToken *oauth2.Token `json:"access_token,omitempty"`
|
||||
|
||||
idToken *oidc.IDToken
|
||||
}
|
||||
|
||||
// Valid returns an error if the users's session state is not valid.
|
||||
func (s *State) Valid() error {
|
||||
if s.Expired() {
|
||||
return ErrExpired
|
||||
// NewStateFromTokens returns a session state built from oidc and oauth2
|
||||
// tokens as part of OpenID Connect flow with a new audience appended to the
|
||||
// audience claim.
|
||||
func NewStateFromTokens(idToken *oidc.IDToken, accessToken *oauth2.Token, audience string) (*State, error) {
|
||||
if idToken == nil {
|
||||
return nil, errors.New("sessions: oidc id token missing")
|
||||
}
|
||||
if accessToken == nil {
|
||||
return nil, errors.New("sessions: oauth2 token missing")
|
||||
}
|
||||
s := &State{}
|
||||
if err := idToken.Claims(s); err != nil {
|
||||
return nil, fmt.Errorf("sessions: couldn't unmarshal extra claims %w", err)
|
||||
}
|
||||
s.Audience = []string{audience}
|
||||
s.idToken = idToken
|
||||
s.AccessToken = accessToken
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// UpdateState updates the current state given a new identity (oidc) and authorization
|
||||
// (oauth2) tokens following a oidc refresh. NB, unlike during authentication,
|
||||
// refresh typically provides fewer claims in the token so we want to build from
|
||||
// our previous state.
|
||||
func (s *State) UpdateState(idToken *oidc.IDToken, accessToken *oauth2.Token) error {
|
||||
if idToken == nil {
|
||||
return errors.New("sessions: oidc id token missing")
|
||||
}
|
||||
if accessToken == nil {
|
||||
return errors.New("sessions: oauth2 token missing")
|
||||
}
|
||||
audience := append(s.Audience[:0:0], s.Audience...)
|
||||
s.AccessToken = accessToken
|
||||
if err := idToken.Claims(s); err != nil {
|
||||
return fmt.Errorf("sessions: update state failed %w", err)
|
||||
}
|
||||
s.Audience = audience
|
||||
s.Expiry = jwt.NewNumericDate(accessToken.Expiry)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForceRefresh sets the refresh deadline to now.
|
||||
func (s *State) ForceRefresh() {
|
||||
s.RefreshDeadline = time.Now().Truncate(time.Second)
|
||||
// NewSession updates issuer, audience, and issuance timestamps but keeps
|
||||
// parent expiry.
|
||||
func (s State) NewSession(issuer string, audience []string) *State {
|
||||
s.IssuedAt = jwt.NewNumericDate(timeNow())
|
||||
s.NotBefore = s.IssuedAt
|
||||
s.Audience = audience
|
||||
s.Issuer = issuer
|
||||
return &s
|
||||
}
|
||||
|
||||
// Expired returns true if the refresh period has expired
|
||||
func (s *State) Expired() bool {
|
||||
return s.RefreshDeadline.Before(time.Now())
|
||||
// RouteSession creates a route session with access tokens stripped and a
|
||||
// custom validity period.
|
||||
func (s State) RouteSession(validity time.Duration) *State {
|
||||
s.Expiry = jwt.NewNumericDate(timeNow().Add(validity))
|
||||
s.AccessToken = nil
|
||||
return &s
|
||||
}
|
||||
|
||||
// Verify returns an error if the users's session state is not valid.
|
||||
func (s *State) Verify(audience string) error {
|
||||
if s.NotBefore != nil && timeNow().Add(DefaultLeeway).Before(s.NotBefore.Time()) {
|
||||
return ErrNotValidYet
|
||||
}
|
||||
|
||||
if s.Expiry != nil && timeNow().Add(-DefaultLeeway).After(s.Expiry.Time()) {
|
||||
return ErrExpired
|
||||
}
|
||||
|
||||
if s.IssuedAt != nil && timeNow().Add(DefaultLeeway).Before(s.IssuedAt.Time()) {
|
||||
return ErrIssuedInTheFuture
|
||||
}
|
||||
|
||||
// if we have an associated access token, check if that token has expired as well
|
||||
if s.AccessToken != nil && timeNow().Add(-DefaultLeeway).After(s.AccessToken.Expiry) {
|
||||
return ErrExpired
|
||||
}
|
||||
|
||||
if len(s.Audience) != 0 {
|
||||
if !s.Audience.Contains(audience) {
|
||||
return ErrInvalidAudience
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Impersonating returns if the request is impersonating.
|
||||
|
@ -65,79 +166,12 @@ func (s *State) RequestGroups() string {
|
|||
return strings.Join(s.Groups, ",")
|
||||
}
|
||||
|
||||
type idToken struct {
|
||||
Issuer string `json:"iss"`
|
||||
Subject string `json:"sub"`
|
||||
Expiry jsonTime `json:"exp"`
|
||||
IssuedAt jsonTime `json:"iat"`
|
||||
Nonce string `json:"nonce"`
|
||||
AtHash string `json:"at_hash"`
|
||||
}
|
||||
|
||||
// IssuedAt parses the IDToken's issue date and returns a valid go time.Time.
|
||||
func (s *State) IssuedAt() (time.Time, error) {
|
||||
payload, err := parseJWT(s.IDToken)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("internal/sessions: malformed jwt: %v", err)
|
||||
}
|
||||
var token idToken
|
||||
if err := json.Unmarshal(payload, &token); err != nil {
|
||||
return time.Time{}, fmt.Errorf("internal/sessions: failed to unmarshal claims: %v", err)
|
||||
}
|
||||
return time.Time(token.IssuedAt), nil
|
||||
}
|
||||
|
||||
// MarshalSession marshals the session state as JSON, encrypts the JSON using the
|
||||
// given cipher, and base64-encodes the result
|
||||
func MarshalSession(s *State, c cryptutil.SecureEncoder) (string, error) {
|
||||
v, err := c.Marshal(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
||||
// byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct
|
||||
func UnmarshalSession(value string, c cryptutil.SecureEncoder) (*State, error) {
|
||||
s := &State{}
|
||||
err := c.Unmarshal(value, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func parseJWT(p string) ([]byte, error) {
|
||||
parts := strings.Split(p, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("internal/sessions: malformed jwt, expected 3 parts got %d", len(parts))
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("internal/sessions: malformed jwt payload: %v", err)
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
type jsonTime time.Time
|
||||
|
||||
func (j *jsonTime) UnmarshalJSON(b []byte) error {
|
||||
var n json.Number
|
||||
if err := json.Unmarshal(b, &n); err != nil {
|
||||
return err
|
||||
}
|
||||
var unix int64
|
||||
|
||||
if t, err := n.Int64(); err == nil {
|
||||
unix = t
|
||||
// SetImpersonation sets impersonation user and groups.
|
||||
func (s *State) SetImpersonation(email, groups string) {
|
||||
s.ImpersonateEmail = email
|
||||
if groups == "" {
|
||||
s.ImpersonateGroups = nil
|
||||
} else {
|
||||
f, err := n.Float64()
|
||||
if err != nil {
|
||||
return err
|
||||
s.ImpersonateGroups = strings.Split(groups, ",")
|
||||
}
|
||||
unix = int64(f)
|
||||
}
|
||||
*j = jsonTime(time.Unix(unix, 0))
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,90 +1,16 @@
|
|||
package sessions
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func TestStateSerialization(t *testing.T) {
|
||||
secret := cryptutil.NewKey()
|
||||
cipher, err := cryptutil.NewAEADCipher(secret)
|
||||
c := cryptutil.NewSecureJSONEncoder(cipher)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||
}
|
||||
|
||||
want := &State{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}
|
||||
|
||||
ciphertext, err := MarshalSession(want, c)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be encode session: %v", err)
|
||||
}
|
||||
|
||||
got, err := UnmarshalSession(ciphertext, c)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be decode session: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Logf("want: %#v", want)
|
||||
t.Logf(" got: %#v", got)
|
||||
t.Errorf("encoding and decoding session resulted in unexpected output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateExpirations(t *testing.T) {
|
||||
session := &State{
|
||||
AccessToken: "token1234",
|
||||
RefreshToken: "refresh4321",
|
||||
RefreshDeadline: time.Now().Add(-1 * time.Hour),
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
}
|
||||
if !session.Expired() {
|
||||
t.Errorf("expected lifetime period to be expired")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestState_IssuedAt(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
IDToken string
|
||||
want time.Time
|
||||
wantErr bool
|
||||
}{
|
||||
{"simple parse", "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlkzYm1qV3R4US16OW1fM1RLb0dtRWciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY3MjY4NywiZXhwIjoxNTU4Njc2Mjg3fQ.a4g8W94E7iVJhiIUmsNMwJssfx3Evi8sXeiXgXMC7kHNvftQ2CFU_LJ-dqZ5Jf61OXcrp26r7lUcTNENXuen9tyUWAiHvxk6OHTxZusdywTCY5xowpSZBO9PDWYrmmdvfhRbaKO6QVAUMkbKr1Tr8xqfoaYVXNZhERXhcVReDznI0ccbwCGrNx5oeqiL4eRdZY9eqFXi4Yfee0mkef9oyVPc2HvnpwcpM0eckYa_l_ZQChGjXVGBFIus_Ao33GbWDuc9gs-_Vp2ev4KUT2qWb7AXMCGDLx0tWI9umm7mCBi_7xnaanGKUYcVwcSrv45arllAAwzuNxO0BVw3oRWa5Q", time.Unix(1558672687, 0), false},
|
||||
{"bad jwt", "x.x.x-x-x", time.Time{}, true},
|
||||
{"malformed jwt", "x", time.Time{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &State{IDToken: tt.IDToken}
|
||||
got, err := s.IssuedAt()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("State.IssuedAt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("State.IssuedAt() = %v, want %v", got.Format(time.RFC3339), tt.want.Format(time.RFC3339))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_Impersonating(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
|
@ -107,9 +33,8 @@ func TestState_Impersonating(t *testing.T) {
|
|||
s := &State{
|
||||
Email: tt.Email,
|
||||
Groups: tt.Groups,
|
||||
ImpersonateEmail: tt.ImpersonateEmail,
|
||||
ImpersonateGroups: tt.ImpersonateGroups,
|
||||
}
|
||||
s.SetImpersonation(tt.ImpersonateEmail, strings.Join(tt.ImpersonateGroups, ","))
|
||||
if got := s.Impersonating(); got != tt.want {
|
||||
t.Errorf("State.Impersonating() = %v, want %v", got, tt.want)
|
||||
}
|
||||
|
@ -123,84 +48,80 @@ func TestState_Impersonating(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestMarshalSession(t *testing.T) {
|
||||
secret := cryptutil.NewKey()
|
||||
cipher, err := cryptutil.NewAEADCipher(secret)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||
}
|
||||
c := cryptutil.NewSecureJSONEncoder(cipher)
|
||||
|
||||
hugeString := make([]byte, 4097)
|
||||
if _, err := rand.Read(hugeString); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
func TestState_Verify(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
s *State
|
||||
Audience jwt.Audience
|
||||
Expiry *jwt.NumericDate
|
||||
NotBefore *jwt.NumericDate
|
||||
IssuedAt *jwt.NumericDate
|
||||
AccessToken *oauth2.Token
|
||||
|
||||
audience string
|
||||
wantErr bool
|
||||
}{
|
||||
{"simple", &State{}, false},
|
||||
{"too big", &State{AccessToken: fmt.Sprintf("%x", hugeString)}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
in, err := MarshalSession(tt.s, c)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("MarshalSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
out, err := UnmarshalSession(in, c)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be decode session: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.s, out); diff != "" {
|
||||
t.Errorf("MarshalSession() = %s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_Valid(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
RefreshDeadline time.Time
|
||||
wantErr bool
|
||||
}{
|
||||
{" good", time.Now().Add(10 * time.Second), false},
|
||||
{" expired", time.Now().Add(-10 * time.Second), true},
|
||||
{"good", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", false},
|
||||
{"bad expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true},
|
||||
{"bad audience", []string{"x", "y", "z"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true},
|
||||
{"bad not before", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true},
|
||||
{"bad issued at", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true},
|
||||
{"bad access token expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(-time.Hour)}, "a", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &State{
|
||||
RefreshDeadline: tt.RefreshDeadline,
|
||||
Audience: tt.Audience,
|
||||
Expiry: tt.Expiry,
|
||||
NotBefore: tt.NotBefore,
|
||||
IssuedAt: tt.IssuedAt,
|
||||
AccessToken: tt.AccessToken,
|
||||
}
|
||||
if err := s.Valid(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("State.Valid() error = %v, wantErr %v", err, tt.wantErr)
|
||||
if err := s.Verify(tt.audience); (err != nil) != tt.wantErr {
|
||||
t.Errorf("State.Verify() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_ForceRefresh(t *testing.T) {
|
||||
func TestState_RouteSession(t *testing.T) {
|
||||
now := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return now
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
RefreshDeadline time.Time
|
||||
Issuer string
|
||||
Audience jwt.Audience
|
||||
Expiry *jwt.NumericDate
|
||||
AccessToken *oauth2.Token
|
||||
|
||||
issuer string
|
||||
|
||||
audience []string
|
||||
validity time.Duration
|
||||
|
||||
want *State
|
||||
}{
|
||||
{"good", time.Now().Truncate(time.Second)},
|
||||
{"good", "authenticate.x.y.z", []string{"http.x.y.z"}, jwt.NewNumericDate(timeNow()), nil, "authenticate.a.b.c", []string{"http.a.b.c"}, 20 * time.Second, &State{Issuer: "authenticate.a.b.c", Audience: []string{"http.a.b.c"}, NotBefore: jwt.NewNumericDate(timeNow()), IssuedAt: jwt.NewNumericDate(timeNow()), Expiry: jwt.NewNumericDate(timeNow().Add(20 * time.Second))}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &State{
|
||||
RefreshDeadline: tt.RefreshDeadline,
|
||||
s := State{
|
||||
Issuer: tt.Issuer,
|
||||
Audience: tt.Audience,
|
||||
Expiry: tt.Expiry,
|
||||
AccessToken: tt.AccessToken,
|
||||
}
|
||||
s.ForceRefresh()
|
||||
if s.RefreshDeadline != tt.RefreshDeadline {
|
||||
t.Errorf("refresh deadline not updated")
|
||||
cmpOpts := []cmp.Option{
|
||||
cmpopts.IgnoreUnexported(State{}),
|
||||
}
|
||||
got := s.NewSession(tt.issuer, tt.audience)
|
||||
got = got.RouteSession(tt.validity)
|
||||
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
||||
t.Errorf("State.RouteSession() = %s", diff)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,19 +6,30 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
// ErrExpired is the error for an expired session.
|
||||
ErrExpired = errors.New("internal/sessions: session is expired")
|
||||
// ErrNoSessionFound is the error for when no session is found.
|
||||
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
|
||||
|
||||
// ErrMalformed is the error for when a session is found but is malformed.
|
||||
ErrMalformed = errors.New("internal/sessions: session is malformed")
|
||||
|
||||
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
|
||||
ErrNotValidYet = errors.New("internal/sessions: validation failed, token not valid yet (nbf)")
|
||||
|
||||
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
|
||||
ErrExpired = errors.New("internal/sessions: validation failed, token is expired (exp)")
|
||||
|
||||
// ErrIssuedInTheFuture indicates that the iat field is in the future.
|
||||
ErrIssuedInTheFuture = errors.New("internal/sessions: validation field, token issued in the future (iat)")
|
||||
|
||||
// ErrInvalidAudience indicated invalid aud claim.
|
||||
ErrInvalidAudience = errors.New("internal/sessions: validation failed, invalid audience claim (aud)")
|
||||
)
|
||||
|
||||
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
||||
type SessionStore interface {
|
||||
ClearSession(http.ResponseWriter, *http.Request)
|
||||
LoadSession(*http.Request) (*State, error)
|
||||
SaveSession(http.ResponseWriter, *http.Request, *State) error
|
||||
SessionLoader
|
||||
SaveSession(http.ResponseWriter, *http.Request, interface{}) error
|
||||
}
|
||||
|
||||
// SessionLoader is implemented by any struct that loads a pomerium session
|
||||
|
@ -26,3 +37,19 @@ type SessionStore interface {
|
|||
type SessionLoader interface {
|
||||
LoadSession(*http.Request) (*State, error)
|
||||
}
|
||||
|
||||
// Encoder can both Marshal and Unmarshal a struct into and from a set of bytes.
|
||||
type Encoder interface {
|
||||
Marshaler
|
||||
Unmarshaler
|
||||
}
|
||||
|
||||
// Marshaler encodes a struct into a set of bytes.
|
||||
type Marshaler interface {
|
||||
Marshal(interface{}) ([]byte, error)
|
||||
}
|
||||
|
||||
// Unmarshaler decodes a set of bytes and returns a struct.
|
||||
type Unmarshaler interface {
|
||||
Unmarshal([]byte, interface{}) error
|
||||
}
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import "strings"
|
||||
|
||||
// ParentSubdomain returns the parent subdomain.
|
||||
func ParentSubdomain(s string) string {
|
||||
if strings.Count(s, ".") < 2 {
|
||||
return ""
|
||||
}
|
||||
split := strings.SplitN(s, ".", 2)
|
||||
return split[1]
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
package sessions
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_ParentSubdomain(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
s string
|
||||
want string
|
||||
}{
|
||||
{"httpbin.corp.example.com", "corp.example.com"},
|
||||
{"some.httpbin.corp.example.com", "httpbin.corp.example.com"},
|
||||
{"example.com", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s, func(t *testing.T) {
|
||||
if got := ParentSubdomain(tt.s); got != tt.want {
|
||||
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -143,9 +143,12 @@ func New() *template.Template {
|
|||
text-align: center;
|
||||
width: 75px;
|
||||
height: auto;
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
.logo {
|
||||
padding-bottom: 20px;
|
||||
padding-top: 20px;
|
||||
width: 115px;
|
||||
height: auto;
|
||||
}
|
||||
|
@ -161,6 +164,7 @@ func New() *template.Template {
|
|||
p.message {
|
||||
margin-top: 10px;
|
||||
margin-bottom: 10px;
|
||||
padding-bottom: 20px;
|
||||
}
|
||||
|
||||
.field {
|
||||
|
@ -300,37 +304,119 @@ func New() *template.Template {
|
|||
<div id="main">
|
||||
<div id="info-box">
|
||||
<div class="card">
|
||||
{{if .Session.Picture }}
|
||||
<img class="icon" src="{{.Session.Picture}}" alt="user image">
|
||||
{{else}}
|
||||
<svg class="icon ok" xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24">
|
||||
<path fill="none" d="M0 0h24v24H0V0z" />
|
||||
<path d="M11 7h2v2h-2zm0 4h2v6h-2zm1-9C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm0 18c-4.41 0-8-3.59-8-8s3.59-8 8-8 8 3.59 8 8-3.59 8-8 8z" />
|
||||
</svg>
|
||||
<form method="POST" action="{{.SignoutURL}}">
|
||||
{{end}}
|
||||
|
||||
<form method="POST" action="/.pomerium/sign_out">
|
||||
<section>
|
||||
<h2>Current user</h2>
|
||||
<p class="message">Your current session details.</p>
|
||||
<fieldset>
|
||||
{{if .Session.Name}}
|
||||
<label>
|
||||
<span>Name</span>
|
||||
<input name="Name" type="text" class="field" value="{{.Session.Name}}" disabled>
|
||||
</label>
|
||||
{{else}}
|
||||
{{if .Session.GivenName}}
|
||||
<label>
|
||||
<span>Given Name</span>
|
||||
<input name="GivenName" type="text" class="field" value="{{.Session.GivenName}}" disabled>
|
||||
</label>
|
||||
{{end}}
|
||||
{{if .Session.FamilyName}}
|
||||
<label>
|
||||
<span>Family Name</span>
|
||||
<input name="FamilyName" type="text" class="field" value="{{.Session.FamilyName}}" disabled>
|
||||
</label>
|
||||
{{end}}
|
||||
{{end}}
|
||||
{{if .Session.Subject}}
|
||||
<label>
|
||||
<span>UserID</span>
|
||||
<input name="email" type="text" class="field" value="{{.Session.Subject}}" disabled>
|
||||
</label>
|
||||
{{end}}
|
||||
{{if .Session.Email}}
|
||||
<label>
|
||||
<span>Email</span>
|
||||
<input name="email" type="email" class="field" value="{{.Email}}" disabled>
|
||||
<input name="email" type="email" class="field" value="{{.Session.Email}}" disabled>
|
||||
</label>
|
||||
{{end}}
|
||||
{{if .Session.User}}
|
||||
<label>
|
||||
<span>User</span>
|
||||
<input name="user" type="text" class="field" value="{{.User}}" disabled>
|
||||
<input name="user" type="text" class="field" value="{{.Session.User}}" disabled>
|
||||
</label>
|
||||
{{end}}
|
||||
{{if .Session.Groups}}
|
||||
<label class="select">
|
||||
<span>Groups</span>
|
||||
<div id="group" class="field">
|
||||
<select name="group">
|
||||
{{range .Groups}}
|
||||
{{range .Session.Groups}}
|
||||
<option value="{{.}}">{{.}}</option>
|
||||
{{end}}
|
||||
</select>
|
||||
</div>
|
||||
</label>
|
||||
{{end}}
|
||||
{{if .Session.Expiry}}
|
||||
<label>
|
||||
<span>Expiry</span>
|
||||
<input name="session expiration" type="text" class="field" value="{{.RefreshDeadline}}" disabled>
|
||||
<input name="session expiration" type="text" class="field" value="{{.Session.Expiry.Time}}" disabled>
|
||||
</label>
|
||||
{{end}}
|
||||
{{if .Session.IssuedAt}}
|
||||
<label>
|
||||
<span>Issued</span>
|
||||
<input name="session expiration" type="text" class="field" value="{{.Session.IssuedAt.Time}}" disabled>
|
||||
</label>
|
||||
{{end}}
|
||||
{{if .Session.Issuer}}
|
||||
<label>
|
||||
<span>Issuer</span>
|
||||
<input name="session expiration" type="text" class="field" value=" {{ .Session.Issuer}}" disabled>
|
||||
</label>
|
||||
{{end}}
|
||||
{{if .Session.Audience}}
|
||||
<label class="select">
|
||||
<span>Audiences</span>
|
||||
<div id="group" class="field">
|
||||
<select name="group">
|
||||
{{range .Session.Audience}}
|
||||
<option value="{{.}}">{{ printf "%.30s" . }}</option>
|
||||
{{end}}
|
||||
</select>
|
||||
</div>
|
||||
</label>
|
||||
{{end}}
|
||||
|
||||
{{if .Session.ImpersonateEmail}}
|
||||
<label>
|
||||
<span>Impersonating Email</span>
|
||||
<input name="session expiration" type="text" class="field" value="{{.Session.ImpersonateEmail}}" disabled>
|
||||
</label>
|
||||
{{end}}
|
||||
{{if .Session.ImpersonateGroups}}
|
||||
<label class="select">
|
||||
<span>Impersonating Groups</span>
|
||||
<div id="group" class="field">
|
||||
<select name="group">
|
||||
{{range .Session.ImpersonateGroups}}
|
||||
<option value="{{.}}">{{.}}</option>
|
||||
{{end}}
|
||||
</select>
|
||||
</div>
|
||||
</label>
|
||||
{{end}}
|
||||
|
||||
</fieldset>
|
||||
</section>
|
||||
<div class="flex">
|
||||
|
@ -338,17 +424,7 @@ func New() *template.Template {
|
|||
<button class="button full" type="submit">Sign Out</button>
|
||||
</div>
|
||||
</form>
|
||||
<section>
|
||||
<h2>Refresh Identity</h2>
|
||||
<p class="message">Pomerium will automatically refresh your user session. However, if your group memberships have recently changed and haven't taken effect yet, you can refresh your session manually.</p>
|
||||
|
||||
<form method="POST" action="/.pomerium/refresh">
|
||||
<div class="flex">
|
||||
{{ .csrfField }}
|
||||
<button class="button full" type="submit">Refresh</button>
|
||||
</div>
|
||||
</form>
|
||||
</section>
|
||||
|
||||
{{if .IsAdmin}}
|
||||
<form method="POST" action="/.pomerium/impersonate">
|
||||
|
@ -358,11 +434,11 @@ func New() *template.Template {
|
|||
<fieldset>
|
||||
<label>
|
||||
<span>Email</span>
|
||||
<input name="email" type="email" class="field" value="{{.ImpersonateEmail}}" placeholder="user@example.com">
|
||||
<input name="email" type="email" class="field" value="" placeholder="user@example.com">
|
||||
</label>
|
||||
<label>
|
||||
<span>Group</span>
|
||||
<input name="group" type="text" class="field" value="{{.ImpersonateGroup}}" placeholder="engineering">
|
||||
<input name="group" type="text" class="field" value="" placeholder="engineering">
|
||||
</label>
|
||||
</fieldset>
|
||||
</section>
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/pomerium/csrf"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
|
@ -27,15 +28,24 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
|
|||
// 3. Enforce CSRF protections for any non-idempotent http method
|
||||
h.Use(csrf.Protect(
|
||||
p.cookieSecret,
|
||||
csrf.Path("/"),
|
||||
csrf.Domain(p.cookieDomain),
|
||||
csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieName)),
|
||||
csrf.Secure(p.cookieOptions.Secure),
|
||||
csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieOptions.Name)),
|
||||
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
|
||||
))
|
||||
h.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet)
|
||||
h.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost)
|
||||
h.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost)
|
||||
h.HandleFunc("/refresh", p.ForceRefresh).Methods(http.MethodPost)
|
||||
|
||||
// Authenticate service callback handlers and middleware
|
||||
c := r.PathPrefix(dashboardURL + "/callback").Subrouter()
|
||||
// only accept payloads that have come from a trusted service (hmac)
|
||||
c.Use(middleware.ValidateSignature(p.SharedKey))
|
||||
c.HandleFunc("/", p.Callback).Queries("redirect_uri", "{redirect_uri}").Methods(http.MethodGet)
|
||||
|
||||
// Programmatic API handlers and middleware
|
||||
a := r.PathPrefix(dashboardURL + "/api").Subrouter()
|
||||
a.HandleFunc("/v1/login", p.ProgrammaticLogin).Queries("redirect_uri", "{redirect_uri}").Methods(http.MethodGet)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
|
@ -56,6 +66,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
|||
redirectURL = uri
|
||||
}
|
||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
|
@ -74,53 +85,14 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
|
|||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
//todo(bdd): make sign out redirect a configuration option so that
|
||||
// admins can set to whatever their corporate homepage is
|
||||
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
|
||||
signoutURL := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
|
||||
|
||||
templates.New().ExecuteTemplate(w, "dashboard.html", map[string]interface{}{
|
||||
"Email": session.Email,
|
||||
"User": session.User,
|
||||
"Groups": session.Groups,
|
||||
"RefreshDeadline": time.Until(session.RefreshDeadline).Round(time.Second).String(),
|
||||
"SignoutURL": signoutURL.String(),
|
||||
"Session": session,
|
||||
"IsAdmin": isAdmin,
|
||||
"ImpersonateEmail": session.ImpersonateEmail,
|
||||
"ImpersonateGroup": strings.Join(session.ImpersonateGroups, ","),
|
||||
"csrfField": csrf.TemplateField(r),
|
||||
})
|
||||
}
|
||||
|
||||
// ForceRefresh redeems and extends an existing authenticated oidc session with
|
||||
// the underlying identity provider. All session details including groups,
|
||||
// timeouts, will be renewed.
|
||||
func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
iss, err := session.IssuedAt()
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// reject a refresh if it's been less than the refresh cooldown to prevent abuse
|
||||
if time.Since(iss) < p.refreshCooldown {
|
||||
errStr := fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown)
|
||||
httpErr := httputil.Error(errStr, http.StatusBadRequest, nil)
|
||||
httputil.ErrorResponse(w, r, httpErr)
|
||||
return
|
||||
}
|
||||
session.ForceRefresh()
|
||||
if err = p.sessionStore.SaveSession(w, r, session); err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, dashboardURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// Impersonate takes the result of a form and adds user impersonation details
|
||||
// to the user's current user sessions state if the user is currently an
|
||||
// administrative user. Requests are redirected back to the user dashboard.
|
||||
|
@ -138,101 +110,112 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
// OK to impersonation
|
||||
session.ImpersonateEmail = r.FormValue("email")
|
||||
session.ImpersonateGroups = strings.Split(r.FormValue("group"), ",")
|
||||
groups := r.FormValue("group")
|
||||
if groups != "" {
|
||||
session.ImpersonateGroups = strings.Split(groups, ",")
|
||||
}
|
||||
if err := p.sessionStore.SaveSession(w, r, session); err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, dashboardURL, http.StatusFound)
|
||||
redirectURL := urlutil.GetAbsoluteURL(r)
|
||||
redirectURL.Path = dashboardURL // redirect back to the dashboard
|
||||
q := redirectURL.Query()
|
||||
q.Add("impersonate_email", r.FormValue("email"))
|
||||
q.Add("impersonate_group", r.FormValue("group"))
|
||||
redirectURL.RawQuery = q.Encode()
|
||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String()
|
||||
http.Redirect(w, r, uri, http.StatusFound)
|
||||
}
|
||||
|
||||
func (p *Proxy) registerFwdAuthHandlers() http.Handler {
|
||||
r := httputil.NewRouter()
|
||||
r.StrictSlash(true)
|
||||
r.Use(sessions.RetrieveSession(p.sessionStore))
|
||||
r.HandleFunc("/", p.VerifyAndSignin).Queries("uri", "{uri}").Methods(http.MethodGet)
|
||||
r.HandleFunc("/verify", p.VerifyOnly).Queries("uri", "{uri}").Methods(http.MethodGet)
|
||||
r.Handle("/", p.Verify(false)).Queries("uri", "{uri}").Methods(http.MethodGet)
|
||||
r.Handle("/verify", p.Verify(true)).Queries("uri", "{uri}").Methods(http.MethodGet)
|
||||
return r
|
||||
}
|
||||
|
||||
// VerifyAndSignin checks a user's credentials for an arbitrary host. If the user
|
||||
// Verify checks a user's credentials for an arbitrary host. If the user
|
||||
// is properly authenticated and is authorized to access the supplied host,
|
||||
// a `200` http status code is returned. If the user is not authenticated, they
|
||||
// will be redirected to the authenticate service to sign in with their identity
|
||||
// provider. If the user is unauthorized, a `401` error is returned.
|
||||
func (p *Proxy) VerifyAndSignin(w http.ResponseWriter, r *http.Request) {
|
||||
func (p *Proxy) Verify(verifyOnly bool) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
uri, err := urlutil.ParseAndValidateURL(r.FormValue("uri"))
|
||||
if err != nil || uri.String() == "" {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("bad verification uri given", http.StatusBadRequest, nil))
|
||||
httputil.ErrorResponse(w, r, httputil.Error("bad verification uri", http.StatusBadRequest, nil))
|
||||
return
|
||||
}
|
||||
if err := p.authenticate(w, r); err != nil {
|
||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
}
|
||||
if err := p.authorize(r, uri); err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusUnauthorized, err))
|
||||
if err := p.authenticate(verifyOnly, w, r); err != nil {
|
||||
return
|
||||
}
|
||||
// check the queryparams to see if this check immediately followed
|
||||
// authentication. If so, redirect back to the originally requested hostname.
|
||||
if isCallback := r.URL.Query().Get(callbackQueryParam); isCallback == "true" {
|
||||
q := uri.Query()
|
||||
q.Del(callbackQueryParam)
|
||||
uri.RawQuery = q.Encode()
|
||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
if err := p.authorize(uri.Host, w, r); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, fmt.Sprintf("Access to %s is allowed.", uri.Host))
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// VerifyOnly checks a user's credentials for an arbitrary host. If the user
|
||||
// is properly authenticated and is authorized to access the supplied host,
|
||||
// a `200` http status code is returned otherwise a `401` error is returned.
|
||||
func (p *Proxy) VerifyOnly(w http.ResponseWriter, r *http.Request) {
|
||||
uri, err := urlutil.ParseAndValidateURL(r.FormValue("uri"))
|
||||
if err != nil || uri.String() == "" {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("bad verification uri given", http.StatusBadRequest, nil))
|
||||
// Callback takes a `redirect_uri` query param that has been hmac'd by the
|
||||
// authenticate service. Embedded in the `redirect_uri` are query-params
|
||||
// that tell this handler how to set the per-route user session.
|
||||
// Callback is responsible for redirecting the user back to the intended
|
||||
// destination URL and path, as well as to clean up any additional query params
|
||||
// added by the authenticate service.
|
||||
func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
if err := p.authenticate(w, r); err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusUnauthorized, err))
|
||||
|
||||
q := redirectURL.Query()
|
||||
// 1. extract the base64 encoded and encrypted JWT from redirect_uri's query params
|
||||
encryptedJWT, err := base64.URLEncoding.DecodeString(q.Get("pomerium_jwt"))
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
if err := p.authorize(r, uri); err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusUnauthorized, err))
|
||||
q.Del("pomerium_jwt")
|
||||
q.Del("impersonate_email")
|
||||
q.Del("impersonate_group")
|
||||
|
||||
// 2. decrypt the JWT using the cipher using the _shared_ secret key
|
||||
rawJWT, err := cryptutil.Decrypt(p.sharedCipher, encryptedJWT, nil)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
// 3. Save the decrypted JWT to the session store directly as a string, without resigning
|
||||
if err = p.sessionStore.SaveSession(w, r, rawJWT); err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// if this is a programmatic request, don't strip the tokens before redirect
|
||||
if redirectURL.Query().Get("pomerium_programmatic_destination_url") != "" {
|
||||
q.Set("pomerium_jwt", string(rawJWT))
|
||||
}
|
||||
redirectURL.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// ProgrammaticLogin returns a signed url that can be used to login
|
||||
// using the authenticate service.
|
||||
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) {
|
||||
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
q := redirectURL.Query()
|
||||
q.Add("pomerium_programmatic_destination_url", urlutil.GetAbsoluteURL(r).String())
|
||||
redirectURL.RawQuery = q.Encode()
|
||||
response := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (p *Proxy) authorize(r *http.Request, uri *url.URL) error {
|
||||
// attempt to retrieve the user session from the request context, validity
|
||||
// of the identity session is asserted by the middleware chain
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// query the authorization service to see if the session's user has
|
||||
// the appropriate authorization to access the given hostname
|
||||
authorized, err := p.AuthorizeClient.Authorize(r.Context(), uri.Host, s)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !authorized {
|
||||
return fmt.Errorf("%s is not authorized for %s", s.RequestEmail(), uri.String())
|
||||
}
|
||||
return nil
|
||||
w.Write([]byte(response))
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -11,11 +11,14 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/config"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/proxy/clients"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func TestProxy_RobotsTxt(t *testing.T) {
|
||||
|
@ -62,17 +65,17 @@ func TestProxy_UserDashboard(t *testing.T) {
|
|||
ctxError error
|
||||
options config.Options
|
||||
method string
|
||||
cipher cryptutil.SecureEncoder
|
||||
cipher sessions.Encoder
|
||||
session sessions.SessionStore
|
||||
authorizer clients.Authorizer
|
||||
|
||||
wantAdminForm bool
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK},
|
||||
{"session context error", errors.New("error"), opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
|
||||
{"want admin form good admin authorization", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
|
||||
{"is admin but authorization fails", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
|
||||
{"good", nil, opts, http.MethodGet, &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusOK},
|
||||
{"session context error", errors.New("error"), opts, http.MethodGet, &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
|
||||
{"want admin form good admin authorization", nil, opts, http.MethodGet, &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
|
||||
{"is admin but authorization fails", nil, opts, http.MethodGet, &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -109,56 +112,6 @@ func TestProxy_UserDashboard(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestProxy_ForceRefresh(t *testing.T) {
|
||||
opts := testOptions(t)
|
||||
opts.RefreshCooldown = 0
|
||||
timeSinceError := testOptions(t)
|
||||
timeSinceError.RefreshCooldown = time.Duration(int(^uint(0) >> 1))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctxError error
|
||||
options config.Options
|
||||
method string
|
||||
cipher cryptutil.SecureEncoder
|
||||
session sessions.SessionStore
|
||||
authorizer clients.Authorizer
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
||||
{"cannot load session", errors.New("error"), opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||
{"bad id token", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||
{"issue date too soon", nil, timeSinceError, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest},
|
||||
{"refresh failure", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
||||
{"can't save refreshed session", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, err := New(tt.options)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.encoder = tt.cipher
|
||||
p.sessionStore = tt.session
|
||||
p.AuthorizeClient = tt.authorizer
|
||||
|
||||
r := httptest.NewRequest(tt.method, "/", nil)
|
||||
state, _ := tt.session.LoadSession(r)
|
||||
ctx := r.Context()
|
||||
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||
r = r.WithContext(ctx)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
p.ForceRefresh(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
||||
t.Errorf("\n%+v", opts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_Impersonate(t *testing.T) {
|
||||
t.Parallel()
|
||||
opts := testOptions(t)
|
||||
|
@ -171,18 +124,17 @@ func TestProxy_Impersonate(t *testing.T) {
|
|||
email string
|
||||
groups string
|
||||
csrf string
|
||||
cipher cryptutil.SecureEncoder
|
||||
cipher sessions.Encoder
|
||||
sessionStore sessions.SessionStore
|
||||
authorizer clients.Authorizer
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
|
||||
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
|
||||
{"save session failure", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &encoding.MockEncoder{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
|
||||
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
|
||||
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -276,25 +228,24 @@ func TestProxy_VerifyWithMiddleware(t *testing.T) {
|
|||
path string
|
||||
verifyURI string
|
||||
|
||||
cipher cryptutil.SecureEncoder
|
||||
cipher sessions.Encoder
|
||||
sessionStore sessions.SessionStore
|
||||
authorizer clients.Authorizer
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{"good", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
|
||||
{"good verify only", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
|
||||
{"bad naked domain uri given", opts, nil, http.MethodGet, "", "/", "a.naked.domain", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri given\"}\n"},
|
||||
{"bad naked domain uri given verify only", opts, nil, http.MethodGet, "", "/verify", "a.naked.domain", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri given\"}\n"},
|
||||
{"bad empty verification uri given", opts, nil, http.MethodGet, "", "/", " ", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri given\"}\n"},
|
||||
{"bad empty verification uri given verify only", opts, nil, http.MethodGet, "", "/verify", " ", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri given\"}\n"},
|
||||
{"good post auth redirect", opts, nil, http.MethodGet, callbackQueryParam, "/", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, "<a href=\"https://some.domain.example\">Found</a>.\n\n"},
|
||||
{"not authorized", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"Unauthorized\"}\n"},
|
||||
{"not authorized verify endpoint", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"Unauthorized\"}\n"},
|
||||
{"not authorized expired, redirect to auth", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
|
||||
{"not authorized expired, don't redirect!", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"Unauthorized\"}\n"},
|
||||
{"not authorized because of error", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusUnauthorized, "{\"error\":\"Unauthorized\"}\n"},
|
||||
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"Unauthorized\"}\n"},
|
||||
{"good", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
|
||||
{"good verify only", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
|
||||
{"bad naked domain uri", opts, nil, http.MethodGet, "", "/", "a.naked.domain", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
|
||||
{"bad naked domain uri verify only", opts, nil, http.MethodGet, "", "/verify", "a.naked.domain", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
|
||||
{"bad empty verification uri", opts, nil, http.MethodGet, "", "/", " ", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
|
||||
{"bad empty verification uri verify only", opts, nil, http.MethodGet, "", "/verify", " ", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
|
||||
{"not authorized", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"},
|
||||
{"not authorized verify endpoint", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"},
|
||||
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, "", "/", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
|
||||
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, "", "/verify", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
||||
{"not authorized because of error", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"error\":\"authz error\"}\n"},
|
||||
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -346,3 +297,133 @@ func TestProxy_VerifyWithMiddleware(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
func TestProxy_Callback(t *testing.T) {
|
||||
t.Parallel()
|
||||
opts := testOptions(t)
|
||||
tests := []struct {
|
||||
name string
|
||||
options config.Options
|
||||
|
||||
method string
|
||||
|
||||
scheme string
|
||||
host string
|
||||
path string
|
||||
|
||||
qp map[string]string
|
||||
|
||||
cipher sessions.Encoder
|
||||
sessionStore sessions.SessionStore
|
||||
authorizer clients.Authorizer
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{"good", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_programmatic_destination_url": "ok", "pomerium_jwt": "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &encoding.MockEncoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||
{"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &encoding.MockEncoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"bad save session", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &encoding.MockEncoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusInternalServerError, ""},
|
||||
{"bad base64", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "^"}, &encoding.MockEncoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, err := New(tt.options)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.encoder = tt.cipher
|
||||
p.sessionStore = tt.sessionStore
|
||||
p.AuthorizeClient = tt.authorizer
|
||||
p.UpdateOptions(tt.options)
|
||||
redirectURI := &url.URL{Scheme: tt.scheme, Host: tt.host, Path: tt.path}
|
||||
queryString := redirectURI.Query()
|
||||
for k, v := range tt.qp {
|
||||
queryString.Set(k, v)
|
||||
}
|
||||
redirectURI.RawQuery = queryString.Encode()
|
||||
|
||||
uri := &url.URL{Path: "/"}
|
||||
if tt.qp != nil {
|
||||
qu := uri.Query()
|
||||
qu.Set("redirect_uri", redirectURI.String())
|
||||
uri.RawQuery = qu.Encode()
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(tt.method, uri.String(), nil)
|
||||
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
p.Callback(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
||||
t.Errorf("\n%+v", w.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantBody != "" {
|
||||
body := w.Body.String()
|
||||
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
|
||||
t.Errorf("wrong body\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_ProgrammaticLogin(t *testing.T) {
|
||||
t.Parallel()
|
||||
opts := testOptions(t)
|
||||
tests := []struct {
|
||||
name string
|
||||
options config.Options
|
||||
|
||||
method string
|
||||
|
||||
scheme string
|
||||
host string
|
||||
path string
|
||||
qp map[string]string
|
||||
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "http://localhost"}, http.StatusOK, ""},
|
||||
{"router miss, bad redirect_uri query", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"bad_redirect_uri": "http://localhost"}, http.StatusNotFound, ""},
|
||||
{"bad redirect_uri missing scheme", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "localhost"}, http.StatusBadRequest, "{\"error\":\"malformed redirect_uri\"}\n"},
|
||||
{"bad http method", opts, http.MethodPost, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "http://localhost"}, http.StatusMethodNotAllowed, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, err := New(tt.options)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
redirectURI := &url.URL{Scheme: tt.scheme, Host: tt.host, Path: tt.path}
|
||||
queryString := redirectURI.Query()
|
||||
for k, v := range tt.qp {
|
||||
queryString.Set(k, v)
|
||||
}
|
||||
redirectURI.RawQuery = queryString.Encode()
|
||||
|
||||
r := httptest.NewRequest(tt.method, redirectURI.String(), nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router := httputil.NewRouter()
|
||||
router = p.registerDashboardHandlers(router)
|
||||
router.ServeHTTP(w, r)
|
||||
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
||||
t.Errorf("\n%+v", w.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantBody != "" {
|
||||
body := w.Body.String()
|
||||
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
|
||||
t.Errorf("wrong body\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,9 +3,8 @@ package proxy // import "github.com/pomerium/pomerium/proxy"
|
|||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
|
@ -30,34 +29,35 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
|
|||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession")
|
||||
defer span.End()
|
||||
if err := p.authenticate(w, r); err != nil {
|
||||
if err := p.authenticate(false, w, r.WithContext(ctx)); err != nil {
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session")
|
||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request) error {
|
||||
// authenticate authenticates a user and sets an appropriate response type,
|
||||
// redirect to authenticate or error handler depending on if err on failure is set.
|
||||
func (p *Proxy) authenticate(errOnFailure bool, w http.ResponseWriter, r *http.Request) error {
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
if errOnFailure || (s != nil && s.Programmatic) {
|
||||
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
|
||||
return err
|
||||
}
|
||||
if s == nil {
|
||||
return fmt.Errorf("empty session state")
|
||||
}
|
||||
if err := s.Valid(); err != nil {
|
||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
return err
|
||||
}
|
||||
// add pomerium's headers to the downstream request
|
||||
r.Header.Set(HeaderUserID, s.User)
|
||||
r.Header.Set(HeaderUserID, s.Subject)
|
||||
r.Header.Set(HeaderEmail, s.RequestEmail())
|
||||
r.Header.Set(HeaderGroups, s.RequestGroups())
|
||||
// and upstream
|
||||
w.Header().Set(HeaderUserID, s.User)
|
||||
w.Header().Set(HeaderUserID, s.Subject)
|
||||
w.Header().Set(HeaderEmail, s.RequestEmail())
|
||||
w.Header().Set(HeaderGroups, s.RequestGroups())
|
||||
return nil
|
||||
|
@ -69,27 +69,35 @@ func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
|
|||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession")
|
||||
defer span.End()
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if err != nil || s == nil {
|
||||
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err))
|
||||
return
|
||||
}
|
||||
authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r.WithContext(ctx), err)
|
||||
return
|
||||
} else if !authorized {
|
||||
errMsg := fmt.Sprintf("%s is not authorized for this route", s.RequestEmail())
|
||||
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error(errMsg, http.StatusForbidden, nil))
|
||||
if err := p.authorize(r.Host, w, r.WithContext(ctx)); err != nil {
|
||||
log.FromRequest(r).Debug().Err(err).Msg("proxy: AuthorizeSession")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Proxy) authorize(host string, w http.ResponseWriter, r *http.Request) error {
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusUnauthorized, err))
|
||||
return err
|
||||
}
|
||||
authorized, err := p.AuthorizeClient.Authorize(r.Context(), host, s)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return err
|
||||
} else if !authorized {
|
||||
err = fmt.Errorf("%s is not authorized for %s", s.RequestEmail(), host)
|
||||
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SignRequest is middleware that signs a JWT that contains a user's id,
|
||||
// email, and group. Session state is retrieved from the users's request context
|
||||
func (p *Proxy) SignRequest(signer cryptutil.JWTSigner) func(next http.Handler) http.Handler {
|
||||
func (p *Proxy) SignRequest(signer sessions.Marshaler) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "proxy.SignRequest")
|
||||
|
@ -99,12 +107,13 @@ func (p *Proxy) SignRequest(signer cryptutil.JWTSigner) func(next http.Handler)
|
|||
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err))
|
||||
return
|
||||
}
|
||||
jwt, err := signer.SignJWT(s.User, s.Email, strings.Join(s.Groups, ","))
|
||||
newSession := s.NewSession(r.Host, []string{r.Host})
|
||||
jwt, err := signer.Marshal(newSession.RouteSession(time.Minute))
|
||||
if err != nil {
|
||||
log.FromRequest(r).Warn().Err(err).Msg("proxy: failed signing jwt")
|
||||
log.FromRequest(r).Error().Err(err).Msg("proxy: failed signing jwt")
|
||||
} else {
|
||||
r.Header.Set(HeaderJWT, jwt)
|
||||
w.Header().Set(HeaderJWT, jwt)
|
||||
r.Header.Set(HeaderJWT, string(jwt))
|
||||
w.Header().Set(HeaderJWT, string(jwt))
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -14,6 +13,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/proxy/clients"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func TestProxy_AuthenticateSession(t *testing.T) {
|
||||
|
@ -27,15 +27,18 @@ func TestProxy_AuthenticateSession(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
name string
|
||||
errOnFailure bool
|
||||
session sessions.SessionStore
|
||||
ctxError error
|
||||
provider identity.Authenticator
|
||||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
|
||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusFound},
|
||||
{"good", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, http.StatusOK},
|
||||
{"invalid session", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||
{"expired", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusFound},
|
||||
{"expired and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusUnauthorized},
|
||||
{"invalid session and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -81,10 +84,10 @@ func TestProxy_AuthorizeSession(t *testing.T) {
|
|||
|
||||
wantStatus int
|
||||
}{
|
||||
{"user is authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK},
|
||||
{"user is not authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusForbidden},
|
||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusForbidden},
|
||||
{"authz client error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError},
|
||||
{"user is authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK},
|
||||
{"user is not authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusUnauthorized},
|
||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
|
||||
{"authz client error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -119,8 +122,9 @@ type mockJWTSigner struct {
|
|||
|
||||
// Sign implements the JWTSigner interface from the cryptutil package, but just
|
||||
// base64's the inputs instead for stesting.
|
||||
func (s *mockJWTSigner) SignJWT(user, email, groups string) (string, error) {
|
||||
return base64.StdEncoding.EncodeToString([]byte(fmt.Sprint(user, email, groups))), s.SignError
|
||||
func (s *mockJWTSigner) Marshal(v interface{}) ([]byte, error) {
|
||||
|
||||
return []byte("ok"), s.SignError
|
||||
}
|
||||
|
||||
func TestProxy_SignRequest(t *testing.T) {
|
||||
|
@ -142,7 +146,7 @@ func TestProxy_SignRequest(t *testing.T) {
|
|||
wantStatus int
|
||||
wantHeaders string
|
||||
}{
|
||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "dGVzdA=="},
|
||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "ok"},
|
||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""},
|
||||
{"signature failure, warn but ok", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""},
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
@ -10,8 +11,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/config"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
|
@ -30,8 +33,6 @@ const (
|
|||
signinURL = "/.pomerium/sign_in"
|
||||
// signoutURL is the path to authenticate's sign out endpoint
|
||||
signoutURL = "/.pomerium/sign_out"
|
||||
|
||||
callbackQueryParam = "pomerium-auth-callback"
|
||||
)
|
||||
|
||||
// ValidateOptions checks that proper configuration settings are set to create
|
||||
|
@ -54,7 +55,7 @@ func ValidateOptions(o config.Options) error {
|
|||
}
|
||||
|
||||
if len(o.SigningKey) != 0 {
|
||||
if _, err := cryptutil.NewES256Signer(o.SigningKey, ""); err != nil {
|
||||
if _, err := jws.NewES256Signer(o.SigningKey, ""); err != nil {
|
||||
return fmt.Errorf("proxy: invalid 'SIGNING_KEY': %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -65,6 +66,8 @@ func ValidateOptions(o config.Options) error {
|
|||
type Proxy struct {
|
||||
// SharedKey used to mutually authenticate service communication
|
||||
SharedKey string
|
||||
sharedCipher cipher.AEAD
|
||||
|
||||
authenticateURL *url.URL
|
||||
authenticateSigninURL *url.URL
|
||||
authenticateSignoutURL *url.URL
|
||||
|
@ -72,9 +75,8 @@ type Proxy struct {
|
|||
|
||||
AuthorizeClient clients.Authorizer
|
||||
|
||||
encoder cryptutil.SecureEncoder
|
||||
cookieName string
|
||||
cookieDomain string
|
||||
encoder sessions.Encoder
|
||||
cookieOptions *sessions.CookieOptions
|
||||
cookieSecret []byte
|
||||
defaultUpstreamTimeout time.Duration
|
||||
refreshCooldown time.Duration
|
||||
|
@ -92,50 +94,48 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// errors checked in ValidateOptions
|
||||
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey)
|
||||
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||
cipher, _ := cryptutil.NewAEADCipherFromBase64(opts.CookieSecret)
|
||||
|
||||
encoder := cryptutil.NewSecureJSONEncoder(cipher)
|
||||
|
||||
if opts.CookieDomain == "" {
|
||||
opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String())
|
||||
}
|
||||
|
||||
cookieStore, err := sessions.NewCookieStore(
|
||||
&sessions.CookieStoreOptions{
|
||||
Name: opts.CookieName,
|
||||
CookieDomain: opts.CookieDomain,
|
||||
CookieSecure: opts.CookieSecure,
|
||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||
CookieExpire: opts.CookieExpire,
|
||||
Encoder: encoder,
|
||||
})
|
||||
|
||||
// used to load and verify JWT tokens signed by the authenticate service
|
||||
encoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cookieOptions := &sessions.CookieOptions{
|
||||
Name: opts.CookieName,
|
||||
Domain: opts.CookieDomain,
|
||||
Secure: opts.CookieSecure,
|
||||
HTTPOnly: opts.CookieHTTPOnly,
|
||||
Expire: opts.CookieExpire,
|
||||
}
|
||||
|
||||
cookieStore, err := sessions.NewCookieLoader(cookieOptions, encoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &Proxy{
|
||||
SharedKey: opts.SharedKey,
|
||||
|
||||
sharedCipher: sharedCipher,
|
||||
encoder: encoder,
|
||||
|
||||
cookieSecret: decodedCookieSecret,
|
||||
cookieDomain: opts.CookieDomain,
|
||||
cookieName: opts.CookieName,
|
||||
cookieOptions: cookieOptions,
|
||||
defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
|
||||
refreshCooldown: opts.RefreshCooldown,
|
||||
sessionStore: cookieStore,
|
||||
sessionLoaders: []sessions.SessionLoader{
|
||||
cookieStore,
|
||||
sessions.NewHeaderStore(encoder),
|
||||
sessions.NewQueryParamStore(encoder)},
|
||||
sessions.NewHeaderStore(encoder, "Pomerium"),
|
||||
sessions.NewQueryParamStore(encoder, "pomerium_session")},
|
||||
signingKey: opts.SigningKey,
|
||||
templates: templates.New(),
|
||||
}
|
||||
// errors checked in ValidateOptions
|
||||
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
|
||||
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
|
||||
|
||||
p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
|
||||
p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
|
||||
|
||||
|
@ -238,14 +238,14 @@ func (p *Proxy) reverseProxyHandler(r *mux.Router, policy *config.Policy) (*mux.
|
|||
// 4. Retrieve the user session and add it to the request context
|
||||
rp.Use(sessions.RetrieveSession(p.sessionLoaders...))
|
||||
// 5. Strip the user session cookie from the downstream request
|
||||
rp.Use(middleware.StripCookie(p.cookieName))
|
||||
rp.Use(middleware.StripCookie(p.cookieOptions.Name))
|
||||
// 6. AuthN - Verify the user is authenticated. Set email, group, & id headers
|
||||
rp.Use(p.AuthenticateSession)
|
||||
// 7. AuthZ - Verify the user is authorized for route
|
||||
rp.Use(p.AuthorizeSession)
|
||||
// Optional: Add a signed JWT attesting to the user's id, email, and group
|
||||
if len(p.signingKey) != 0 {
|
||||
signer, err := cryptutil.NewES256Signer(p.signingKey, policy.Source.Host)
|
||||
signer, err := jws.NewES256Signer(p.signingKey, policy.Destination.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -172,7 +172,10 @@ func Test_UpdateOptions(t *testing.T) {
|
|||
corsPreflight.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", CORSAllowPreflight: true}}
|
||||
disableAuth := testOptions(t)
|
||||
disableAuth.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", AllowPublicUnauthenticatedAccess: true}}
|
||||
|
||||
fwdAuth := testOptions(t)
|
||||
fwdAuth.ForwardAuthURL = &url.URL{Scheme: "https", Host: "corp.example.example"}
|
||||
reqHeaders := testOptions(t)
|
||||
reqHeaders.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", SetRequestHeaders: map[string]string{"x": "y"}}}
|
||||
tests := []struct {
|
||||
name string
|
||||
originalOptions config.Options
|
||||
|
@ -198,6 +201,8 @@ func Test_UpdateOptions(t *testing.T) {
|
|||
{"no websockets, custom timeout", good, customTimeout, "", "https://corp.example.example", false, true},
|
||||
{"enable cors preflight", good, corsPreflight, "", "https://corp.example.example", false, true},
|
||||
{"disable auth", good, disableAuth, "", "https://corp.example.example", false, true},
|
||||
{"enable forward auth", good, fwdAuth, "", "https://corp.example.example", false, true},
|
||||
{"set request headers", good, reqHeaders, "", "https://corp.example.example", false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -1,79 +1,136 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import http.server
|
||||
import json
|
||||
import sys
|
||||
|
||||
import urllib.parse
|
||||
import webbrowser
|
||||
from urllib.parse import urlparse
|
||||
import requests
|
||||
|
||||
done = False
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--openid-configuration',
|
||||
default="https://accounts.google.com/.well-known/openid-configuration")
|
||||
parser.add_argument('--client-id')
|
||||
parser.add_argument('--client-secret')
|
||||
parser.add_argument('--pomerium-client-id')
|
||||
parser.add_argument('--code')
|
||||
parser.add_argument('--pomerium-token-url',
|
||||
default="https://authenticate.corp.beyondperimeter.com/api/v1/token")
|
||||
parser.add_argument('--pomerium-token')
|
||||
parser.add_argument('--pomerium-url', default="https://httpbin.corp.beyondperimeter.com/get")
|
||||
parser.add_argument("--login", action="store_true")
|
||||
parser.add_argument(
|
||||
"--dst", default="https://httpbin.imac.bdd.io/headers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--refresh-endpoint", default="https://authenticate.imac.bdd.io/api/v1/refresh",
|
||||
)
|
||||
parser.add_argument("--server", default="localhost", type=str)
|
||||
parser.add_argument("--port", default=8000, type=int)
|
||||
parser.add_argument(
|
||||
"--cred", default="pomerium-cred.json",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
class PomeriumSession:
|
||||
def __init__(self, jwt, refresh_token):
|
||||
self.jwt = jwt
|
||||
self.refresh_token = refresh_token
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self.__dict__, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, fn):
|
||||
with open(fn) as f:
|
||||
data = json.load(f)
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class Callback(http.server.BaseHTTPRequestHandler):
|
||||
def log_message(self, format, *args):
|
||||
# silence http server logs for now
|
||||
return
|
||||
|
||||
def do_GET(self):
|
||||
global args
|
||||
global done
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
response = b"OK"
|
||||
if "pomerium" in self.path:
|
||||
path = urllib.parse.urlparse(self.path).query
|
||||
path_qp = urllib.parse.parse_qs(path)
|
||||
session = PomeriumSession(
|
||||
path_qp.get("pomerium_jwt")[0],
|
||||
path_qp.get("pomerium_refresh_token")[0],
|
||||
)
|
||||
done = True
|
||||
response = session.to_json().encode()
|
||||
with open(args.cred, "w", encoding="utf-8") as f:
|
||||
f.write(session.to_json())
|
||||
print("=> pomerium json credential saved to:\n{}".format(f.name))
|
||||
|
||||
self.wfile.write(response)
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
code = args.code
|
||||
pomerium_token = args.pomerium_token
|
||||
oidc_document = requests.get(args.openid_configuration).json()
|
||||
token_url = oidc_document['token_endpoint']
|
||||
print(token_url)
|
||||
sign_in_url = oidc_document['authorization_endpoint']
|
||||
global args
|
||||
|
||||
if not code and not pomerium_token:
|
||||
if not args.client_id:
|
||||
print("client-id is required")
|
||||
sys.exit(1)
|
||||
dst = urllib.parse.urlparse(args.dst)
|
||||
try:
|
||||
cred = PomeriumSession.from_json_file(args.cred)
|
||||
except:
|
||||
print("=> no credential found, let's login")
|
||||
args.login = True
|
||||
|
||||
sign_in_url = "{}?response_type=code&scope=openid%20email&access_type=offline&redirect_uri=urn:ietf:wg:oauth:2.0:oob&client_id={}".format(
|
||||
oidc_document['authorization_endpoint'], args.client_id)
|
||||
print("Access code not set, so we'll do the process interactively!")
|
||||
print("Go to the url : {}".format(sign_in_url))
|
||||
code = input("Complete the login and enter your code:")
|
||||
print(code)
|
||||
# initial login to make sure we have our credential
|
||||
if args.login:
|
||||
dst = urllib.parse.urlparse(args.dst)
|
||||
query_params = {"redirect_uri": "http://{}:{}".format(args.server, args.port)}
|
||||
enc_query_params = urllib.parse.urlencode(query_params)
|
||||
dst_login = "{}://{}{}?{}".format(
|
||||
dst.scheme, dst.hostname, "/.pomerium/api/v1/login", enc_query_params,
|
||||
)
|
||||
response = requests.get(dst_login)
|
||||
print("=> Your browser has been opened to visit:\n{}".format(response.text))
|
||||
webbrowser.open(response.text)
|
||||
|
||||
if not pomerium_token:
|
||||
req = requests.post(
|
||||
token_url, {
|
||||
'client_id': args.client_id,
|
||||
'client_secret': args.client_secret,
|
||||
'code': code,
|
||||
'redirect_uri': 'urn:ietf:wg:oauth:2.0:oob',
|
||||
'grant_type': 'authorization_code'
|
||||
})
|
||||
with http.server.HTTPServer((args.server, args.port), Callback) as httpd:
|
||||
while not done:
|
||||
httpd.handle_request()
|
||||
|
||||
refresh_token = req.json()['refresh_token']
|
||||
print("refresh token: {}".format(refresh_token))
|
||||
|
||||
print("create a new id_token with our pomerium app as the audience")
|
||||
req = requests.post(
|
||||
token_url, {
|
||||
'refresh_token': refresh_token,
|
||||
'client_id': args.client_id,
|
||||
'client_secret': args.client_secret,
|
||||
'audience': args.pomerium_client_id,
|
||||
'grant_type': 'refresh_token'
|
||||
})
|
||||
id_token = req.json()['id_token']
|
||||
print("pomerium id_token: {}".format(id_token))
|
||||
|
||||
print("exchange our identity providers id token for a pomerium bearer token")
|
||||
req = requests.post(args.pomerium_token_url, {'id_token': id_token})
|
||||
pomerium_token = req.json()['Token']
|
||||
print("pomerium bearer token is: {}".format(pomerium_token))
|
||||
|
||||
req = requests.get(args.pomerium_url, headers={'Authorization': 'Bearer ' + pomerium_token})
|
||||
json_formatted = json.dumps(req.json(), indent=1)
|
||||
print(json_formatted)
|
||||
cred = PomeriumSession.from_json_file(args.cred)
|
||||
response = requests.get(
|
||||
args.dst,
|
||||
headers={
|
||||
"Authorization": "Pomerium {}".format(cred.jwt),
|
||||
"Content-type": "application/json",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
print(
|
||||
"==> request\n{}\n==> response.status_code\n{}\n==>response.text\n{}\n".format(
|
||||
args.dst, response.status_code, response.text
|
||||
)
|
||||
)
|
||||
# if response.status_code == 200:
|
||||
if response.status_code == 401:
|
||||
# user our refresh token to get a new cred
|
||||
print("==> got a 401, let's try to refresh that credential")
|
||||
response = requests.get(
|
||||
args.refresh_endpoint,
|
||||
headers={
|
||||
"Authorization": "Pomerium {}".format(cred.refresh_token),
|
||||
"Content-type": "application/json",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
print(
|
||||
"==>request\n{}\n ==> response.status_code\n{}\nresponse.text==>\n{}\n".format(
|
||||
args.refresh_endpoint, response.status_code, response.text
|
||||
)
|
||||
)
|
||||
# update our cred!
|
||||
with open(args.cred, "w", encoding="utf-8") as f:
|
||||
f.write(response.text)
|
||||
print("=> pomerium json credential saved to:\n{}".format(f.name))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,53 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Create a new OAUTH2 provider DISTINCT from your pomerium configuration
|
||||
# Select type as "OTHER"
|
||||
CLIENT_ID='REPLACE-ME.apps.googleusercontent.com'
|
||||
CLIENT_SECRET='REPLACE-ME'
|
||||
SIGNIN_URL='https://accounts.google.com/o/oauth2/v2/auth?client_id='$CLIENT_ID'&response_type=code&scope=openid%20email&access_type=offline&redirect_uri=urn:ietf:wg:oauth:2.0:oob'
|
||||
|
||||
# This would be your pomerium client id
|
||||
POMERIUM_CLIENT_ID='REPLACE-ME.apps.googleusercontent.com'
|
||||
|
||||
echo "Follow the following URL to get an offline auth code from your IdP"
|
||||
echo $SIGNIN_URL
|
||||
|
||||
read -p 'Enter the authorization code as a result of logging in: ' CODE
|
||||
echo $CODE
|
||||
|
||||
echo "Exchange our authorization code to get a refresh_token"
|
||||
echo "refresh_tokens can be used to generate indefinite access tokens / id_tokens"
|
||||
curl \
|
||||
-d client_id=$CLIENT_ID \
|
||||
-d client_secret=$CLIENT_SECRET \
|
||||
-d code=$CODE \
|
||||
-d redirect_uri=urn:ietf:wg:oauth:2.0:oob \
|
||||
-d grant_type=authorization_code \
|
||||
https://www.googleapis.com/oauth2/v4/token
|
||||
|
||||
read -p 'Enter the refresh token result:' REFRESH_TOKEN
|
||||
echo $REFRESH_TOKEN
|
||||
|
||||
echo "Use our refresh_token to create a new id_token with an audience of pomerium's oauth client"
|
||||
curl \
|
||||
-d client_id=$CLIENT_ID \
|
||||
-d client_secret=$CLIENT_SECRET \
|
||||
-d refresh_token=$REFRESH_TOKEN \
|
||||
-d grant_type=refresh_token \
|
||||
-d audience=$POMERIUM_CLIENT_ID \
|
||||
https://www.googleapis.com/oauth2/v4/token
|
||||
|
||||
echo "now we have an id_token with an audience that matches that of our pomerium app"
|
||||
read -p 'Enter the resulting id_token:' ID_TOKEN
|
||||
echo $ID_TOKEN
|
||||
|
||||
curl -X POST \
|
||||
-d id_token=$ID_TOKEN \
|
||||
https://authenticate.corp.beyondperimeter.com/api/v1/token
|
||||
|
||||
read -p 'Enter the resulting Token:' POMERIUM_ACCESS_TOKEN
|
||||
echo $POMERIUM_ACCESS_TOKEN
|
||||
|
||||
echo "we have our bearer token that can be used with pomerium now"
|
||||
curl \
|
||||
-H "Authorization: Bearer ${POMERIUM_ACCESS_TOKEN}" \
|
||||
"https://httpbin.corp.beyondperimeter.com/"
|
Loading…
Add table
Add a link
Reference in a new issue