all: support route scoped sessions

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2019-11-06 17:30:27 -08:00 committed by Bobby DeSimone
parent 83342112bb
commit d3d60d1055
53 changed files with 2092 additions and 2416 deletions

View file

@ -7,9 +7,12 @@ import (
"fmt"
"html/template"
"net/url"
"time"
"github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates"
@ -18,6 +21,10 @@ import (
const callbackPath = "/oauth2/callback"
// DefaultSessionDuration is the default time a managed route session is
// valid for.
var DefaultSessionDuration = time.Minute * 10
// ValidateOptions checks that configuration are complete and valid.
// Returns on first error found.
func ValidateOptions(o config.Options) error {
@ -41,18 +48,34 @@ func ValidateOptions(o config.Options) error {
// Authenticate contains data required to run the authenticate service.
type Authenticate struct {
SharedKey string
// RedirectURL is the authenticate service's externally accessible
// url that the identity provider (IdP) will callback to following
// authentication flow
RedirectURL *url.URL
cookieName string
cookieSecure bool
cookieDomain string
// sharedKey is used to encrypt and authenticate data between services
sharedKey string
// sharedCipher is used to encrypt data for use between services
sharedCipher cipher.AEAD
// sharedEncoder is the encoder to use to serialize data to be consumed
// by other services
sharedEncoder sessions.Encoder
// data related to this service only
cookieOptions *sessions.CookieOptions
// cookieSecret is the secret to encrypt and authenticate data for this service
cookieSecret []byte
templates *template.Template
// is the cipher to use to encrypt data for this service
cookieCipher cipher.AEAD
sessionStore sessions.SessionStore
cipher cipher.AEAD
encoder cryptutil.SecureEncoder
encryptedEncoder sessions.Encoder
sessionStores []sessions.SessionStore
sessionLoaders []sessions.SessionLoader
// provider is the interface to interacting with the identity provider (IdP)
provider identity.Authenticator
templates *template.Template
}
// New validates and creates a new authenticate service from a set of Options.
@ -60,29 +83,37 @@ func New(opts config.Options) (*Authenticate, error) {
if err := ValidateOptions(opts); err != nil {
return nil, err
}
// shared state encoder setup
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey)
signedEncoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host)
if err != nil {
return nil, err
}
// private state encoder setup
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
cipher, err := cryptutil.NewAEADCipher(decodedCookieSecret)
encoder := cryptutil.NewSecureJSONEncoder(cipher)
if err != nil {
return nil, err
}
if opts.CookieDomain == "" {
opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String())
}
cookieStore, err := sessions.NewCookieStore(
&sessions.CookieStoreOptions{
cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret)
encryptedEncoder := ecjson.New(cookieCipher)
cookieOptions := &sessions.CookieOptions{
Name: opts.CookieName,
CookieDomain: opts.CookieDomain,
CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieExpire: opts.CookieExpire,
Encoder: encoder,
})
Domain: opts.CookieDomain,
Secure: opts.CookieSecure,
HTTPOnly: opts.CookieHTTPOnly,
Expire: opts.CookieExpire,
}
cookieStore, err := sessions.NewCookieStore(cookieOptions, encryptedEncoder)
if err != nil {
return nil, err
}
qpStore := sessions.NewQueryParamStore(encryptedEncoder, "pomerium_programmatic_token")
headerStore := sessions.NewHeaderStore(encryptedEncoder, "Pomerium")
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
redirectURL.Path = callbackPath
// configure our identity provider
provider, err := identity.New(
opts.Provider,
&identity.Provider{
@ -99,16 +130,22 @@ func New(opts config.Options) (*Authenticate, error) {
}
return &Authenticate{
SharedKey: opts.SharedKey,
RedirectURL: redirectURL,
templates: templates.New(),
sessionStore: cookieStore,
cipher: cipher,
encoder: encoder,
provider: provider,
// shared state
sharedKey: opts.SharedKey,
sharedCipher: sharedCipher,
sharedEncoder: signedEncoder,
// private state
cookieSecret: decodedCookieSecret,
cookieName: opts.CookieName,
cookieDomain: opts.CookieDomain,
cookieSecure: opts.CookieSecure,
cookieCipher: cookieCipher,
cookieOptions: cookieOptions,
sessionStore: cookieStore,
encryptedEncoder: encryptedEncoder,
sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore},
sessionStores: []sessions.SessionStore{cookieStore, qpStore},
// IdP
provider: provider,
templates: templates.New(),
}, nil
}

View file

@ -31,35 +31,37 @@ var CSPHeaders = map[string]string{
"Referrer-Policy": "Same-origin",
}
// Handler returns the authenticate service's HTTP multiplexer, and routes.
// Handler returns the authenticate service's handler chain.
func (a *Authenticate) Handler() http.Handler {
r := httputil.NewRouter()
r.Use(middleware.SetHeaders(CSPHeaders))
r.Use(csrf.Protect(
a.cookieSecret,
csrf.Secure(a.cookieSecure),
csrf.Secure(a.cookieOptions.Secure),
csrf.Path("/"),
csrf.Domain(a.cookieDomain),
csrf.UnsafePaths([]string{callbackPath}), // enforce CSRF on "safe" handler
csrf.FormValueName("state"), // rfc6749 section-10.12
csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieName)),
csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieOptions.Name)),
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
))
r.HandleFunc("/robots.txt", a.RobotsTxt).Methods(http.MethodGet)
// Identity Provider (IdP) endpoints
r.HandleFunc("/oauth2/callback", a.OAuthCallback).Methods(http.MethodGet)
r.HandleFunc("/api/v1/token", a.ExchangeToken)
// Proxy service endpoints
v := r.PathPrefix("/.pomerium").Subrouter()
v.Use(middleware.ValidateSignature(a.SharedKey))
v.Use(middleware.ValidateRedirectURI(a.RedirectURL))
v.Use(sessions.RetrieveSession(a.sessionStore))
v.Use(middleware.ValidateSignature(a.sharedKey))
v.Use(sessions.RetrieveSession(a.sessionLoaders...))
v.Use(a.VerifySession)
v.HandleFunc("/sign_in", a.SignIn)
v.HandleFunc("/sign_out", a.SignOut)
// programmatic access api endpoint
api := r.PathPrefix("/api").Subrouter()
api.Use(sessions.RetrieveSession(a.sessionLoaders...))
api.HandleFunc("/v1/refresh", a.RefreshAPI)
return r
}
@ -71,14 +73,14 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
if errors.Is(err, sessions.ErrExpired) {
if err := a.refresh(w, r, state); err != nil {
log.FromRequest(r).Debug().Str("cause", err.Error()).Msg("authenticate: couldn't refresh session")
a.sessionStore.ClearSession(w, r)
a.redirectToIdentityProvider(w, r)
return
}
// redirect to restart middleware-chain following refresh
http.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound)
return
} else if err != nil {
log.FromRequest(r).Err(err).Msg("authenticate: unexpected session state")
a.sessionStore.ClearSession(w, r)
log.FromRequest(r).Err(err).Msg("authenticate: malformed session")
a.redirectToIdentityProvider(w, r)
return
}
@ -95,7 +97,6 @@ func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessio
return fmt.Errorf("authenticate: refresh save failed: %w", err)
}
return nil
}
// RobotsTxt handles the /robots.txt route.
@ -108,19 +109,64 @@ func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
// SignIn handles to authenticating a user.
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
// grab and parse our redirect_uri
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
return
}
// Add query param to let downstream apps (or auth endpoints) know
// this request followed authentication. Useful for auth-forward-endpoint
// redirecting
// create a clone of the redirect URI, unless this is a programmatic request
// in which case we will redirect back to proxy's callback endpoint
callbackURL, _ := urlutil.DeepCopy(redirectURL)
callbackURL.Path = "/.pomerium/callback"
q := redirectURL.Query()
q.Add("pomerium-auth-callback", "true")
if q.Get("pomerium_programmatic_destination_url") != "" {
callbackURL, err = urlutil.ParseAndValidateURL(q.Get("pomerium_programmatic_destination_url"))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
return
}
}
s, err := sessions.FromContext(r.Context())
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
return
}
s.SetImpersonation(q.Get("impersonate_email"), q.Get("impersonate_group"))
newSession := s.NewSession(a.RedirectURL.Host, []string{a.RedirectURL.Host, callbackURL.Host})
if q.Get("pomerium_programmatic_destination_url") != "" {
newSession.Programmatic = true
encSession, err := a.encryptedEncoder.Marshal(newSession)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
return
}
q.Set("pomerium_refresh_token", string(encSession))
}
// sign the route session, as a JWT
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession(DefaultSessionDuration))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
return
}
// encrypt our route-based token JWT avoiding any accidental logging
encryptedJWT := cryptutil.Encrypt(a.sharedCipher, signedJWT, nil)
// base64 our encrypted payload for URL-friendlyness
encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT)
// add our encoded and encrypted route-session JWT to a query param
q.Set("pomerium_jwt", encodedJWT)
redirectURL.RawQuery = q.Encode()
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
// build our hmac-d redirect URL with our session, pointing back to the
// proxy's callback URL which is responsible for setting our new route-session
uri := urlutil.SignedRedirectURL(a.sharedKey, callbackURL, redirectURL)
http.Redirect(w, r, uri.String(), http.StatusFound)
}
// SignOut signs the user out and attempts to revoke the user's identity session
@ -132,7 +178,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
return
}
a.sessionStore.ClearSession(w, r)
err = a.provider.Revoke(session.AccessToken)
err = a.provider.Revoke(r.Context(), session.AccessToken)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err))
return
@ -152,11 +198,12 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
// https://tools.ietf.org/html/rfc6749#section-4.2.1
func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http.Request) {
a.sessionStore.ClearSession(w, r)
redirectURL := a.RedirectURL.ResolveReference(r.URL)
nonce := csrf.Token(r)
now := time.Now().Unix()
b := []byte(fmt.Sprintf("%s|%d|", nonce, now))
enc := cryptutil.Encrypt(a.cipher, []byte(redirectURL.String()), b)
enc := cryptutil.Encrypt(a.cookieCipher, []byte(redirectURL.String()), b)
b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b)
http.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
@ -201,7 +248,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
return nil, httputil.Error("malformed state", http.StatusBadRequest, err)
}
// split state into its it's components, e.g.
// split state into concat'd components
// (nonce|timestamp|redirect_url|encrypted_data(redirect_url)+mac(nonce,ts))
statePayload := strings.SplitN(string(bytes), "|", 3)
if len(statePayload) != 3 {
@ -209,16 +256,16 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
fmt.Errorf("state malformed, size: %d", len(statePayload)))
}
// verify that the returned timestamp is valid (replay attack)
// verify that the returned timestamp is valid
if err := cryptutil.ValidTimestamp(statePayload[1]); err != nil {
return nil, httputil.Error(err.Error(), http.StatusBadRequest, err)
}
// Use our AEAD construct to enforce secrecy and authenticity,:
// Use our AEAD construct to enforce secrecy and authenticity:
// mac: to validate the nonce again, and above timestamp
// decrypt: to prevent leaking 'redirect_uri' to IdP or logs)
// decrypt: to prevent leaking 'redirect_uri' to IdP or logs
b := []byte(fmt.Sprint(statePayload[0], "|", statePayload[1], "|"))
redirectString, err := cryptutil.Decrypt(a.cipher, []byte(statePayload[2]), b)
redirectString, err := cryptutil.Decrypt(a.cookieCipher, []byte(statePayload[2]), b)
if err != nil {
return nil, httputil.Error("'state' has invalid hmac", http.StatusBadRequest, err)
}
@ -235,38 +282,45 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
return redirectURL, nil
}
// ExchangeToken takes an identity provider issued JWT as input ('id_token)
// and exchanges that token for a pomerium session. The provided token's
// audience ('aud') attribute must match Pomerium's client_id.
func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) {
code := r.FormValue("id_token")
if code == "" {
httputil.ErrorResponse(w, r, httputil.Error("missing id token", http.StatusBadRequest, nil))
// RefreshAPI loads a global state, and attempts to refresh the session's access
// tokens and state with the identity provider. If successful, a new signed JWT
// and refresh token (`refresh_token`) are returned as JSON
func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) {
s, err := sessions.FromContext(r.Context())
if err != nil && !errors.Is(err, sessions.ErrExpired) {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
return
}
session, err := a.provider.IDTokenToSession(r.Context(), code)
newSession, err := a.provider.Refresh(r.Context(), s)
if err != nil {
httputil.ErrorResponse(w, r, err)
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusInternalServerError, err))
return
}
encToken, err := sessions.MarshalSession(session, a.encoder)
newSession = newSession.NewSession(s.Issuer, s.Audience)
encSession, err := a.encryptedEncoder.Marshal(newSession)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusInternalServerError, err))
return
}
restSession := struct {
Token string
Expiry time.Time `json:",omitempty"`
}{
Token: encToken,
Expiry: session.RefreshDeadline,
}
jsonBytes, err := json.Marshal(restSession)
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession(DefaultSessionDuration))
if err != nil {
httputil.ErrorResponse(w, r, err)
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusInternalServerError, err))
return
}
var response struct {
JWT string `json:"jwt"`
RefreshToken string `json:"refresh_token"`
}
response.RefreshToken = string(encSession)
response.JWT = string(signedJWT)
jsonResponse, err := json.Marshal(&response)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(jsonBytes)
w.Write(jsonResponse)
}

View file

@ -7,22 +7,27 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates"
"github.com/google/go-cmp/cmp"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt"
)
func testAuthenticate() *Authenticate {
var auth Authenticate
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
auth.SharedKey = "IzY7MOZwzfOkmELXgozHDKTxoT3nOYhwkcmUVINsRww="
auth.cookieSecret = []byte(auth.SharedKey)
auth.sharedKey = cryptutil.NewBase64Key()
auth.cookieSecret = cryptutil.NewKey()
auth.cookieOptions = &sessions.CookieOptions{Name: "name"}
auth.templates = templates.New()
return &auth
}
@ -67,29 +72,30 @@ func TestAuthenticate_Handler(t *testing.T) {
func TestAuthenticate_SignIn(t *testing.T) {
t.Parallel()
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
state string
redirectURI string
scheme string
host string
qp map[string]string
session sessions.SessionStore
restStore sessions.SessionStore
provider identity.MockProvider
encoder cryptutil.SecureEncoder
encoder sessions.Encoder
wantCode int
}{
{"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusFound},
{"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockEncoder{}, http.StatusFound},
{"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockEncoder{}, http.StatusFound},
{"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockEncoder{}, http.StatusFound}, // mocking hmac is meh
{"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockEncoder{}, http.StatusFound},
// {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusInternalServerError},
{"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusFound},
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusFound},
{"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusBadRequest},
// actually caught by go's handler, but we should keep the test.
{"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusBadRequest},
{"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{MarshalError: errors.New("error")}, http.StatusFound},
{"good", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{}, http.StatusFound},
{"session not valid", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{}, http.StatusFound},
{"bad redirect uri query", "", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{}, http.StatusBadRequest},
{"bad marshal", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"session error", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{}, &encoding.MockEncoder{}, http.StatusBadRequest},
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{}, http.StatusFound},
{"encrypted encoder error", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &encoding.MockEncoder{}, http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -97,16 +103,28 @@ func TestAuthenticate_SignIn(t *testing.T) {
sessionStore: tt.session,
provider: tt.provider,
RedirectURL: uriParseHelper("https://some.example"),
SharedKey: "secret",
encoder: tt.encoder,
sharedKey: "secret",
sharedEncoder: tt.encoder,
encryptedEncoder: tt.encoder,
sharedCipher: aead,
cookieOptions: &sessions.CookieOptions{
Name: "cookie",
Domain: "foo",
},
}
uri := &url.URL{Host: "corp.some.example", Scheme: "https", Path: "/"}
uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI)
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
queryString := uri.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
uri.RawQuery = queryString.Encode()
r := httptest.NewRequest(http.MethodGet, "/?redirect_uri="+uri.String(), nil)
r.Header.Set("Accept", "application/json")
state, _ := tt.session.LoadSession(r)
state, err := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, nil)
ctx = sessions.NewContext(ctx, state, err)
r = r.WithContext(ctx)
w := httptest.NewRecorder()
@ -141,10 +159,10 @@ func TestAuthenticate_SignOut(t *testing.T) {
wantCode int
wantBody string
}{
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, "could not revoke"},
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{LoadError: errors.New("hi"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, ""},
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, http.StatusBadRequest, ""},
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"error\":\"could not revoke user session\"}\n"},
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"error\":\"Bad Request\"}\n"},
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"error\":\"malformed redirect_uri\"}\n"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -164,15 +182,19 @@ func TestAuthenticate_SignOut(t *testing.T) {
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
a.SignOut(w, r)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
}
if body := w.Body.String(); !strings.Contains(body, tt.wantBody) {
t.Errorf("handler returned wrong body Body: got \n%s \n%s", body, tt.wantBody)
body := w.Body.String()
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
t.Errorf("handler returned wrong body Body: %s", diff)
}
})
}
}
@ -199,19 +221,19 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
want string
wantCode int
}{
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusFound},
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound},
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusInternalServerError},
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError},
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -224,7 +246,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
RedirectURL: authURL,
sessionStore: tt.session,
provider: tt.provider,
cipher: aead,
cookieCipher: aead,
}
u, _ := url.Parse("/oauthGet")
params, _ := url.ParseQuery(u.RawQuery)
@ -235,7 +257,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
// (nonce|timestamp|redirect_url|encrypt(redirect_url),mac(nonce,ts))
b := []byte(fmt.Sprintf("%s|%d|%s", nonce, tt.ts, tt.extraMac))
enc := cryptutil.Encrypt(a.cipher, []byte(tt.redirectURI), b)
enc := cryptutil.Encrypt(a.cookieCipher, []byte(tt.redirectURI), b)
b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b)
if tt.extraState != "" {
@ -261,59 +283,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
}
}
func TestAuthenticate_ExchangeToken(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
idToken string
restStore sessions.SessionStore
encoder cryptutil.SecureEncoder
provider identity.MockProvider
want string
}{
{"good", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
{"could not exchange identity for session", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionError: errors.New("error")}, ""},
{"missing token", http.MethodPost, "", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "missing id token"},
{"malformed form", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
{"can't marshal token", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{MarshalError: errors.New("can't marshal token")}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "can't marshal token"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
a := &Authenticate{
encoder: tt.encoder,
provider: tt.provider,
sessionStore: tt.restStore,
cipher: aead,
}
form := url.Values{}
if tt.idToken != "" {
form.Add("id_token", tt.idToken)
}
rawForm := form.Encode()
if tt.name == "malformed form" {
rawForm = "example=%zzzzz"
}
r := httptest.NewRequest(tt.method, "/", strings.NewReader(rawForm))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
a.ExchangeToken(w, r)
got := w.Body.String()
if !strings.Contains(got, tt.want) {
t.Errorf("Authenticate.ExchangeToken() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -331,11 +300,11 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
wantStatus int
}{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusOK},
{"expired,refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
{"expired,save error", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusFound},
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"good refresh expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
{"expired,refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
{"expired,save error", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -344,12 +313,12 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
t.Fatal(err)
}
a := Authenticate{
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
sharedKey: cryptutil.NewBase64Key(),
cookieSecret: cryptutil.NewKey(),
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
sessionStore: tt.session,
provider: tt.provider,
cipher: aead,
cookieCipher: aead,
}
r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r)
@ -370,3 +339,57 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
})
}
}
func TestAuthenticate_RefreshAPI(t *testing.T) {
t.Parallel()
tests := []struct {
name string
session sessions.SessionStore
ctxError error
provider identity.Authenticator
secretEncoder sessions.Encoder
sharedEncoder sessions.Encoder
wantStatus int
}{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, http.StatusOK},
{"refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
{"session is not refreshable error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
{"secret encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, encoding.MockEncoder{MarshalError: errors.New("error")}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
{"shared encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, encoding.MockEncoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
a := Authenticate{
sharedKey: cryptutil.NewBase64Key(),
cookieSecret: cryptutil.NewKey(),
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
encryptedEncoder: tt.secretEncoder,
sharedEncoder: tt.sharedEncoder,
sessionStore: tt.session,
provider: tt.provider,
cookieCipher: aead,
}
r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
a.RefreshAPI(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
})
}
}

View file

@ -1,5 +1,30 @@
# Changelog
## vUnreleased
### New
- Session state is now route-scoped. Each managed route uses a transparent, signed JSON Web Token (JWT) to assert identity.
- Managed routes no longer need to be under the same subdomain! Access can be delegated to any route, on any domain.
- Programmatic access now also uses JWT tokens. Access tokens are now generated via a standard oauth2 token flow, and credentials can be refreshed for as long as is permitted by the underlying identity provider.
- User dashboard now pulls in additional user context fields (where supported) like the profile picture, first and last name, and so on.
### Security
- Some identity providers (Okta and Azure) previously used mutable signifiers to set and assert group membership. Group membership for all providers now use globally unique and immutable identifiers when available.
### Changed
- Azure AD identity provider now uses globally unique and immutable `ID` for [group membership](https://docs.microsoft.com/en-us/graph/api/group-get?view=graph-rest-1.0&tabs=http).
- Okta no longer uses tokens to retrieve group membership. Group membership is now fetched using Okta's HTTP API. [Group membership](https://developer.okta.com/docs/reference/api/groups/) is now determined by the globally unique and immutable `ID` field.
- Okta now requires an additional set of credentials to be used to query for group membership set as a [service account](https://www.pomerium.io/docs/reference/reference.html#identity-provider-service-account).
- URLs are no longer validated to be on the same domain-tree as the authenticate service. Managed routes can live on any domain.
### Removed
- Force refresh has been removed from the dashboard.
- Previous programmatic authentication endpoints (`/api/v1/token`) has been removed and is no longer supported.
## v0.4.2
### Security

View file

@ -11,6 +11,30 @@ description: >-
### Breaking
#### Subdomain requirement dropped
- Pomerium services and managed routes are no longer required to be on the same domain-tree. Access can be delegated to any route, on any domain (that you have access to, of course).
#### Azure AD
- The Azure AD provider now uses the globally unique and immutable`ID` instead of `group name` to attest a user's [group membership](https://docs.microsoft.com/en-us/graph/api/group-get?view=graph-rest-1.0&tabs=http). Please update your policies to use Group `ID`s instead of group names.
#### Okta
- Okta no longer uses tokens to retrieve group membership. [Group membership](https://developer.okta.com/docs/reference/api/groups/) is now fetched using Okta's API. Please update your policies to use Group `ID`s instead of group names.
- Okta's group membership is now determined by the globally unique and immutable ID field.
- Okta now requires an additional set of credentials to be used to query for group membership set as a [service account](https://www.pomerium.io/docs/reference/reference.html#identity-provider-service-account).
#### Force Refresh Removed
Force refresh has been removed from the dashboard. Logging out and back in again should have the equivalent desired effect.
#### Programmatic Access API changed
Previous programmatic authentication endpoints (`/api/v1/token`) has been removed and has been replaced by a per-route, oauth2 based auth flow. Please see updated [programmatic documentation](https://www.pomerium.io/docs/reference/programmatic-access.html) how to use the new programmatic access api.
#### Forward-auth route change
Previously, routes were verified by taking the downstream applications hostname in the form of a path `(e.g. ${fwdauth}/.pomerium/verify/httpbin.some.example`) variable. The new method for verifying a route using forward authentication is to pass the entire requested url in the form of a query string `(e.g. ${fwdauth}/.pomerium/verify?url=https://httpbin.some.example)` where the routed domain is the value of the `uri` key.
Note that the verification URL is no longer nested under the `.pomerium` endpoint.

16
go.mod
View file

@ -3,7 +3,7 @@ module github.com/pomerium/pomerium
go 1.12
require (
cloud.google.com/go v0.40.0 // indirect
cloud.google.com/go v0.47.0 // indirect
contrib.go.opencensus.io/exporter/jaeger v0.1.0
contrib.go.opencensus.io/exporter/prometheus v0.1.0
github.com/fsnotify/fsnotify v1.4.7
@ -25,14 +25,14 @@ require (
github.com/spf13/viper v1.4.0
github.com/stretchr/testify v1.4.0 // indirect
go.opencensus.io v0.22.0
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8
golang.org/x/net v0.0.0-20191003171128-d98b1b443823
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4
golang.org/x/net v0.0.0-20191028085509-fe3aa8a45271
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9 // indirect
google.golang.org/api v0.6.0
google.golang.org/appengine v1.6.1 // indirect
google.golang.org/genproto v0.0.0-20191002211648-c459b9ce5143 // indirect
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c // indirect
google.golang.org/api v0.13.0
google.golang.org/appengine v1.6.5 // indirect
google.golang.org/genproto v0.0.0-20191028173616-919d9bdd9fe6 // indirect
google.golang.org/grpc v1.24.0
gopkg.in/square/go-jose.v2 v2.3.1
gopkg.in/square/go-jose.v2 v2.4.0
gopkg.in/yaml.v2 v2.2.3
)

85
go.sum
View file

@ -2,14 +2,24 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT
cloud.google.com/go v0.34.0 h1:eOI3/cP2VTU6uZLDYAoic+eyzzB9YyGmJ7eIjl8rOPg=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU=
cloud.google.com/go v0.40.0 h1:FjSY7bOj+WzJe6TZRVtXI2b9kAYvtNg4lMbcH2+MUkk=
cloud.google.com/go v0.40.0/go.mod h1:Tk58MuI9rbLMKlAjeO/bDnteAx7tX2gJIXw4T5Jwlro=
cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU=
cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY=
cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc=
cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0=
cloud.google.com/go v0.47.0 h1:1JUtpcY9E7+eTospEwWS2QXP3DEn7poB3E2j0jN74mM=
cloud.google.com/go v0.47.0/go.mod h1:5p3Ky/7f3N10VBkhuR5LFtddroTiMyjZV/Kj5qOQFxU=
cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o=
cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I=
cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw=
contrib.go.opencensus.io/exporter/jaeger v0.1.0 h1:WNc9HbA38xEQmsI40Tjd/MNU/g8byN2Of7lwIjv0Jdc=
contrib.go.opencensus.io/exporter/jaeger v0.1.0/go.mod h1:VYianECmuFPwU37O699Vc1GOcy+y8kOsfaxHRImmjbA=
contrib.go.opencensus.io/exporter/prometheus v0.1.0 h1:SByaIoWwNgMdPSgl5sMqM2KDE5H/ukPWBRo314xiDvg=
contrib.go.opencensus.io/exporter/prometheus v0.1.0/go.mod h1:cGFniUXGZlKRjzOyuZJ6mgB+PgBcCIa79kEKR8YCW+A=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
@ -40,6 +50,7 @@ github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFP
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
@ -71,7 +82,11 @@ github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk=
@ -157,6 +172,7 @@ github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7z
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/zerolog v1.14.3 h1:4EGfSkR2hJDB0s3oFfrlPqjU1e4WLncergLil3nEKW0=
github.com/rs/zerolog v1.14.3/go.mod h1:3WXPzbXEEliJ+a6UFE4vhIxV8qR1EML6ngzP9ug4eYg=
@ -199,16 +215,30 @@ go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmVhcMLU6v5fEb/ok4wyqtNU=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8 h1:1wopBVtVdWnn03fZelqdXTqk7U7zPQCb+T4rbU9ZEoU=
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf h1:fnPsqIDRbCSgumaMCRpoIoF2s4qxv0xSSS0BVZUE/ss=
golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4 h1:PDpCLFAH/YIX0QpHPf2eO7L4rC2OOirBrKtXTLLiNTY=
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek=
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -225,8 +255,9 @@ golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65 h1:+rhAzEzT3f4JtomfC371qB+0Ola2caSKcY69NUBZrRQ=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqPKXhus2MkfktJ91kRh4=
golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191028085509-fe3aa8a45271 h1:N66aaryRB3Ax92gH0v3hp1QYZ3zWWCCUR/j8Ifh45Ss=
golang.org/x/net v0.0.0-20191028085509-fe3aa8a45271/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421 h1:Wo7BWFiOk0QRFMLYMqJGFMd9CgUAcGx7V+qEg/h5IBI=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -247,12 +278,14 @@ golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9 h1:L2auWcuQIvxz9xSEqzESnV/QN/gNRXNApHi3fYwl2w0=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c h1:S/FtSvpNLtFBgjTqcKsRpsa6aVsI6iztaz1bQd9BJwE=
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -266,35 +299,51 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191010171213-8abd42400456/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
google.golang.org/api v0.3.2/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.6.0 h1:2tJEkRfnZL5g1GeBUlITh/rqT5HG3sFcoVCUUxmgJ2g=
google.golang.org/api v0.6.0/go.mod h1:btoxGiFvQNVUZQ8W08zLtrVS08CNpINPEfxXxgJL1Q4=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
google.golang.org/api v0.13.0 h1:Q3Ui3V3/CVinFWFiW39Iw0kMuVrRzYX0wN6OPFp0lTA=
google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.6.1 h1:QzqyMA1tlu6CgqCDUtU9V+ZKhLFT2dkJuANu5QaxI3I=
google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0=
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20190530194941-fb225487d101/go.mod h1:z3L6/3dTEVtUr6QSP8miRzeRqwQOioJ9I66odjN4I7s=
google.golang.org/genproto v0.0.0-20191002211648-c459b9ce5143 h1:tikhlQEJeezbnu0Zcblj7g5vm/L7xt6g1vnfq8mRCS4=
google.golang.org/genproto v0.0.0-20191002211648-c459b9ce5143/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8=
google.golang.org/genproto v0.0.0-20191009194640-548a555dbc03/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
google.golang.org/genproto v0.0.0-20191028173616-919d9bdd9fe6 h1:UXl+Zk3jqqcbEVV7ace5lrt4YdA4tXiz3f/KbmD29Vo=
google.golang.org/genproto v0.0.0-20191028173616-919d9bdd9fe6/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
google.golang.org/grpc v1.24.0 h1:vb/1TCsVn3DcJlQ0Gs1yB1pKI6Do2/QNwxdKqmc/b0s=
google.golang.org/grpc v1.24.0/go.mod h1:XDChyiUovWa60DnaeDeZmSW86xtLtjtZbwvSiRnRtcA=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
@ -302,10 +351,11 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=
gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/square/go-jose.v2 v2.4.0 h1:0kXPskUMGAXXWJlP05ktEMOV0vmzFQUWw6d+aZJQU8A=
gopkg.in/square/go-jose.v2 v2.4.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
@ -319,4 +369,5 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=

View file

@ -1,9 +1,12 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import (
"crypto/ecdsa"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
)
@ -49,3 +52,80 @@ func bytesToCertPool(b []byte) (*x509.CertPool, error) {
}
return certPool, nil
}
// DecodePublicKey decodes a PEM-encoded ECDSA public key.
func DecodePublicKey(encodedKey []byte) (*ecdsa.PublicKey, error) {
block, _ := pem.Decode(encodedKey)
if block == nil {
return nil, fmt.Errorf("cryptutil: decoded nil PEM block")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, err
}
ecdsaPub, ok := pub.(*ecdsa.PublicKey)
if !ok {
return nil, errors.New("cryptutil: data was not an ECDSA public key")
}
return ecdsaPub, nil
}
// EncodePublicKey encodes an ECDSA public key to PEM format.
func EncodePublicKey(key *ecdsa.PublicKey) ([]byte, error) {
derBytes, err := x509.MarshalPKIXPublicKey(key)
if err != nil {
return nil, err
}
block := &pem.Block{
Type: "PUBLIC KEY",
Bytes: derBytes,
}
return pem.EncodeToMemory(block), nil
}
// DecodePrivateKey decodes a PEM-encoded ECDSA private key.
func DecodePrivateKey(encodedKey []byte) (*ecdsa.PrivateKey, error) {
var skippedTypes []string
var block *pem.Block
for {
block, encodedKey = pem.Decode(encodedKey)
if block == nil {
return nil, fmt.Errorf("cryptutil: failed to find EC PRIVATE KEY in PEM data after skipping types %v", skippedTypes)
}
if block.Type == "EC PRIVATE KEY" {
break
} else {
skippedTypes = append(skippedTypes, block.Type)
continue
}
}
privKey, err := x509.ParseECPrivateKey(block.Bytes)
if err != nil {
return nil, err
}
return privKey, nil
}
// EncodePrivateKey encodes an ECDSA private key to PEM format.
func EncodePrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
derKey, err := x509.MarshalECPrivateKey(key)
if err != nil {
return nil, err
}
keyBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: derKey,
}
return pem.EncodeToMemory(keyBlock), nil
}

View file

@ -1,10 +1,36 @@
package cryptutil
import (
"bytes"
"crypto/tls"
"strings"
"testing"
)
// A keypair for NIST P-256 / secp256r1
// Generated using:
// openssl ecparam -genkey -name prime256v1 -outform PEM
var pemECPrivateKeyP256 = `-----BEGIN EC PARAMETERS-----
BggqhkjOPQMBBw==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIOI+EZsjyN3jvWJI/KDihFmqTuDpUe/if6f/pgGTBta/oAoGCCqGSM49
AwEHoUQDQgAEhhObKJ1r1PcUw+3REd/TbmSZnDvXnFUSTwqQFo5gbfIlP+gvEYba
+Rxj2hhqjfzqxIleRK40IRyEi3fJM/8Qhg==
-----END EC PRIVATE KEY-----
`
var pemECPublicKeyP256 = `-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEhhObKJ1r1PcUw+3REd/TbmSZnDvX
nFUSTwqQFo5gbfIlP+gvEYba+Rxj2hhqjfzqxIleRK40IRyEi3fJM/8Qhg==
-----END PUBLIC KEY-----
`
var garbagePEM = `-----BEGIN GARBAGE-----
TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQ=
-----END GARBAGE-----
`
func TestCertifcateFromBase64(t *testing.T) {
tests := []struct {
@ -91,3 +117,39 @@ func TestCertificateFromFile(t *testing.T) {
}
_ = listener
}
func TestPublicKeyMarshaling(t *testing.T) {
ecKey, err := DecodePublicKey([]byte(pemECPublicKeyP256))
if err != nil {
t.Fatal(err)
}
_, err = DecodePublicKey(nil)
if err == nil {
t.Fatal("expected error")
}
pemBytes, _ := EncodePublicKey(ecKey)
if !bytes.Equal(pemBytes, []byte(pemECPublicKeyP256)) {
t.Fatal("public key encoding did not match")
}
}
func TestPrivateKeyBadDecode(t *testing.T) {
_, err := DecodePrivateKey([]byte(garbagePEM))
if err == nil {
t.Fatal("decoded garbage data without complaint")
}
}
func TestPrivateKeyMarshaling(t *testing.T) {
ecKey, err := DecodePrivateKey([]byte(pemECPrivateKeyP256))
if err != nil {
t.Fatal(err)
}
pemBytes, _ := EncodePrivateKey(ecKey)
if !strings.HasSuffix(pemECPrivateKeyP256, string(pemBytes)) {
t.Fatal("private key encoding did not match")
}
}

View file

@ -1,13 +1,9 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import (
"bytes"
"compress/gzip"
"crypto/cipher"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"golang.org/x/crypto/chacha20poly1305"
)
@ -30,106 +26,6 @@ func NewAEADCipherFromBase64(s string) (cipher.AEAD, error) {
return NewAEADCipher(decoded)
}
// SecureEncoder provides and interface for to encrypt and decrypting structures .
type SecureEncoder interface {
Marshal(interface{}) (string, error)
Unmarshal(string, interface{}) error
}
// SecureJSONEncoder implements SecureEncoder for JSON using an AEAD cipher.
//
// See https://en.wikipedia.org/wiki/Authenticated_encryption
type SecureJSONEncoder struct {
aead cipher.AEAD
}
// NewSecureJSONEncoder takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher.
func NewSecureJSONEncoder(aead cipher.AEAD) SecureEncoder {
return &SecureJSONEncoder{aead: aead}
}
// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher
// and base64 encodes the binary value as a string and returns the result
//
// can panic if source of random entropy is exhausted generating a nonce.
func (c *SecureJSONEncoder) Marshal(s interface{}) (string, error) {
// encode json value
plaintext, err := json.Marshal(s)
if err != nil {
return "", err
}
// compress the plaintext bytes
compressed, err := compress(plaintext)
if err != nil {
return "", err
}
// encrypt the compressed JSON bytes
ciphertext := Encrypt(c.aead, compressed, nil)
// base64-encode the result
encoded := base64.RawURLEncoding.EncodeToString(ciphertext)
return encoded, nil
}
// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the
// byte slice the passed cipher, and unmarshals the resulting JSON into the struct pointer passed
func (c *SecureJSONEncoder) Unmarshal(value string, s interface{}) error {
// convert base64 string value to bytes
ciphertext, err := base64.RawURLEncoding.DecodeString(value)
if err != nil {
return err
}
// decrypt the bytes
compressed, err := Decrypt(c.aead, ciphertext, nil)
if err != nil {
return err
}
// decompress the unencrypted bytes
plaintext, err := decompress(compressed)
if err != nil {
return err
}
// unmarshal the unencrypted bytes
err = json.Unmarshal(plaintext, s)
if err != nil {
return err
}
return nil
}
// compress gzips a set of bytes
func compress(data []byte) ([]byte, error) {
var buf bytes.Buffer
writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression)
if err != nil {
return nil, fmt.Errorf("cryptutil: failed to create a gzip writer: %q", err)
}
if writer == nil {
return nil, fmt.Errorf("cryptutil: failed to create a gzip writer")
}
if _, err = writer.Write(data); err != nil {
return nil, fmt.Errorf("cryptutil: failed to compress data with err: %q", err)
}
if err = writer.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// decompress un-gzips a set of bytes
func decompress(data []byte) ([]byte, error) {
reader, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("cryptutil: failed to create a gzip reader: %q", err)
}
defer reader.Close()
var buf bytes.Buffer
if _, err = io.Copy(&buf, reader); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// Encrypt encrypts a value with optional associated data
//
// Panics if source of randomness fails.

View file

@ -39,106 +39,6 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
}
}
func TestMarshalAndUnmarshalStruct(t *testing.T) {
key := NewKey()
a, err := NewAEADCipher(key)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
c := SecureJSONEncoder{aead: a}
type TC struct {
Field string `json:"field"`
}
tc := &TC{
Field: "my plain text value",
}
value1, err := c.Marshal(tc)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
value2, err := c.Marshal(tc)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if value1 == value2 {
t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2)
}
got1 := &TC{}
err = c.Unmarshal(value1, got1)
if err != nil {
t.Fatalf("unexpected err unmarshalling struct: %v", err)
}
if !reflect.DeepEqual(got1, tc) {
t.Logf("want: %#v", tc)
t.Logf(" got: %#v", got1)
t.Fatalf("expected structs to be equal")
}
got2 := &TC{}
err = c.Unmarshal(value2, got2)
if err != nil {
t.Fatalf("unexpected err unmarshalling struct: %v", err)
}
if !reflect.DeepEqual(got1, got2) {
t.Logf("got2: %#v", got2)
t.Logf("got1: %#v", got1)
t.Fatalf("expected structs to be equal")
}
}
func TestSecureJSONEncoder_Marshal(t *testing.T) {
t.Parallel()
tests := []struct {
name string
s interface{}
wantErr bool
}{
{"unsupported type",
struct {
Animal string `json:"animal"`
Func func() `json:"sound"`
}{
Animal: "cat",
Func: func() {},
},
true},
{"simple",
struct {
Animal string `json:"animal"`
Sound string `json:"sound"`
}{
Animal: "cat",
Sound: "meow",
},
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewAEADCipher(NewKey())
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
e := SecureJSONEncoder{aead: c}
_, err = e.Marshal(tt.s)
if (err != nil) != tt.wantErr {
t.Errorf("SecureJSONEncoder.Marshal() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func TestNewAEADCipher(t *testing.T) {
t.Parallel()
tests := []struct {

View file

@ -7,6 +7,11 @@ import (
"time"
)
const (
// DefaultLeeway defines the default leeway for matching NotBefore/Expiry claims.
DefaultLeeway = 1.0 * time.Minute
)
var (
errTimestampMalformed = errors.New("internal/cryptutil: timestamp malformed")
errTimestampExpired = errors.New("internal/cryptutil: timestamp expired")
@ -31,7 +36,6 @@ func CheckHMAC(data, suppliedMAC []byte, key string) bool {
// ValidTimestamp is a helper function often used in conjunction with an HMAC
// function to verify that the timestamp (in unix seconds) is within leeway
// period.
// todo(bdd) : should leeway be configurable?
func ValidTimestamp(ts string) error {
var timeStamp int64
var err error

View file

@ -1,100 +0,0 @@
// Package cryptutil provides encoding and decoding routines for various cryptographic structures.
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import (
"crypto/ecdsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
)
// DecodePublicKey decodes a PEM-encoded ECDSA public key.
func DecodePublicKey(encodedKey []byte) (*ecdsa.PublicKey, error) {
block, _ := pem.Decode(encodedKey)
if block == nil {
return nil, fmt.Errorf("cryptutil: decoded nil PEM block")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, err
}
ecdsaPub, ok := pub.(*ecdsa.PublicKey)
if !ok {
return nil, errors.New("cryptutil: data was not an ECDSA public key")
}
return ecdsaPub, nil
}
// EncodePublicKey encodes an ECDSA public key to PEM format.
func EncodePublicKey(key *ecdsa.PublicKey) ([]byte, error) {
derBytes, err := x509.MarshalPKIXPublicKey(key)
if err != nil {
return nil, err
}
block := &pem.Block{
Type: "PUBLIC KEY",
Bytes: derBytes,
}
return pem.EncodeToMemory(block), nil
}
// DecodePrivateKey decodes a PEM-encoded ECDSA private key.
func DecodePrivateKey(encodedKey []byte) (*ecdsa.PrivateKey, error) {
var skippedTypes []string
var block *pem.Block
for {
block, encodedKey = pem.Decode(encodedKey)
if block == nil {
return nil, fmt.Errorf("cryptutil: failed to find EC PRIVATE KEY in PEM data after skipping types %v", skippedTypes)
}
if block.Type == "EC PRIVATE KEY" {
break
} else {
skippedTypes = append(skippedTypes, block.Type)
continue
}
}
privKey, err := x509.ParseECPrivateKey(block.Bytes)
if err != nil {
return nil, err
}
return privKey, nil
}
// EncodePrivateKey encodes an ECDSA private key to PEM format.
func EncodePrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
derKey, err := x509.MarshalECPrivateKey(key)
if err != nil {
return nil, err
}
keyBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: derKey,
}
return pem.EncodeToMemory(keyBlock), nil
}
// EncodeSignatureJWT encodes an ECDSA signature according to
// https://tools.ietf.org/html/rfc7515#appendix-A.3.1
func EncodeSignatureJWT(sig []byte) string {
return base64.RawURLEncoding.EncodeToString(sig)
}
// DecodeSignatureJWT decodes an ECDSA signature according to
// https://tools.ietf.org/html/rfc7515#appendix-A.3.1
func DecodeSignatureJWT(b64sig string) ([]byte, error) {
return base64.RawURLEncoding.DecodeString(b64sig)
}

View file

@ -1,105 +0,0 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import (
"bytes"
"strings"
"testing"
)
// A keypair for NIST P-256 / secp256r1
// Generated using:
// openssl ecparam -genkey -name prime256v1 -outform PEM
var pemECPrivateKeyP256 = `-----BEGIN EC PARAMETERS-----
BggqhkjOPQMBBw==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIOI+EZsjyN3jvWJI/KDihFmqTuDpUe/if6f/pgGTBta/oAoGCCqGSM49
AwEHoUQDQgAEhhObKJ1r1PcUw+3REd/TbmSZnDvXnFUSTwqQFo5gbfIlP+gvEYba
+Rxj2hhqjfzqxIleRK40IRyEi3fJM/8Qhg==
-----END EC PRIVATE KEY-----
`
var pemECPublicKeyP256 = `-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEhhObKJ1r1PcUw+3REd/TbmSZnDvX
nFUSTwqQFo5gbfIlP+gvEYba+Rxj2hhqjfzqxIleRK40IRyEi3fJM/8Qhg==
-----END PUBLIC KEY-----
`
var garbagePEM = `-----BEGIN GARBAGE-----
TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQ=
-----END GARBAGE-----
`
func TestPublicKeyMarshaling(t *testing.T) {
ecKey, err := DecodePublicKey([]byte(pemECPublicKeyP256))
if err != nil {
t.Fatal(err)
}
_, err = DecodePublicKey(nil)
if err == nil {
t.Fatal("expected error")
}
pemBytes, _ := EncodePublicKey(ecKey)
if !bytes.Equal(pemBytes, []byte(pemECPublicKeyP256)) {
t.Fatal("public key encoding did not match")
}
}
func TestPrivateKeyBadDecode(t *testing.T) {
_, err := DecodePrivateKey([]byte(garbagePEM))
if err == nil {
t.Fatal("decoded garbage data without complaint")
}
}
func TestPrivateKeyMarshaling(t *testing.T) {
ecKey, err := DecodePrivateKey([]byte(pemECPrivateKeyP256))
if err != nil {
t.Fatal(err)
}
pemBytes, _ := EncodePrivateKey(ecKey)
if !strings.HasSuffix(pemECPrivateKeyP256, string(pemBytes)) {
t.Fatal("private key encoding did not match")
}
}
// Test vector from https://tools.ietf.org/html/rfc7515#appendix-A.3.1
var jwtTest = []struct {
sigBytes []byte
b64sig string
}{
{
sigBytes: []byte{14, 209, 33, 83, 121, 99, 108, 72, 60, 47, 127, 21,
88, 7, 212, 2, 163, 178, 40, 3, 58, 249, 124, 126, 23, 129, 154, 195, 22, 158,
166, 101, 197, 10, 7, 211, 140, 60, 112, 229, 216, 241, 45, 175,
8, 74, 84, 128, 166, 101, 144, 197, 242, 147, 80, 154, 143, 63, 127, 138, 131,
163, 84, 213},
b64sig: "DtEhU3ljbEg8L38VWAfUAqOyKAM6-Xx-F4GawxaepmXFCgfTjDxw5djxLa8ISlSApmWQxfKTUJqPP3-Kg6NU1Q",
},
}
func TestJWTEncoding(t *testing.T) {
for _, tt := range jwtTest {
result := EncodeSignatureJWT(tt.sigBytes)
if strings.Compare(result, tt.b64sig) != 0 {
t.Fatalf("expected %s, got %s\n", tt.b64sig, result)
}
}
}
func TestJWTDecoding(t *testing.T) {
for _, tt := range jwtTest {
resultSig, err := DecodeSignatureJWT(tt.b64sig)
if err != nil {
t.Error(err)
}
if !bytes.Equal(resultSig, tt.sigBytes) {
t.Fatalf("decoded signature was incorrect")
}
}
}

View file

@ -1,107 +0,0 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import (
"encoding/base64"
"fmt"
"sync"
"time"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
)
const (
// DefaultLeeway defines the default leeway for matching NotBefore/Expiry claims.
DefaultLeeway = 5.0 * time.Minute
)
// JWTSigner implements JWT signing according to JSON Web Token (JWT) RFC7519
// https://tools.ietf.org/html/rfc7519
type JWTSigner interface {
SignJWT(string, string, string) (string, error)
}
// ES256Signer is struct containing the required fields to create a ES256 signed JSON Web Tokens
type ES256Signer struct {
signer jose.Signer
mu sync.Mutex
// User (sub) is unique, stable identifier for the user.
// Use in place of the x-pomerium-authenticated-user-id header.
User string `json:"sub,omitempty"`
// Email (email) is a **custom** claim name identifier for the user email address.
// Use in place of the x-pomerium-authenticated-user-email header.
Email string `json:"email,omitempty"`
// Groups (groups) is a **custom** claim name identifier for the user's groups.
// Use in place of the x-pomerium-authenticated-user-groups header.
Groups string `json:"groups,omitempty"`
// Audience (aud) must be the destination of the upstream proxy locations.
// e.g. `helloworld.corp.example.com`
Audience jwt.Audience `json:"aud,omitempty"`
// Issuer (iss) is the URL of the proxy.
// e.g. `proxy.corp.example.com`
Issuer string `json:"iss,omitempty"`
// Expiry (exp) is the expiration time in seconds since the UNIX epoch.
// Allow 1 minute for skew. The maximum lifetime of a token is 10 minutes + 2 * skew.
Expiry jwt.NumericDate `json:"exp,omitempty"`
// IssuedAt (iat) is the time is measured in seconds since the UNIX epoch.
// Allow 1 minute for skew.
IssuedAt jwt.NumericDate `json:"iat,omitempty"`
// IssuedAt (nbf) is the time is measured in seconds since the UNIX epoch.
// Allow 1 minute for skew.
NotBefore jwt.NumericDate `json:"nbf,omitempty"`
}
// NewES256Signer creates a NIST P-256 (aka secp256r1 aka prime256v1) JWT signer
// from a base64 encoded private key.
//
// RSA is not supported due to performance considerations of needing to sign each request.
// Go's P-256 is constant-time and SHA-256 is faster on 64-bit machines and immune
// to length extension attacks.
// See also:
// - https://cloud.google.com/iot/docs/how-tos/credentials/keys
func NewES256Signer(privKey, audience string) (*ES256Signer, error) {
decodedSigningKey, err := base64.StdEncoding.DecodeString(privKey)
if err != nil {
return nil, err
}
key, err := DecodePrivateKey(decodedSigningKey)
if err != nil {
return nil, fmt.Errorf("cryptutil: parsing key failed %v", err)
}
signer, err := jose.NewSigner(
jose.SigningKey{
Algorithm: jose.ES256, // ECDSA using P-256 and SHA-256
Key: key,
},
(&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
return nil, fmt.Errorf("cryptutil: new signer failed %v", err)
}
return &ES256Signer{
Issuer: "pomerium-proxy",
Audience: jwt.Audience{audience},
signer: signer,
}, nil
}
// SignJWT creates a signed JWT containing claims for the logged in
// user id (`sub`), email (`email`) and groups (`groups`).
func (s *ES256Signer) SignJWT(user, email, groups string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.User = user
s.Email = email
s.Groups = groups
now := time.Now()
s.IssuedAt = *jwt.NewNumericDate(now)
s.Expiry = *jwt.NewNumericDate(now.Add(DefaultLeeway))
s.NotBefore = *jwt.NewNumericDate(now.Add(-1 * DefaultLeeway))
rawJWT, err := jwt.Signed(s.signer).Claims(s).CompactSerialize()
if err != nil {
return "", fmt.Errorf("cryptutil: sign failed %v", err)
}
return rawJWT, nil
}

View file

@ -1,46 +0,0 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import (
"encoding/base64"
"testing"
)
func TestES256Signer(t *testing.T) {
signer, err := NewES256Signer(base64.StdEncoding.EncodeToString([]byte(pemECPrivateKeyP256)), "destination-url")
if err != nil {
t.Fatal(err)
}
if signer == nil {
t.Fatal("signer should not be nil")
}
rawJwt, err := signer.SignJWT("joe-user", "joe-user@example.com", "group1,group2")
if err != nil {
t.Fatal(err)
}
if rawJwt == "" {
t.Fatal("jwt should not be nil")
}
}
func TestNewES256Signer(t *testing.T) {
t.Parallel()
tests := []struct {
name string
privKey string
audience string
wantErr bool
}{
{"working example", base64.StdEncoding.EncodeToString([]byte(pemECPrivateKeyP256)), "some-domain.com", false},
{"bad private key", base64.StdEncoding.EncodeToString([]byte(garbagePEM)), "some-domain.com", true},
{"bad base64 key", garbagePEM, "some-domain.com", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewES256Signer(tt.privKey, tt.audience)
if (err != nil) != tt.wantErr {
t.Errorf("NewES256Signer() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

View file

@ -0,0 +1,110 @@
// Package ecjson represents encrypted and compressed content using JSON-based
package ecjson // import "github.com/pomerium/pomerium/internal/encoding/ecjson"
import (
"bytes"
"compress/gzip"
"crypto/cipher"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"github.com/pomerium/pomerium/internal/cryptutil"
)
// EncryptedCompressedJSON implements SecureEncoder for JSON using an AEAD cipher.
//
// See https://en.wikipedia.org/wiki/Authenticated_encryption
type EncryptedCompressedJSON struct {
aead cipher.AEAD
}
// New takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher.
func New(aead cipher.AEAD) *EncryptedCompressedJSON {
return &EncryptedCompressedJSON{aead: aead}
}
// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher
// and base64 encodes the binary value as a string and returns the result
//
// can panic if source of random entropy is exhausted generating a nonce.
func (c *EncryptedCompressedJSON) Marshal(s interface{}) ([]byte, error) {
// encode json value
plaintext, err := json.Marshal(s)
if err != nil {
return nil, err
}
// compress the plaintext bytes
compressed, err := compress(plaintext)
if err != nil {
return nil, err
}
// encrypt the compressed JSON bytes
ciphertext := cryptutil.Encrypt(c.aead, compressed, nil)
// base64-encode the result
encoded := base64.RawURLEncoding.EncodeToString(ciphertext)
return []byte(encoded), nil
}
// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the
// byte slice the passed cipher, and unmarshals the resulting JSON into the struct pointer passed
func (c *EncryptedCompressedJSON) Unmarshal(data []byte, s interface{}) error {
// convert base64 string value to bytes
ciphertext, err := base64.RawURLEncoding.DecodeString(string(data))
if err != nil {
return err
}
// decrypt the bytes
compressed, err := cryptutil.Decrypt(c.aead, ciphertext, nil)
if err != nil {
return err
}
// decompress the unencrypted bytes
plaintext, err := decompress(compressed)
if err != nil {
return err
}
// unmarshal the unencrypted bytes
err = json.Unmarshal(plaintext, s)
if err != nil {
return err
}
return nil
}
// compress gzips a set of bytes
func compress(data []byte) ([]byte, error) {
var buf bytes.Buffer
writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression)
if err != nil {
return nil, fmt.Errorf("cryptutil: failed to create a gzip writer: %q", err)
}
if writer == nil {
return nil, fmt.Errorf("cryptutil: failed to create a gzip writer")
}
if _, err = writer.Write(data); err != nil {
return nil, fmt.Errorf("cryptutil: failed to compress data with err: %q", err)
}
if err = writer.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// decompress un-gzips a set of bytes
func decompress(data []byte) ([]byte, error) {
reader, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("cryptutil: failed to create a gzip reader: %q", err)
}
defer reader.Close()
var buf bytes.Buffer
if _, err = io.Copy(&buf, reader); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View file

@ -0,0 +1,70 @@
// Package jws represents content secured with digitalsignatures
// using JSON-based data structures as specified by rfc7515
package jws // import "github.com/pomerium/pomerium/internal/encoding/jws"
import (
"encoding/base64"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/pomerium/pomerium/internal/cryptutil"
)
// JSONWebSigner is the struct representing a signed JWT.
// https://tools.ietf.org/html/rfc7519
type JSONWebSigner struct {
Signer jose.Signer
Issuer string
key interface{}
}
// NewHS256Signer creates a SHA256 JWT signer from a 32 byte key.
func NewHS256Signer(key []byte, issuer string) (*JSONWebSigner, error) {
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key},
(&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
return nil, err
}
return &JSONWebSigner{Signer: sig, key: key, Issuer: issuer}, nil
}
// NewES256Signer creates a NIST P-256 (aka secp256r1 aka prime256v1) JWT signer
// from a base64 encoded private key.
//
// RSA is not supported due to performance considerations of needing to sign each request.
// Go's P-256 is constant-time and SHA-256 is faster on 64-bit machines and immune
// to length extension attacks.
// See : https://cloud.google.com/iot/docs/how-tos/credentials/keys
func NewES256Signer(privKey, issuer string) (*JSONWebSigner, error) {
decodedSigningKey, err := base64.StdEncoding.DecodeString(privKey)
if err != nil {
return nil, err
}
key, err := cryptutil.DecodePrivateKey(decodedSigningKey)
if err != nil {
return nil, err
}
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key},
(&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
return nil, err
}
return &JSONWebSigner{Signer: sig, key: key, Issuer: issuer}, nil
}
// Marshal signs, and serializes a JWT.
func (c *JSONWebSigner) Marshal(x interface{}) ([]byte, error) {
s, err := jwt.Signed(c.Signer).Claims(x).CompactSerialize()
return []byte(s), err
}
// Unmarshal parses and validates a signed JWT.
func (c *JSONWebSigner) Unmarshal(value []byte, s interface{}) error {
tok, err := jwt.ParseSigned(string(value))
if err != nil {
return err
}
return tok.Claims(c.key, s)
}

View file

@ -1,18 +1,18 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
package encoding // import "github.com/pomerium/pomerium/internal/encoding"
// MockEncoder MockCSRFStore is a mock implementation of Cipher.
type MockEncoder struct {
MarshalResponse string
MarshalResponse []byte
MarshalError error
UnmarshalError error
}
// Marshal is a mock implementation of MockEncoder.
func (mc MockEncoder) Marshal(i interface{}) (string, error) {
func (mc MockEncoder) Marshal(i interface{}) ([]byte, error) {
return mc.MarshalResponse, mc.MarshalError
}
// Unmarshal is a mock implementation of MockEncoder.
func (mc MockEncoder) Unmarshal(s string, i interface{}) error {
func (mc MockEncoder) Unmarshal(s []byte, i interface{}) error {
return mc.UnmarshalError
}

View file

@ -1,4 +1,4 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
package encoding // import "github.com/pomerium/pomerium/internal/encoding"
import (
"errors"
@ -8,7 +8,7 @@ import (
func TestMockEncoder(t *testing.T) {
e := errors.New("err")
mc := MockEncoder{
MarshalResponse: "MarshalResponse",
MarshalResponse: []byte("MarshalResponse"),
MarshalError: e,
UnmarshalError: e,
}
@ -16,10 +16,10 @@ func TestMockEncoder(t *testing.T) {
if err != e {
t.Error("unexpected Marshal error")
}
if s != "MarshalResponse" {
if string(s) != "MarshalResponse" {
t.Error("unexpected MarshalResponse error")
}
err = mc.Unmarshal("s", "s")
err = mc.Unmarshal([]byte("s"), "s")
if err != e {
t.Error("unexpected Unmarshal error")
}

View file

@ -2,6 +2,7 @@ package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@ -27,7 +28,7 @@ var httpClient = &http.Client{
}
// Client provides a simple helper interface to make HTTP requests
func Client(method, endpoint, userAgent string, headers map[string]string, params url.Values, response interface{}) error {
func Client(ctx context.Context, method, endpoint, userAgent string, headers map[string]string, params url.Values, response interface{}) error {
var body io.Reader
switch method {
case http.MethodPost:
@ -41,7 +42,7 @@ func Client(method, endpoint, userAgent string, headers map[string]string, param
default:
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
}
req, err := http.NewRequest(method, endpoint, body)
req, err := http.NewRequestWithContext(ctx, method, endpoint, body)
if err != nil {
return err
}

View file

@ -9,9 +9,9 @@ import (
"github.com/pomerium/pomerium/internal/log"
)
// HeaderForwardHost is the header key the identifies the originating
// IP addresses of a client connecting to a web server through an HTTP proxy
// or a load balancer.
// HeaderForwardHost is the header key that identifies the original host requested
// by the client in the Host HTTP request header.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-Host
const HeaderForwardHost = "X-Forwarded-Host"
// NewReverseProxy returns a new ReverseProxy that routes

View file

@ -4,14 +4,13 @@ import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"time"
oidc "github.com/pomerium/go-oidc"
"golang.org/x/oauth2"
"golang.org/x/oauth2/jwt"
"golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1"
"github.com/pomerium/pomerium/internal/httputil"
@ -22,14 +21,12 @@ import (
const defaultGoogleProviderURL = "https://accounts.google.com"
// JWTTokenURL is Google's OAuth 2.0 token URL to use with the JWT flow.
const JWTTokenURL = "https://accounts.google.com/o/oauth2/token"
// GoogleProvider is an implementation of the Provider interface.
type GoogleProvider struct {
*Provider
// non-standard oidc fields
RevokeURL *url.URL
RevokeURL string `json:"revocation_endpoint"`
apiClient *admin.Service
}
@ -61,13 +58,9 @@ func NewGoogleProvider(p *Provider) (*GoogleProvider, error) {
gp := &GoogleProvider{
Provider: p,
}
// google supports a revocation endpoint
var claims struct {
RevokeURL string `json:"revocation_endpoint"`
}
// build api client to make group membership api calls
if err := p.provider.Claims(&claims); err != nil {
if err := p.provider.Claims(&gp); err != nil {
return nil, err
}
// if service account set, configure admin sdk calls
@ -78,34 +71,37 @@ func NewGoogleProvider(p *Provider) (*GoogleProvider, error) {
}
// Required scopes for groups api
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
conf, err := JWTConfigFromJSON(apiCreds, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope)
conf, err := google.JWTConfigFromJSON(apiCreds, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope)
if err != nil {
return nil, fmt.Errorf("identity/google: failed making jwt config from json %v", err)
}
var credentialsFile struct {
ImpersonateUser string `json:"impersonate_user"`
}
if err := json.Unmarshal(apiCreds, &credentialsFile); err != nil {
return nil, err
}
conf.Subject = credentialsFile.ImpersonateUser
client := conf.Client(context.TODO())
gp.apiClient, err = admin.New(client)
if err != nil {
return nil, fmt.Errorf("identity/google: failed creating admin service %v", err)
}
gp.UserGroupFn = gp.UserGroups
} else {
log.Warn().Msg("identity/google: no service account, cannot retrieve groups")
}
gp.RevokeURL, err = url.Parse(claims.RevokeURL)
if err != nil {
return nil, err
}
return gp, nil
}
// Revoke revokes the access token a given session state.
//
// https://developers.google.com/identity/protocols/OAuth2WebServer#tokenrevoke
func (p *GoogleProvider) Revoke(accessToken string) error {
func (p *GoogleProvider) Revoke(ctx context.Context, token *oauth2.Token) error {
params := url.Values{}
params.Add("token", accessToken)
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), nil, params, nil)
params.Add("token", token.AccessToken)
err := httputil.Client(ctx, http.MethodPost, p.RevokeURL, version.UserAgent(), nil, params, nil)
if err != nil && err != httputil.ErrTokenRevoked {
return err
}
@ -127,95 +123,14 @@ func (p *GoogleProvider) GetSignInURL(state string) string {
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "select_account consent"))
}
// Authenticate creates an identity session with google from a authorization code, and follows up
// call to the admin/group api to check what groups the user is in.
func (p *GoogleProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
oauth2Token, err := p.oauth.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("identity/google: token exchange failed %v", err)
}
// id_token is a JWT that contains identity information about the user
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("identity/google: response did not contain an id_token")
}
session, err := p.IDTokenToSession(ctx, rawIDToken)
if err != nil {
return nil, err
}
session.AccessToken = oauth2Token.AccessToken
session.RefreshToken = oauth2Token.RefreshToken
return session, nil
}
// Refresh renews a user's session using an oidc refresh token withoutreprompting the user.
// Group membership is also refreshed.
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
func (p *GoogleProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.RefreshToken == "" {
return nil, errors.New("identity: missing refresh token")
}
t := oauth2.Token{RefreshToken: s.RefreshToken}
newToken, err := p.oauth.TokenSource(ctx, &t).Token()
if err != nil {
log.Error().Err(err).Msg("identity: refresh failed")
return nil, err
}
// id_token contains claims about the authenticated user
rawIDToken, ok := newToken.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("identity/google: response did not contain an id_token")
}
newSession, err := p.IDTokenToSession(ctx, rawIDToken)
if err != nil {
return nil, err
}
newSession.AccessToken = newToken.AccessToken
newSession.RefreshToken = s.RefreshToken
return newSession, nil
}
// IDTokenToSession takes an identity provider issued JWT as input ('id_token')
// and returns a session state. The provided token's audience ('aud') must
// match Pomerium's client_id.
func (p *GoogleProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) {
idToken, err := p.verifier.Verify(ctx, rawIDToken)
if err != nil {
return nil, fmt.Errorf("identity/google: could not verify id_token %v", err)
}
var claims struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}
// parse claims from the raw, encoded jwt token
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("identity/google: failed to parse id_token claims %v", err)
}
// google requires additional call to retrieve groups.
groups, err := p.UserGroups(ctx, claims.Email)
if err != nil {
return nil, fmt.Errorf("identity/google: could not retrieve groups %v", err)
}
return &sessions.State{
IDToken: rawIDToken,
RefreshDeadline: idToken.Expiry.Truncate(time.Second),
Email: claims.Email,
User: idToken.Subject,
Groups: groups,
}, nil
}
// UserGroups returns a slice of group names a given user is in
// NOTE: groups via Directory API is limited to 1 QPS!
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
// https://developers.google.com/admin-sdk/directory/v1/limits
func (p *GoogleProvider) UserGroups(ctx context.Context, user string) ([]string, error) {
func (p *GoogleProvider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
var groups []string
if p.apiClient != nil {
req := p.apiClient.Groups.List().UserKey(user).MaxResults(100)
req := p.apiClient.Groups.List().UserKey(s.Subject).MaxResults(100)
resp, err := req.Do()
if err != nil {
return nil, fmt.Errorf("identity/google: group api request failed %v", err)
@ -226,57 +141,3 @@ func (p *GoogleProvider) UserGroups(ctx context.Context, user string) ([]string,
}
return groups, nil
}
// JWTConfigFromJSON uses a Google Developers service account JSON key file to read
// the credentials that authorize and authenticate the requests.
// Create a service account on "Credentials" for your project at
// https://console.developers.google.com to download a JSON key file.
func JWTConfigFromJSON(jsonKey []byte, scope ...string) (*jwt.Config, error) {
var f credentialsFile
if err := json.Unmarshal(jsonKey, &f); err != nil {
return nil, err
}
if f.Type != "service_account" {
return nil, fmt.Errorf("identity/google: 'type' field is %q (expected %q)", f.Type, "service_account")
}
// Service account must impersonate a user : https://stackoverflow.com/a/48601364
if f.ImpersonateUser == "" {
return nil, errors.New("identity/google: impersonate_user not found in json config")
}
scope = append([]string(nil), scope...) // copy
return f.jwtConfig(scope), nil
}
// credentialsFile is the unmarshalled representation of a credentials file.
type credentialsFile struct {
Type string `json:"type"` // serviceAccountKey or userCredentialsKey
// Service account must impersonate a user
ImpersonateUser string `json:"impersonate_user"`
// Service Account fields
ClientEmail string `json:"client_email"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
TokenURL string `json:"token_uri"`
ProjectID string `json:"project_id"`
// User Credential fields
ClientSecret string `json:"client_secret"`
ClientID string `json:"client_id"`
RefreshToken string `json:"refresh_token"`
}
func (f *credentialsFile) jwtConfig(scopes []string) *jwt.Config {
cfg := &jwt.Config{
Subject: f.ImpersonateUser,
Email: f.ClientEmail,
PrivateKey: []byte(f.PrivateKey),
PrivateKeyID: f.PrivateKeyID,
Scopes: scopes,
TokenURL: f.TokenURL,
}
if cfg.TokenURL == "" {
cfg.TokenURL = JWTTokenURL
}
return cfg
}

View file

@ -27,7 +27,7 @@ const defaultAzureGroupURL = "https://graph.microsoft.com/v1.0/me/memberOf"
type AzureProvider struct {
*Provider
// non-standard oidc fields
RevokeURL *url.URL
RevokeURL string `json:"end_session_endpoint"`
}
// NewAzureProvider returns a new AzureProvider and sets the provider url endpoints.
@ -54,84 +54,22 @@ func NewAzureProvider(p *Provider) (*AzureProvider, error) {
Scopes: p.Scopes,
}
azureProvider := &AzureProvider{
Provider: p,
}
// azure has a "end session endpoint"
var claims struct {
RevokeURL string `json:"end_session_endpoint"`
}
if err := p.provider.Claims(&claims); err != nil {
return nil, err
}
azureProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
if err != nil {
azureProvider := &AzureProvider{Provider: p}
if err := p.provider.Claims(&azureProvider); err != nil {
return nil, err
}
p.UserGroupFn = azureProvider.UserGroups
return azureProvider, nil
}
// Authenticate creates an identity session with azure from a authorization code, and follows up
// call to the groups api to check what groups the user is in.
func (p *AzureProvider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
// convert authorization code into a token
oauth2Token, err := p.oauth.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("identity/microsoft: token exchange failed %v", err)
}
// id_token contains claims about the authenticated user
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("identity/microsoft: response did not contain an id_token")
}
// Parse and verify ID Token payload.
session, err := p.IDTokenToSession(ctx, rawIDToken)
if err != nil {
return nil, fmt.Errorf("identity/microsoft: could not verify id_token %v", err)
}
session.AccessToken = oauth2Token.AccessToken
session.RefreshToken = oauth2Token.RefreshToken
session.Groups, err = p.UserGroups(ctx, session.AccessToken)
if err != nil {
return nil, fmt.Errorf("identity/microsoft: could not retrieve groups %v", err)
}
return session, nil
}
// IDTokenToSession takes an identity provider issued JWT as input ('id_token')
// and returns a session state. The provided token's audience ('aud') must
// match Pomerium's client_id.
func (p *AzureProvider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) {
idToken, err := p.verifier.Verify(ctx, rawIDToken)
if err != nil {
return nil, fmt.Errorf("identity/microsoft: could not verify id_token %v", err)
}
var claims struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}
// parse claims from the raw, encoded jwt token
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("identity/microsoft: failed to parse id_token claims %v", err)
}
return &sessions.State{
IDToken: rawIDToken,
RefreshDeadline: idToken.Expiry.Truncate(time.Second),
Email: claims.Email,
User: idToken.Subject,
}, nil
}
// Revoke revokes the access token a given session state.
// https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-protocols-oidc#send-a-sign-out-request
func (p *AzureProvider) Revoke(token string) error {
func (p *AzureProvider) Revoke(ctx context.Context, token *oauth2.Token) error {
params := url.Values{}
params.Add("token", token)
err := httputil.Client(http.MethodPost, p.RevokeURL.String(), version.UserAgent(), nil, params, nil)
params.Add("token", token.AccessToken)
err := httputil.Client(ctx, http.MethodPost, p.RevokeURL, version.UserAgent(), nil, params, nil)
if err != nil && err != httputil.ErrTokenRevoked {
return err
}
@ -143,34 +81,14 @@ func (p *AzureProvider) GetSignInURL(state string) string {
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "select_account"))
}
// Refresh renews a user's session using an oid refresh token without reprompting the user.
// Group membership is also refreshed.
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
func (p *AzureProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.RefreshToken == "" {
return nil, errors.New("identity/microsoft: missing refresh token")
}
t := oauth2.Token{RefreshToken: s.RefreshToken}
newToken, err := p.oauth.TokenSource(ctx, &t).Token()
if err != nil {
log.Error().Err(err).Msg("identity/microsoft: refresh failed")
return nil, err
}
s.AccessToken = newToken.AccessToken
s.RefreshDeadline = newToken.Expiry.Truncate(time.Second)
s.Groups, err = p.UserGroups(ctx, s.AccessToken)
if err != nil {
log.Error().Err(err).Msg("identity/microsoft: refresh failed")
return nil, err
}
return s, nil
}
// UserGroups returns a slice of group names a given user is in.
// `Directory.Read.All` is required.
// https://docs.microsoft.com/en-us/graph/api/resources/directoryobject?view=graph-rest-1.0
// https://docs.microsoft.com/en-us/graph/api/user-list-memberof?view=graph-rest-1.0
func (p *AzureProvider) UserGroups(ctx context.Context, accessToken string) ([]string, error) {
func (p *AzureProvider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
if s == nil || s.AccessToken == nil {
return nil, errors.New("identity/azure: session cannot be nil")
}
var response struct {
Groups []struct {
ID string `json:"id"`
@ -180,15 +98,15 @@ func (p *AzureProvider) UserGroups(ctx context.Context, accessToken string) ([]s
GroupTypes []string `json:"groupTypes,omitempty"`
} `json:"value"`
}
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", accessToken)}
err := httputil.Client(http.MethodGet, defaultAzureGroupURL, version.UserAgent(), headers, nil, &response)
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
err := httputil.Client(ctx, http.MethodGet, defaultAzureGroupURL, version.UserAgent(), headers, nil, &response)
if err != nil {
return nil, err
}
var groups []string
for _, group := range response.Groups {
log.Debug().Str("DisplayName", group.DisplayName).Str("ID", group.ID).Msg("identity/microsoft: group")
groups = append(groups, group.DisplayName)
groups = append(groups, group.ID)
}
return groups, nil
}

View file

@ -3,6 +3,8 @@ package identity // import "github.com/pomerium/pomerium/internal/identity"
import (
"context"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/sessions"
)
@ -10,11 +12,7 @@ import (
type MockProvider struct {
AuthenticateResponse sessions.State
AuthenticateError error
IDTokenToSessionResponse sessions.State
IDTokenToSessionError error
ValidateResponse bool
ValidateError error
RefreshResponse *sessions.State
RefreshResponse sessions.State
RefreshError error
RevokeError error
GetSignInURLResponse string
@ -25,23 +23,13 @@ func (mp MockProvider) Authenticate(ctx context.Context, code string) (*sessions
return &mp.AuthenticateResponse, mp.AuthenticateError
}
// IDTokenToSession is a mocked providers function.
func (mp MockProvider) IDTokenToSession(ctx context.Context, code string) (*sessions.State, error) {
return &mp.IDTokenToSessionResponse, mp.IDTokenToSessionError
}
// Validate is a mocked providers function.
func (mp MockProvider) Validate(ctx context.Context, s string) (bool, error) {
return mp.ValidateResponse, mp.ValidateError
}
// Refresh is a mocked providers function.
func (mp MockProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
return mp.RefreshResponse, mp.RefreshError
return &mp.RefreshResponse, mp.RefreshError
}
// Revoke is a mocked providers function.
func (mp MockProvider) Revoke(s string) error {
func (mp MockProvider) Revoke(ctx context.Context, s *oauth2.Token) error {
return mp.RevokeError
}

View file

@ -2,14 +2,9 @@ package identity // import "github.com/pomerium/pomerium/internal/identity"
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
oidc "github.com/pomerium/go-oidc"
"golang.org/x/oauth2"
@ -17,6 +12,7 @@ import (
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/version"
)
@ -26,7 +22,8 @@ import (
type OktaProvider struct {
*Provider
RevokeURL *url.URL
RevokeURL string `json:"revocation_endpoint"`
userAPI *url.URL
}
// NewOktaProvider creates a new instance of Okta as an identity provider.
@ -53,80 +50,62 @@ func NewOktaProvider(p *Provider) (*OktaProvider, error) {
}
// okta supports a revocation endpoint
var claims struct {
RevokeURL string `json:"revocation_endpoint"`
}
if err := p.provider.Claims(&claims); err != nil {
oktaProvider := OktaProvider{Provider: p}
if err := p.provider.Claims(&oktaProvider); err != nil {
return nil, err
}
oktaProvider := OktaProvider{Provider: p}
oktaProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
if p.ServiceAccount != "" {
p.UserGroupFn = oktaProvider.UserGroups
userAPI, err := urlutil.ParseAndValidateURL(p.ProviderURL)
if err != nil {
return nil, err
}
userAPI.Path = "/api/v1/users/"
oktaProvider.userAPI = userAPI
} else {
log.Warn().Msg("identity/okta: api token provided, cannot retrieve groups")
}
return &oktaProvider, nil
}
// Revoke revokes the access token a given session state.
// https://developer.okta.com/docs/api/resources/oidc#revoke
func (p *OktaProvider) Revoke(token string) error {
func (p *OktaProvider) Revoke(ctx context.Context, token *oauth2.Token) error {
params := url.Values{}
params.Add("client_id", p.ClientID)
params.Add("client_secret", p.ClientSecret)
params.Add("token", token)
params.Add("token", token.AccessToken)
params.Add("token_type_hint", "refresh_token")
err := httputil.Client(http.MethodPost, p.RevokeURL.String(), version.UserAgent(), nil, params, nil)
err := httputil.Client(ctx, http.MethodPost, p.RevokeURL, version.UserAgent(), nil, params, nil)
if err != nil && err != httputil.ErrTokenRevoked {
return err
}
return nil
}
type accessToken struct {
Subject string `json:"sub"`
Groups []string `json:"groups"`
}
// Refresh renews a user's session using an oid refresh token without reprompting the user.
// Group membership is also refreshed. If configured properly, Okta is we can configure the access token
// to include group membership claims which allows us to avoid a follow up oauth2 call.
func (p *OktaProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.RefreshToken == "" {
return nil, errors.New("identity/okta: missing refresh token")
// UserGroups fetches the groups of which the user is a member
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
func (p *OktaProvider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
var response []struct {
ID string `json:"id"`
Profile struct {
Name string `json:"name"`
Description string `json:"description"`
} `json:"profile"`
}
t := oauth2.Token{RefreshToken: s.RefreshToken}
newToken, err := p.oauth.TokenSource(ctx, &t).Token()
headers := map[string]string{"Authorization": fmt.Sprintf("SSWS %s", p.ServiceAccount)}
err := httputil.Client(ctx, http.MethodGet, fmt.Sprintf("%s/%s/groups", p.userAPI.String(), s.Subject), version.UserAgent(), headers, nil, &response)
if err != nil {
log.Error().Err(err).Msg("identity/okta: refresh failed")
return nil, err
}
payload, err := parseJWT(newToken.AccessToken)
if err != nil {
return nil, fmt.Errorf("identity/okta: malformed access token jwt: %v", err)
var groups []string
for _, group := range response {
log.Debug().Interface("group", group).Msg("identity/okta: group")
groups = append(groups, group.ID)
}
var token accessToken
if err := json.Unmarshal(payload, &token); err != nil {
return nil, fmt.Errorf("identity/okta: failed to unmarshal access token claims: %v", err)
}
if len(token.Groups) != 0 {
s.Groups = token.Groups
}
s.AccessToken = newToken.AccessToken
s.RefreshDeadline = newToken.Expiry.Truncate(time.Second)
return s, nil
}
func parseJWT(p string) ([]byte, error) {
parts := strings.Split(p, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
}
return payload, nil
return groups, nil
}

View file

@ -12,23 +12,22 @@ import (
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version"
)
const defaultOneLoginProviderURL = "https://openid-connect.onelogin.com/oidc"
const defaultOneloginGroupURL = "https://openid-connect.onelogin.com/oidc/me"
// OneLoginProvider provides a standard, OpenID Connect implementation
// of an authorization identity provider.
type OneLoginProvider struct {
*Provider
// non-standard oidc fields
RevokeURL *url.URL
AdminCreds *credentialsFile
RevokeURL string `json:"revocation_endpoint"`
}
const defaultOneLoginProviderURL = "https://openid-connect.onelogin.com/oidc"
// NewOneLoginProvider creates a new instance of an OpenID Connect provider.
func NewOneLoginProvider(p *Provider) (*OneLoginProvider, error) {
ctx := context.Background()
@ -52,72 +51,38 @@ func NewOneLoginProvider(p *Provider) (*OneLoginProvider, error) {
Scopes: p.Scopes,
}
// okta supports a revocation endpoint
var claims struct {
RevokeURL string `json:"revocation_endpoint"`
}
if err := p.provider.Claims(&claims); err != nil {
return nil, err
}
OneLoginProvider := OneLoginProvider{Provider: p}
olProvider := OneLoginProvider{Provider: p}
OneLoginProvider.RevokeURL, err = url.Parse(claims.RevokeURL)
if err != nil {
if err := p.provider.Claims(&olProvider); err != nil {
return nil, err
}
return &OneLoginProvider, nil
p.UserGroupFn = olProvider.UserGroups
return &olProvider, nil
}
// Revoke revokes the access token a given session state.
// https://developers.onelogin.com/openid-connect/api/revoke-session
func (p *OneLoginProvider) Revoke(token string) error {
func (p *OneLoginProvider) Revoke(ctx context.Context, token *oauth2.Token) error {
params := url.Values{}
params.Add("client_id", p.ClientID)
params.Add("client_secret", p.ClientSecret)
params.Add("token", token)
params.Add("token", token.AccessToken)
params.Add("token_type_hint", "access_token")
err := httputil.Client("POST", p.RevokeURL.String(), version.UserAgent(), nil, params, nil)
err := httputil.Client(ctx, http.MethodPost, p.RevokeURL, version.UserAgent(), nil, params, nil)
if err != nil && err != httputil.ErrTokenRevoked {
log.Error().Err(err).Msg("authenticate/providers: failed to revoke session")
return err
return fmt.Errorf("identity/onelogin: revocation error %w", err)
}
return nil
}
// GetSignInURL returns the sign in url with typical oauth parameters
func (p *OneLoginProvider) GetSignInURL(state string) string {
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline)
}
// Refresh renews a user's session using an oid refresh token without reprompting the user.
// Group membership is also refreshed.
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
func (p *OneLoginProvider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.RefreshToken == "" {
return nil, errors.New("identity/microsoft: missing refresh token")
}
t := oauth2.Token{RefreshToken: s.RefreshToken}
newToken, err := p.oauth.TokenSource(ctx, &t).Token()
if err != nil {
log.Error().Err(err).Msg("identity/microsoft: refresh failed")
return nil, err
}
s.AccessToken = newToken.AccessToken
s.RefreshDeadline = newToken.Expiry.Truncate(time.Second)
s.Groups, err = p.UserGroups(ctx, s.AccessToken)
if err != nil {
log.Error().Err(err).Msg("identity/microsoft: refresh failed")
return nil, err
}
return s, nil
}
const defaultOneloginGroupURL = "https://openid-connect.onelogin.com/oidc/me"
// UserGroups returns a slice of group names a given user is in.
// https://developers.onelogin.com/openid-connect/api/user-info
func (p *OneLoginProvider) UserGroups(ctx context.Context, accessToken string) ([]string, error) {
func (p *OneLoginProvider) UserGroups(ctx context.Context, s *sessions.State) ([]string, error) {
if s == nil || s.AccessToken == nil {
return nil, errors.New("identity/onelogin: session cannot be nil")
}
var response struct {
User string `json:"sub"`
Email string `json:"email"`
@ -128,15 +93,10 @@ func (p *OneLoginProvider) UserGroups(ctx context.Context, accessToken string) (
FamilyName string `json:"family_name"`
Groups []string `json:"groups"`
}
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", accessToken)}
err := httputil.Client(http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, &response)
headers := map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.AccessToken.AccessToken)}
err := httputil.Client(ctx, http.MethodGet, defaultOneloginGroupURL, version.UserAgent(), headers, nil, &response)
if err != nil {
return nil, err
}
var groups []string
for _, group := range response.Groups {
log.Debug().Str("ID", group).Msg("identity/onelogin: group")
groups = append(groups, group)
}
return groups, nil
return response.Groups, nil
}

View file

@ -7,11 +7,8 @@ import (
"errors"
"fmt"
"net/url"
"time"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/trace"
oidc "github.com/pomerium/go-oidc"
"golang.org/x/oauth2"
@ -34,22 +31,13 @@ const (
// ErrMissingProviderURL is returned when an identity provider requires a provider url
// does not receive one.
var ErrMissingProviderURL = errors.New("identity: missing provider url")
// UserGrouper is an interface representing the ability to retrieve group membership information
// from an identity provider
type UserGrouper interface {
// UserGroups returns a slice of group names a given user is in
UserGroups(context.Context, string) ([]string, error)
}
var ErrMissingProviderURL = errors.New("internal/identity: missing provider url")
// Authenticator is an interface representing the ability to authenticate with an identity provider.
type Authenticator interface {
Authenticate(context.Context, string) (*sessions.State, error)
IDTokenToSession(context.Context, string) (*sessions.State, error)
Validate(context.Context, string) (bool, error)
Refresh(context.Context, *sessions.State) (*sessions.State, error)
Revoke(string) error
Revoke(context.Context, *oauth2.Token) error
GetSignInURL(state string) string
}
@ -59,8 +47,8 @@ func New(providerName string, p *Provider) (a Authenticator, err error) {
switch providerName {
case AzureProviderName:
a, err = NewAzureProvider(p)
case GitlabProviderName:
return nil, fmt.Errorf("identity: %s currently not supported", providerName)
// case GitlabProviderName:
// return nil, fmt.Errorf("internal/identity: %s currently not supported", providerName)
case GoogleProviderName:
a, err = NewGoogleProvider(p)
case OIDCProviderName:
@ -70,7 +58,7 @@ func New(providerName string, p *Provider) (a Authenticator, err error) {
case OneLoginProviderName:
a, err = NewOneLoginProvider(p)
default:
return nil, fmt.Errorf("identity: %s provider not known", providerName)
return nil, fmt.Errorf("internal/identity: %s provider not known", providerName)
}
if err != nil {
return nil, err
@ -85,13 +73,16 @@ type Provider struct {
ProviderName string
RedirectURL *url.URL
ClientID string
ClientSecret string
ProviderURL string
Scopes []string
// Some providers, such as google, require additional remote api calls to retrieve
// user details like groups. Provider is responsible for parsing.
UserGroupFn func(context.Context, *sessions.State) ([]string, error)
// ServiceAccount can be set for those providers that require additional
// credentials or tokens to do follow up API calls (e.g. Google)
ServiceAccount string
provider *oidc.Provider
@ -110,94 +101,73 @@ func (p *Provider) GetSignInURL(state string) string {
return p.oauth.AuthCodeURL(state, oauth2.AccessTypeOffline)
}
// Validate validates a given session's from it's JWT token
// The function verifies it's been signed by the provider, preforms
// any additional checks depending on the Config, and returns the payload.
//
// Validate does NOT do nonce validation.
// Validate does NOT check if revoked.
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func (p *Provider) Validate(ctx context.Context, idToken string) (bool, error) {
ctx, span := trace.StartSpan(ctx, "identity.provider.Validate")
defer span.End()
_, err := p.verifier.Verify(ctx, idToken)
if err != nil {
log.Error().Err(err).Msg("identity: failed to verify session state")
return false, err
}
return true, nil
}
// IDTokenToSession takes an identity provider issued JWT as input ('id_token')
// and returns a session state. The provided token's audience ('aud') must
// match Pomerium's client_id.
func (p *Provider) IDTokenToSession(ctx context.Context, rawIDToken string) (*sessions.State, error) {
idToken, err := p.verifier.Verify(ctx, rawIDToken)
if err != nil {
return nil, fmt.Errorf("identity: could not verify id_token: %v", err)
}
// extract additional, non-oidc standard claims
var claims struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Groups []string `json:"groups"`
}
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("identity: failed to parse id_token claims: %v", err)
}
return &sessions.State{
IDToken: rawIDToken,
User: idToken.Subject,
RefreshDeadline: idToken.Expiry.Truncate(time.Second),
Email: claims.Email,
Groups: claims.Groups,
}, nil
}
// Authenticate creates a session with an identity provider from a authorization code
// Authenticate creates an identity session with google from a authorization code, and follows up
// call to the admin/group api to check what groups the user is in.
func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.State, error) {
// exchange authorization for a oidc token
oauth2Token, err := p.oauth.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("identity: failed token exchange: %v", err)
return nil, fmt.Errorf("internal/identity: token exchange failed: %w", err)
}
//id_token contains claims about the authenticated user
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("token response did not contain an id_token")
}
session, err := p.IDTokenToSession(ctx, rawIDToken)
idToken, err := p.IdentityFromToken(ctx, oauth2Token)
if err != nil {
return nil, fmt.Errorf("identity: could not verify id_token: %v", err)
}
session.AccessToken = oauth2Token.AccessToken
session.RefreshToken = oauth2Token.RefreshToken
return session, nil
}
// Refresh renews a user's session using therefresh_token without reprompting
// the user. If supported, group membership is also refreshed.
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.RefreshToken == "" {
return nil, errors.New("identity: missing refresh token")
}
t := oauth2.Token{RefreshToken: s.RefreshToken}
newToken, err := p.oauth.TokenSource(ctx, &t).Token()
if err != nil {
log.Error().Err(err).Msg("identity: refresh failed")
return nil, err
}
s.AccessToken = newToken.AccessToken
s.RefreshDeadline = newToken.Expiry.Truncate(time.Second)
s, err := sessions.NewStateFromTokens(idToken, oauth2Token, p.RedirectURL.Host)
if err != nil {
return nil, err
}
if p.UserGroupFn != nil {
s.Groups, err = p.UserGroupFn(ctx, s)
if err != nil {
return nil, fmt.Errorf("internal/identity: could not retrieve groups %w", err)
}
}
return s, nil
}
// Refresh renews a user's session using an oidc refresh token withoutreprompting the user.
// Group membership is also refreshed.
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
func (p *Provider) Refresh(ctx context.Context, s *sessions.State) (*sessions.State, error) {
if s.AccessToken == nil || s.AccessToken.RefreshToken == "" {
return nil, errors.New("internal/identity: missing refresh token")
}
t := oauth2.Token{RefreshToken: s.AccessToken.RefreshToken}
oauthToken, err := p.oauth.TokenSource(ctx, &t).Token()
if err != nil {
return nil, fmt.Errorf("internal/identity: refresh failed %w", err)
}
idToken, err := p.IdentityFromToken(ctx, oauthToken)
if err != nil {
return nil, err
}
if err := s.UpdateState(idToken, oauthToken); err != nil {
return nil, fmt.Errorf("internal/identity: state update failed %w", err)
}
if p.UserGroupFn != nil {
s.Groups, err = p.UserGroupFn(ctx, s)
if err != nil {
return nil, fmt.Errorf("internal/identity: could not retrieve groups %w", err)
}
}
return s, nil
}
// IdentityFromToken takes an identity provider issued JWT as input ('id_token')
// and returns a session state. The provided token's audience ('aud') must
// match Pomerium's client_id.
func (p *Provider) IdentityFromToken(ctx context.Context, t *oauth2.Token) (*oidc.IDToken, error) {
rawIDToken, ok := t.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("internal/identity: id_token not found")
}
return p.verifier.Verify(ctx, rawIDToken)
}
// Revoke enables a user to revoke her token. If the identity provider supports revocation
// the endpoint is available, otherwise an error is thrown.
func (p *Provider) Revoke(token string) error {
return fmt.Errorf("identity: revoke not implemented by %s", p.ProviderName)
func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error {
return fmt.Errorf("internal/identity: revoke not implemented by %s", p.ProviderName)
}

View file

@ -4,7 +4,6 @@ import (
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strings"
"time"
@ -12,8 +11,6 @@ import (
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil"
"golang.org/x/net/publicsuffix"
)
// SetHeaders sets a map of response headers.
@ -30,72 +27,6 @@ func SetHeaders(headers map[string]string) func(next http.Handler) http.Handler
}
}
// ValidateClientSecret checks the request header for the client secret and returns
// an error if it does not match the proxy client secret
func ValidateClientSecret(sharedSecret string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateClientSecret")
defer span.End()
if err := r.ParseForm(); err != nil {
httputil.ErrorResponse(w, r, httputil.Error("couldn't parse form", http.StatusBadRequest, err))
return
}
clientSecret := r.Form.Get("shared_secret")
// check the request header for the client secret
if clientSecret == "" {
clientSecret = r.Header.Get("X-Client-Secret")
}
if clientSecret != sharedSecret {
httputil.ErrorResponse(w, r, httputil.Error("client secret mismatch", http.StatusBadRequest, nil))
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// ValidateRedirectURI checks the redirect uri in the query parameters and ensures that
// the its domain is in the list of proxy root domains.
func ValidateRedirectURI(rootDomain *url.URL) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateRedirectURI")
defer span.End()
err := r.ParseForm()
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("couldn't parse form", http.StatusBadRequest, err))
return
}
redirectURI, err := urlutil.ParseAndValidateURL(r.Form.Get("redirect_uri"))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("bad redirect_uri", http.StatusBadRequest, err))
return
}
if !SameDomain(redirectURI, rootDomain) {
httputil.ErrorResponse(w, r, httputil.Error("redirect uri and root domain differ", http.StatusBadRequest, nil))
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// SameDomain checks to see if two URLs share the top level domain (TLD Plus One).
func SameDomain(u, j *url.URL) bool {
a, err := publicsuffix.EffectiveTLDPlusOne(u.Hostname())
if err != nil {
return false
}
b, err := publicsuffix.EffectiveTLDPlusOne(j.Hostname())
if err != nil {
return false
}
return a == b
}
// ValidateSignature ensures the request is valid and has been signed with
// the correspdoning client secret key
func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler {

View file

@ -17,36 +17,6 @@ func hmacHelperFunc(rawRedirect string, timestamp time.Time, secret string) []by
return cryptutil.GenerateHMAC(data, secret)
}
func Test_SameDomain(t *testing.T) {
t.Parallel()
tests := []struct {
name string
uri string
rootDomains string
want bool
}{
{"good url redirect", "https://example.com/redirect", "https://example.com", true},
{"good multilevel", "https://httpbin.a.corp.example.com", "https://auth.b.corp.example.com", true},
{"good complex tld", "https://httpbin.a.corp.example.co.uk", "https://auth.b.corp.example.co.uk", true},
{"bad complex tld", "https://httpbin.a.corp.notexample.co.uk", "https://auth.b.corp.example.co.uk", false},
{"simple sub", "https://auth.example.com", "https://test.example.com", true},
{"bad domain", "https://auth.example.com/redirect", "https://test.notexample.com", false},
{"malformed url", "^example.com/redirect", "https://notexample.com", false},
{"empty domain list", "https://example.com/redirect", ".com", false},
{"empty domain", "https://example.com/redirect", "", false},
{"empty url", "", "example.com", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u, _ := url.Parse(tt.uri)
j, _ := url.Parse(tt.rootDomains)
if got := SameDomain(u, j); got != tt.want {
t.Errorf("SameDomain() = %v, want %v", got, tt.want)
}
})
}
}
func Test_ValidSignature(t *testing.T) {
t.Parallel()
goodURL := "https://example.com/redirect"
@ -109,87 +79,6 @@ func TestSetHeaders(t *testing.T) {
}
}
func TestValidateRedirectURI(t *testing.T) {
t.Parallel()
tests := []struct {
name string
rootDomain string
redirectURI string
status int
}{
{"simple", "https://auth.google.com", "redirect_uri=https://b.google.com", http.StatusOK},
{"deep ok", "https://a.some.really.deep.sub.domain.google.com", "redirect_uri=https://b.some.really.deep.sub.domain.google.com", http.StatusOK},
{"bad match", "https://auth.aol.com", "redirect_uri=https://test.google.com", http.StatusBadRequest},
{"bad simple", "https://auth.corp.aol.com", "redirect_uri=https://test.corp.google.com", http.StatusBadRequest},
{"deep bad", "https://a.some.really.deep.sub.domain.scroogle.com", "redirect_uri=https://b.some.really.deep.sub.domain.google.com", http.StatusBadRequest},
{"with cname", "https://auth.google.com", "redirect_uri=https://www.google.com", http.StatusOK},
{"with path", "https://auth.google.com", "redirect_uri=https://www.google.com/path", http.StatusOK},
{"http mistmatch", "https://auth.google.com", "redirect_uri=http://www.google.com/path", http.StatusOK},
{"http", "http://auth.google.com", "redirect_uri=http://www.google.com/path", http.StatusOK},
{"ip", "http://1.1.1.1", "redirect_uri=http://8.8.8.8", http.StatusBadRequest},
{"redirect get param not set", "https://auth.google.com", "not_redirect_uri!=https://b.google.com", http.StatusBadRequest},
{"malformed, invalid get params", "https://auth.google.com", "redirect_uri=https://%zzzzz", http.StatusBadRequest},
{"malformed, invalid url", "https://auth.google.com", "redirect_uri=https://accounts.google.^", http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &http.Request{
Method: http.MethodGet,
URL: &url.URL{RawQuery: tt.redirectURI},
}
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hi"))
})
rr := httptest.NewRecorder()
u, _ := url.Parse(tt.rootDomain)
handler := ValidateRedirectURI(u)(testHandler)
handler.ServeHTTP(rr, req)
if rr.Code != tt.status {
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
t.Errorf("%s", rr.Body)
}
})
}
}
func TestValidateClientSecret(t *testing.T) {
t.Parallel()
tests := []struct {
name string
sharedSecret string
clientGetValue string
clientHeaderValue string
status int
}{
{"simple", "secret", "secret", "secret", http.StatusOK},
{"missing get param, valid header", "secret", "", "secret", http.StatusOK},
{"missing both", "secret", "", "", http.StatusBadRequest},
{"simple bad", "bad-secret", "secret", "", http.StatusBadRequest},
{"malformed, invalid hex digits", "secret", "%zzzzz", "", http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &http.Request{
Method: http.MethodGet,
Header: http.Header{"X-Client-Secret": []string{tt.clientHeaderValue}},
URL: &url.URL{RawQuery: fmt.Sprintf("shared_secret=%s", tt.clientGetValue)},
}
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hi"))
})
rr := httptest.NewRecorder()
handler := ValidateClientSecret(tt.sharedSecret)(testHandler)
handler.ServeHTTP(rr, req)
if rr.Code != tt.status {
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
t.Errorf("%s", rr.Body)
}
})
}
}
func TestValidateSignature(t *testing.T) {
t.Parallel()
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="

View file

@ -1,13 +1,12 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"errors"
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
)
const (
@ -28,49 +27,65 @@ const (
// CookieStore implements the session store interface for session cookies.
type CookieStore struct {
Name string
CookieDomain string
CookieExpire time.Duration
CookieHTTPOnly bool
CookieSecure bool
Encoder cryptutil.SecureEncoder
Domain string
Expire time.Duration
HTTPOnly bool
Secure bool
encoder Marshaler
decoder Unmarshaler
}
// CookieStoreOptions holds options for CookieStore
type CookieStoreOptions struct {
// CookieOptions holds options for CookieStore
type CookieOptions struct {
Name string
CookieDomain string
CookieExpire time.Duration
CookieHTTPOnly bool
CookieSecure bool
Encoder cryptutil.SecureEncoder
Domain string
Expire time.Duration
HTTPOnly bool
Secure bool
}
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
func NewCookieStore(opts *CookieOptions, encoder Encoder) (*CookieStore, error) {
if opts.Name == "" {
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
}
if opts.Encoder == nil {
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
if encoder == nil {
return nil, fmt.Errorf("internal/sessions: decoder cannot be nil")
}
return &CookieStore{
Name: opts.Name,
CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieDomain: opts.CookieDomain,
CookieExpire: opts.CookieExpire,
Encoder: opts.Encoder,
Secure: opts.Secure,
HTTPOnly: opts.HTTPOnly,
Domain: opts.Domain,
Expire: opts.Expire,
encoder: encoder,
decoder: encoder,
}, nil
}
// NewCookieLoader returns a new session with ciphers for each of the cookie secrets
func NewCookieLoader(opts *CookieOptions, decoder Unmarshaler) (*CookieStore, error) {
if opts.Name == "" {
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
}
if decoder == nil {
return nil, fmt.Errorf("internal/sessions: decoder cannot be nil")
}
return &CookieStore{
Name: opts.Name,
Secure: opts.Secure,
HTTPOnly: opts.HTTPOnly,
Domain: opts.Domain,
Expire: opts.Expire,
decoder: decoder,
}, nil
}
func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
domain := req.Host
if cs.CookieDomain != "" {
domain = cs.CookieDomain
} else {
domain = ParentSubdomain(domain)
if cs.Domain != "" {
domain = cs.Domain
}
if h, _, err := net.SplitHostPort(domain); err == nil {
@ -81,8 +96,8 @@ func (cs *CookieStore) makeCookie(req *http.Request, name string, value string,
Value: value,
Path: "/",
Domain: domain,
HttpOnly: cs.CookieHTTPOnly,
Secure: cs.CookieSecure,
HttpOnly: cs.HTTPOnly,
Secure: cs.Secure,
}
// only set an expiration if we want one, otherwise default to non perm session based
if expiration != 0 {
@ -98,23 +113,38 @@ func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
// LoadSession returns a State from the cookie in the request.
func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
cipherText := loadChunkedCookie(req, cs.Name)
if cipherText == "" {
data := loadChunkedCookie(req, cs.Name)
if data == "" {
return nil, ErrNoSessionFound
}
session, err := UnmarshalSession(cipherText, cs.Encoder)
var session State
err := cs.decoder.Unmarshal([]byte(data), &session)
if err != nil {
return nil, ErrMalformed
}
return session, nil
return &session, err
}
// SaveSession saves a session state to a request sessions.
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
value, err := MarshalSession(s, cs.Encoder)
// SaveSession saves a session state to a request's cookie store.
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, x interface{}) error {
var value string
if cs.encoder != nil {
data, err := cs.encoder.Marshal(x)
if err != nil {
return err
}
value = string(data)
} else {
switch v := x.(type) {
case []byte:
value = string(v)
case string:
value = v
default:
return errors.New("internal/sessions: cannot save non-string type")
}
}
cs.setSessionCookie(w, req, value)
return nil
}
@ -125,7 +155,7 @@ func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expira
}
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
cs.setCookie(w, cs.makeSessionCookie(req, val, cs.CookieExpire, time.Now()))
cs.setCookie(w, cs.makeSessionCookie(req, val, cs.Expire, time.Now()))
}
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
@ -153,11 +183,11 @@ func loadChunkedCookie(r *http.Request, cookieName string) string {
if err != nil {
return ""
}
cipherText := c.Value
data := c.Value
// if the first byte is our canary byte, we need to handle the multipart bit
if []byte(c.Value)[0] == ChunkedCanaryByte {
var b strings.Builder
fmt.Fprintf(&b, "%s", cipherText[1:])
fmt.Fprintf(&b, "%s", data[1:])
for i := 1; i <= MaxNumChunks; i++ {
next, err := r.Cookie(fmt.Sprintf("%s_%d", cookieName, i))
if err != nil {
@ -165,9 +195,9 @@ func loadChunkedCookie(r *http.Request, cookieName string) string {
}
fmt.Fprintf(&b, "%s", next.Value)
}
cipherText = b.String()
data = b.String()
}
return cipherText
return data
}
func chunk(s string, size int) []string {

View file

@ -6,86 +6,43 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
)
type MockEncoder struct{}
func (a MockEncoder) Marshal(s interface{}) (string, error) { return "", errors.New("error") }
func (a MockEncoder) Unmarshal(s string, i interface{}) error {
if s == "unmarshal error" || s == "error" {
return errors.New("error")
}
return nil
}
func TestNewCookieStore(t *testing.T) {
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
encoder := cryptutil.NewSecureJSONEncoder(cipher)
encoder := ecjson.New(cipher)
tests := []struct {
name string
opts *CookieStoreOptions
opts *CookieOptions
encoder Encoder
want *CookieStore
wantErr bool
}{
{"good",
&CookieStoreOptions{
Name: "_cookie",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
Encoder: encoder,
},
&CookieStore{
Name: "_cookie",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
Encoder: encoder,
},
false},
{"missing name",
&CookieStoreOptions{
Name: "",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
Encoder: encoder,
},
nil,
true},
{"missing cipher",
&CookieStoreOptions{
Name: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
Encoder: nil,
},
nil,
true},
{"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
{"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
{"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewCookieStore(tt.opts)
got, err := NewCookieStore(tt.opts, tt.encoder)
if (err != nil) != tt.wantErr {
t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr)
return
}
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(cryptutil.SecureJSONEncoder{}),
cmpopts.IgnoreUnexported(CookieStore{}),
}
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
@ -94,6 +51,40 @@ func TestNewCookieStore(t *testing.T) {
})
}
}
func TestNewCookieLoader(t *testing.T) {
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
encoder := ecjson.New(cipher)
tests := []struct {
name string
opts *CookieOptions
encoder Encoder
want *CookieStore
wantErr bool
}{
{"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
{"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
{"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewCookieLoader(tt.opts, tt.encoder)
if (err != nil) != tt.wantErr {
t.Errorf("NewCookieLoader() error = %v, wantErr %v", err, tt.wantErr)
return
}
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(CookieStore{}),
}
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
t.Errorf("NewCookieLoader() = %s", diff)
}
})
}
}
func TestCookieStore_makeCookie(t *testing.T) {
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
@ -114,10 +105,10 @@ func TestCookieStore_makeCookie(t *testing.T) {
want *http.Cookie
wantCSRF *http.Cookie
}{
{"good", "http://httpbin.corp.pomerium.io", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
{"domains with https", "https://httpbin.corp.pomerium.io", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
{"domain with port", "http://httpbin.corp.pomerium.io:443", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
{"expiration set", "http://httpbin.corp.pomerium.io:443", "", "_pomerium", "value", 10 * time.Second, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
{"good", "http://httpbin.corp.pomerium.io", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
{"domains with https", "https://httpbin.corp.pomerium.io", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
{"domain with port", "http://httpbin.corp.pomerium.io:443", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
{"expiration set", "http://httpbin.corp.pomerium.io:443", "", "_pomerium", "value", 10 * time.Second, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
{"good", "http://httpbin.corp.pomerium.io", "pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
}
for _, tt := range tests {
@ -125,13 +116,14 @@ func TestCookieStore_makeCookie(t *testing.T) {
r := httptest.NewRequest("GET", tt.domain, nil)
s, err := NewCookieStore(
&CookieStoreOptions{
&CookieOptions{
Name: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: tt.cookieDomain,
CookieExpire: 10 * time.Second,
Encoder: cryptutil.NewSecureJSONEncoder(cipher)})
Secure: true,
HTTPOnly: true,
Domain: tt.cookieDomain,
Expire: 10 * time.Second,
},
ecjson.New(cipher))
if err != nil {
t.Fatal(err)
}
@ -151,7 +143,6 @@ func TestCookieStore_SaveSession(t *testing.T) {
if err != nil {
t.Fatal(err)
}
cipher := cryptutil.NewSecureJSONEncoder(c)
hugeString := make([]byte, 4097)
if _, err := rand.Read(hugeString); err != nil {
@ -160,23 +151,28 @@ func TestCookieStore_SaveSession(t *testing.T) {
tests := []struct {
name string
State *State
cipher cryptutil.SecureEncoder
encoder Encoder
decoder Encoder
wantErr bool
wantLoadErr bool
}{
{"good", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
{"bad cipher", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, MockEncoder{}, true, true},
{"huge cookie", &State{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
{"good", &State{Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
{"bad cipher", &State{Email: "user@domain.com", User: "user"}, nil, nil, true, true},
{"huge cookie", &State{Subject: fmt.Sprintf("%x", hugeString), Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
{"marshal error", &State{Email: "user@domain.com", User: "user"}, encoding.MockEncoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true},
{"nil encoder cannot save non string type", &State{Email: "user@domain.com", User: "user"}, nil, ecjson.New(c), true, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &CookieStore{
Name: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
Encoder: tt.cipher}
Secure: true,
HTTPOnly: true,
Domain: "pomerium.io",
Expire: 10 * time.Second,
encoder: tt.encoder,
decoder: tt.encoder,
}
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
@ -195,54 +191,19 @@ func TestCookieStore_SaveSession(t *testing.T) {
t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
return
}
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(State{}),
}
if err == nil {
if diff := cmp.Diff(state, tt.State); diff != "" {
if diff := cmp.Diff(state, tt.State, cmpOpts...); diff != "" {
t.Errorf("CookieStore.LoadSession() got = %s", diff)
}
}
})
}
}
func TestMockSessionStore(t *testing.T) {
tests := []struct {
name string
mockCSRF *MockSessionStore
saveSession *State
wantLoadErr bool
wantSaveErr bool
}{
{"basic",
&MockSessionStore{
ResponseSession: "test",
Session: &State{AccessToken: "AccessToken"},
SaveError: nil,
LoadError: nil,
},
&State{AccessToken: "AccessToken"},
false,
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ms := tt.mockCSRF
err := ms.SaveSession(nil, nil, tt.saveSession)
if (err != nil) != tt.wantSaveErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantSaveErr %v", err, tt.wantSaveErr)
return
}
got, err := ms.LoadSession(nil)
if (err != nil) != tt.wantLoadErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantLoadErr %v", err, tt.wantLoadErr)
return
}
if !reflect.DeepEqual(got, tt.mockCSRF.Session) {
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Session)
}
ms.ClearSession(nil, nil)
if ms.ResponseSession != "" {
t.Errorf("ResponseSession not empty! %s", ms.ResponseSession)
w = httptest.NewRecorder()
s.ClearSession(w, r)
x := w.Header().Get("Set-Cookie")
if !strings.Contains(x, "_pomerium=; Path=/;") {
t.Errorf(x)
}
})
}

View file

@ -3,14 +3,9 @@ package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"net/http"
"strings"
"github.com/pomerium/pomerium/internal/cryptutil"
)
const (
// defaultAuthHeader and defaultAuthType are default header name for the
// authorization bearer token header as defined in rfc2617
// https://tools.ietf.org/html/rfc6750#section-2.1
defaultAuthHeader = "Authorization"
defaultAuthType = "Bearer"
)
@ -20,41 +15,44 @@ const (
type HeaderStore struct {
authHeader string
authType string
encoder cryptutil.SecureEncoder
encoder Unmarshaler
}
// NewHeaderStore returns a new header store for loading sessions from
// authorization headers.
func NewHeaderStore(enc cryptutil.SecureEncoder) *HeaderStore {
// authorization header as defined in as defined in rfc2617
//
// NOTA BENE: While most servers do not log Authorization headers by default,
// you should ensure no other services are logging or leaking your auth headers.
func NewHeaderStore(enc Unmarshaler, headerType string) *HeaderStore {
if headerType == "" {
headerType = defaultAuthType
}
return &HeaderStore{
authHeader: defaultAuthHeader,
authType: defaultAuthType,
authType: headerType,
encoder: enc,
}
}
// LoadSession tries to retrieve the token string from the Authorization header.
//
// NOTA BENE: While most servers do not log Authorization headers by default,
// you should ensure no other services are logging or leaking your auth headers.
func (as *HeaderStore) LoadSession(r *http.Request) (*State, error) {
cipherText := as.tokenFromHeader(r)
cipherText := TokenFromHeader(r, as.authHeader, as.authType)
if cipherText == "" {
return nil, ErrNoSessionFound
}
session, err := UnmarshalSession(cipherText, as.encoder)
if err != nil {
var session State
if err := as.encoder.Unmarshal([]byte(cipherText), &session); err != nil {
return nil, ErrMalformed
}
return session, nil
return &session, nil
}
// retrieve the value of the authorization header
func (as *HeaderStore) tokenFromHeader(r *http.Request) string {
bearer := r.Header.Get(as.authHeader)
atSize := len(as.authType)
if len(bearer) > atSize && strings.EqualFold(bearer[0:atSize], as.authType) {
// TokenFromHeader retrieves the value of the authorization header from a given
// request, header key, and authentication type.
func TokenFromHeader(r *http.Request, authHeader, authType string) string {
bearer := r.Header.Get(authHeader)
atSize := len(authType)
if len(bearer) > atSize && strings.EqualFold(bearer[0:atSize], authType) {
return bearer[atSize+1:]
}
return ""

View file

@ -12,10 +12,8 @@ var (
ErrorCtxKey = &contextKey{"Error"}
)
// RetrieveSession will search for a auth session in a http request, in the order:
// 1. `pomerium_session` URI query parameter
// 2. `Authorization: BEARER` request header
// 3. Cookie `_pomerium` value
// RetrieveSession takes a slice of session loaders and tries to find a valid
// session in the order they were supplied and is added to the request's context
func RetrieveSession(s ...SessionLoader) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return retrieve(s...)(next)
@ -34,35 +32,21 @@ func retrieve(s ...SessionLoader) func(http.Handler) http.Handler {
}
}
// retrieveFromRequest extracts sessions state from the request by calling
// token find functions in the order they where provided.
func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, error) {
state := new(State)
var err error
// Extract sessions state from the request by calling token find functions in
// the order they where provided. Further extraction stops if a function
// returns a non-empty string.
for _, s := range sessions {
state, err = s.LoadSession(r)
state, err := s.LoadSession(r)
if err != nil && !errors.Is(err, ErrNoSessionFound) {
// unexpected error
return nil, err
}
// break, we found a session state
if state != nil {
break
}
}
// no session found if state is still empty
if state == nil {
return nil, ErrNoSessionFound
}
if err = state.Valid(); err != nil {
// a little unusual but we want to return the expired state too
return state, err
}
if state != nil {
err := state.Verify(r.Host)
return state, err // N.B.: state is _not nil_
}
}
return state, nil
return nil, ErrNoSessionFound
}
// NewContext sets context values for the user session state and error.

View file

@ -11,6 +11,8 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"gopkg.in/square/go-jose.v2/jwt"
)
func TestNewContext(t *testing.T) {
@ -27,7 +29,7 @@ func TestNewContext(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
ctxOut := NewContext(tt.ctx, tt.t, tt.err)
stateOut, errOut := FromContext(ctxOut)
if diff := cmp.Diff(tt.t, stateOut); diff != "" {
if diff := cmp.Diff(tt.t.Email, stateOut.Email); diff != "" {
t.Errorf("NewContext() = %s", diff)
}
if diff := cmp.Diff(tt.err, errOut); diff != "" {
@ -67,56 +69,54 @@ func TestVerifier(t *testing.T) {
wantBody string
wantStatus int
}{
{"good cookie session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth header session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
{"good cookie session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth header session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth query param session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"no session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := cryptutil.NewSecureJSONEncoder(cipher)
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
encSession, err := MarshalSession(&tt.state, encoder)
encSession, err := encoder.Marshal(&tt.state)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string
encSession += cryptutil.NewBase64Key()
encSession = append(encSession, cryptutil.NewKey()...)
}
cs, err := NewCookieStore(&CookieStoreOptions{
cs, err := NewCookieStore(&CookieOptions{
Name: "_pomerium",
Encoder: encoder,
})
}, encoder)
if err != nil {
t.Fatal(err)
}
as := NewHeaderStore(encoder)
as := NewHeaderStore(encoder, "")
qp := NewQueryParamStore(encoder)
qp := NewQueryParamStore(encoder, "")
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
if tt.cookie {
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: encSession})
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: string(encSession)})
} else if tt.header {
r.Header.Set("Authorization", "Bearer "+encSession)
r.Header.Set("Authorization", "Bearer "+string(encSession))
} else if tt.param {
q := r.URL.Query()
q.Set("pomerium_session", encSession)
q.Set("pomerium_session", string(encSession))
r.URL.RawQuery = q.Encode()
}

View file

@ -23,6 +23,6 @@ func (ms MockSessionStore) LoadSession(*http.Request) (*State, error) {
}
// SaveSession returns a save error.
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, *State) error {
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
return ms.SaveError
}

View file

@ -0,0 +1,50 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"reflect"
"testing"
)
func TestMockSessionStore(t *testing.T) {
tests := []struct {
name string
mockCSRF *MockSessionStore
saveSession *State
wantLoadErr bool
wantSaveErr bool
}{
{"basic",
&MockSessionStore{
ResponseSession: "test",
Session: &State{Subject: "0101"},
SaveError: nil,
LoadError: nil,
},
&State{Subject: "0101"},
false,
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ms := tt.mockCSRF
err := ms.SaveSession(nil, nil, tt.saveSession)
if (err != nil) != tt.wantSaveErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantSaveErr %v", err, tt.wantSaveErr)
return
}
got, err := ms.LoadSession(nil)
if (err != nil) != tt.wantLoadErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantLoadErr %v", err, tt.wantLoadErr)
return
}
if !reflect.DeepEqual(got, tt.mockCSRF.Session) {
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Session)
}
ms.ClearSession(nil, nil)
if ms.ResponseSession != "" {
t.Errorf("ResponseSession not empty! %s", ms.ResponseSession)
}
})
}
}

View file

@ -2,8 +2,6 @@ package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"net/http"
"github.com/pomerium/pomerium/internal/cryptutil"
)
const (
@ -14,31 +12,55 @@ const (
// query strings / query parameters.
type QueryParamStore struct {
queryParamKey string
encoder cryptutil.SecureEncoder
encoder Marshaler
decoder Unmarshaler
}
// NewQueryParamStore returns a new query param store for loading sessions from
// query strings / query parameters.
func NewQueryParamStore(enc cryptutil.SecureEncoder) *QueryParamStore {
//
// NOTA BENE: By default, most servers _DO_ log query params, the leaking or
// accidental logging of which should be considered a security issue.
func NewQueryParamStore(enc Encoder, qp string) *QueryParamStore {
if qp == "" {
qp = defaultQueryParamKey
}
return &QueryParamStore{
queryParamKey: defaultQueryParamKey,
queryParamKey: qp,
encoder: enc,
decoder: enc,
}
}
// LoadSession tries to retrieve the token string from URL query parameters.
//
// NOTA BENE: By default, most servers _DO_ log query params, the leaking or
// accidental logging of which should be considered a security issue.
func (qp *QueryParamStore) LoadSession(r *http.Request) (*State, error) {
cipherText := r.URL.Query().Get(qp.queryParamKey)
if cipherText == "" {
return nil, ErrNoSessionFound
}
session, err := UnmarshalSession(cipherText, qp.encoder)
if err != nil {
var session State
if err := qp.decoder.Unmarshal([]byte(cipherText), &session); err != nil {
return nil, ErrMalformed
}
return session, nil
return &session, nil
}
// ClearSession clears the session cookie from a request's query param key `pomerium_session`.
func (qp *QueryParamStore) ClearSession(w http.ResponseWriter, r *http.Request) {
params := r.URL.Query()
params.Del(qp.queryParamKey)
r.URL.RawQuery = params.Encode()
}
// SaveSession sets a session to a request's query param key `pomerium_session`
func (qp *QueryParamStore) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
data, err := qp.encoder.Marshal(x)
if err != nil {
return err
}
r.URL.Query().Get(qp.queryParamKey)
params := r.URL.Query()
params.Set(qp.queryParamKey, string(data))
r.URL.RawQuery = params.Encode()
return nil
}

View file

@ -0,0 +1,47 @@
package sessions
import (
"errors"
"net/http/httptest"
"net/url"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/encoding"
)
func TestNewQueryParamStore(t *testing.T) {
tests := []struct {
name string
State *State
enc Encoder
qp string
wantErr bool
wantURL *url.URL
}{
{"simple good", &State{Email: "user@domain.com", User: "user"}, encoding.MockEncoder{MarshalResponse: []byte("ok")}, "", false, &url.URL{Path: "/", RawQuery: "pomerium_session=ok"}},
{"marshall error", &State{Email: "user@domain.com", User: "user"}, encoding.MockEncoder{MarshalError: errors.New("error")}, "", true, &url.URL{Path: "/"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewQueryParamStore(tt.enc, tt.qp)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
t.Errorf("NewQueryParamStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
if diff := cmp.Diff(r.URL, tt.wantURL); diff != "" {
t.Errorf("NewQueryParamStore() = %v", diff)
}
got.ClearSession(w, r)
if diff := cmp.Diff(r.URL, &url.URL{Path: "/"}); diff != "" {
t.Errorf("NewQueryParamStore() = %v", diff)
}
})
}
}

View file

@ -1,46 +1,147 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
oidc "github.com/pomerium/go-oidc"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt"
)
const (
// DefaultLeeway defines the default leeway for matching NotBefore/Expiry claims.
DefaultLeeway = 1.0 * time.Minute
)
// timeNow is time.Now but pulled out as a variable for tests.
var timeNow = time.Now
// State is our object that keeps track of a user's session state
type State struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
RefreshDeadline time.Time `json:"refresh_deadline"`
// Public claim values (as specified in RFC 7519).
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Audience jwt.Audience `json:"aud,omitempty"`
Expiry *jwt.NumericDate `json:"exp,omitempty"`
NotBefore *jwt.NumericDate `json:"nbf,omitempty"`
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
ID string `json:"jti,omitempty"`
// core pomerium identity claims ; not standard to RFC 7519
Email string `json:"email"`
User string `json:"user"`
Groups []string `json:"groups"`
Groups []string `json:"groups,omitempty"`
User string `json:"user,omitempty"` // google
ImpersonateEmail string
ImpersonateGroups []string
// commonly supported IdP information
// https://www.iana.org/assignments/jwt/jwt.xhtml#claims
Name string `json:"name,omitempty"` // google
GivenName string `json:"given_name,omitempty"` // google
FamilyName string `json:"family_name,omitempty"` // google
Picture string `json:"picture,omitempty"` // google
EmailVerified bool `json:"email_verified,omitempty"` // google
// Impersonate-able fields
ImpersonateEmail string `json:"impersonate_email,omitempty"`
ImpersonateGroups []string `json:"impersonate_groups,omitempty"`
// Programmatic whether this state is used for machine-to-machine
// programatic access.
Programmatic bool `json:"programatic"`
AccessToken *oauth2.Token `json:"access_token,omitempty"`
idToken *oidc.IDToken
}
// Valid returns an error if the users's session state is not valid.
func (s *State) Valid() error {
if s.Expired() {
return ErrExpired
// NewStateFromTokens returns a session state built from oidc and oauth2
// tokens as part of OpenID Connect flow with a new audience appended to the
// audience claim.
func NewStateFromTokens(idToken *oidc.IDToken, accessToken *oauth2.Token, audience string) (*State, error) {
if idToken == nil {
return nil, errors.New("sessions: oidc id token missing")
}
if accessToken == nil {
return nil, errors.New("sessions: oauth2 token missing")
}
s := &State{}
if err := idToken.Claims(s); err != nil {
return nil, fmt.Errorf("sessions: couldn't unmarshal extra claims %w", err)
}
s.Audience = []string{audience}
s.idToken = idToken
s.AccessToken = accessToken
return s, nil
}
// UpdateState updates the current state given a new identity (oidc) and authorization
// (oauth2) tokens following a oidc refresh. NB, unlike during authentication,
// refresh typically provides fewer claims in the token so we want to build from
// our previous state.
func (s *State) UpdateState(idToken *oidc.IDToken, accessToken *oauth2.Token) error {
if idToken == nil {
return errors.New("sessions: oidc id token missing")
}
if accessToken == nil {
return errors.New("sessions: oauth2 token missing")
}
audience := append(s.Audience[:0:0], s.Audience...)
s.AccessToken = accessToken
if err := idToken.Claims(s); err != nil {
return fmt.Errorf("sessions: update state failed %w", err)
}
s.Audience = audience
s.Expiry = jwt.NewNumericDate(accessToken.Expiry)
return nil
}
// ForceRefresh sets the refresh deadline to now.
func (s *State) ForceRefresh() {
s.RefreshDeadline = time.Now().Truncate(time.Second)
// NewSession updates issuer, audience, and issuance timestamps but keeps
// parent expiry.
func (s State) NewSession(issuer string, audience []string) *State {
s.IssuedAt = jwt.NewNumericDate(timeNow())
s.NotBefore = s.IssuedAt
s.Audience = audience
s.Issuer = issuer
return &s
}
// Expired returns true if the refresh period has expired
func (s *State) Expired() bool {
return s.RefreshDeadline.Before(time.Now())
// RouteSession creates a route session with access tokens stripped and a
// custom validity period.
func (s State) RouteSession(validity time.Duration) *State {
s.Expiry = jwt.NewNumericDate(timeNow().Add(validity))
s.AccessToken = nil
return &s
}
// Verify returns an error if the users's session state is not valid.
func (s *State) Verify(audience string) error {
if s.NotBefore != nil && timeNow().Add(DefaultLeeway).Before(s.NotBefore.Time()) {
return ErrNotValidYet
}
if s.Expiry != nil && timeNow().Add(-DefaultLeeway).After(s.Expiry.Time()) {
return ErrExpired
}
if s.IssuedAt != nil && timeNow().Add(DefaultLeeway).Before(s.IssuedAt.Time()) {
return ErrIssuedInTheFuture
}
// if we have an associated access token, check if that token has expired as well
if s.AccessToken != nil && timeNow().Add(-DefaultLeeway).After(s.AccessToken.Expiry) {
return ErrExpired
}
if len(s.Audience) != 0 {
if !s.Audience.Contains(audience) {
return ErrInvalidAudience
}
}
return nil
}
// Impersonating returns if the request is impersonating.
@ -65,79 +166,12 @@ func (s *State) RequestGroups() string {
return strings.Join(s.Groups, ",")
}
type idToken struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Expiry jsonTime `json:"exp"`
IssuedAt jsonTime `json:"iat"`
Nonce string `json:"nonce"`
AtHash string `json:"at_hash"`
}
// IssuedAt parses the IDToken's issue date and returns a valid go time.Time.
func (s *State) IssuedAt() (time.Time, error) {
payload, err := parseJWT(s.IDToken)
if err != nil {
return time.Time{}, fmt.Errorf("internal/sessions: malformed jwt: %v", err)
}
var token idToken
if err := json.Unmarshal(payload, &token); err != nil {
return time.Time{}, fmt.Errorf("internal/sessions: failed to unmarshal claims: %v", err)
}
return time.Time(token.IssuedAt), nil
}
// MarshalSession marshals the session state as JSON, encrypts the JSON using the
// given cipher, and base64-encodes the result
func MarshalSession(s *State, c cryptutil.SecureEncoder) (string, error) {
v, err := c.Marshal(s)
if err != nil {
return "", err
}
return v, nil
}
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
// byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct
func UnmarshalSession(value string, c cryptutil.SecureEncoder) (*State, error) {
s := &State{}
err := c.Unmarshal(value, s)
if err != nil {
return nil, err
}
return s, nil
}
func parseJWT(p string) ([]byte, error) {
parts := strings.Split(p, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("internal/sessions: malformed jwt, expected 3 parts got %d", len(parts))
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("internal/sessions: malformed jwt payload: %v", err)
}
return payload, nil
}
type jsonTime time.Time
func (j *jsonTime) UnmarshalJSON(b []byte) error {
var n json.Number
if err := json.Unmarshal(b, &n); err != nil {
return err
}
var unix int64
if t, err := n.Int64(); err == nil {
unix = t
// SetImpersonation sets impersonation user and groups.
func (s *State) SetImpersonation(email, groups string) {
s.ImpersonateEmail = email
if groups == "" {
s.ImpersonateGroups = nil
} else {
f, err := n.Float64()
if err != nil {
return err
s.ImpersonateGroups = strings.Split(groups, ",")
}
unix = int64(f)
}
*j = jsonTime(time.Unix(unix, 0))
return nil
}

View file

@ -1,90 +1,16 @@
package sessions
import (
"crypto/rand"
"fmt"
"reflect"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt"
)
func TestStateSerialization(t *testing.T) {
secret := cryptutil.NewKey()
cipher, err := cryptutil.NewAEADCipher(secret)
c := cryptutil.NewSecureJSONEncoder(cipher)
if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err)
}
want := &State{
AccessToken: "token1234",
RefreshToken: "refresh4321",
RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(),
Email: "user@domain.com",
User: "user",
}
ciphertext, err := MarshalSession(want, c)
if err != nil {
t.Fatalf("expected to be encode session: %v", err)
}
got, err := UnmarshalSession(ciphertext, c)
if err != nil {
t.Fatalf("expected to be decode session: %v", err)
}
if !reflect.DeepEqual(want, got) {
t.Logf("want: %#v", want)
t.Logf(" got: %#v", got)
t.Errorf("encoding and decoding session resulted in unexpected output")
}
}
func TestStateExpirations(t *testing.T) {
session := &State{
AccessToken: "token1234",
RefreshToken: "refresh4321",
RefreshDeadline: time.Now().Add(-1 * time.Hour),
Email: "user@domain.com",
User: "user",
}
if !session.Expired() {
t.Errorf("expected lifetime period to be expired")
}
}
func TestState_IssuedAt(t *testing.T) {
t.Parallel()
tests := []struct {
name string
IDToken string
want time.Time
wantErr bool
}{
{"simple parse", "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlkzYm1qV3R4US16OW1fM1RLb0dtRWciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY3MjY4NywiZXhwIjoxNTU4Njc2Mjg3fQ.a4g8W94E7iVJhiIUmsNMwJssfx3Evi8sXeiXgXMC7kHNvftQ2CFU_LJ-dqZ5Jf61OXcrp26r7lUcTNENXuen9tyUWAiHvxk6OHTxZusdywTCY5xowpSZBO9PDWYrmmdvfhRbaKO6QVAUMkbKr1Tr8xqfoaYVXNZhERXhcVReDznI0ccbwCGrNx5oeqiL4eRdZY9eqFXi4Yfee0mkef9oyVPc2HvnpwcpM0eckYa_l_ZQChGjXVGBFIus_Ao33GbWDuc9gs-_Vp2ev4KUT2qWb7AXMCGDLx0tWI9umm7mCBi_7xnaanGKUYcVwcSrv45arllAAwzuNxO0BVw3oRWa5Q", time.Unix(1558672687, 0), false},
{"bad jwt", "x.x.x-x-x", time.Time{}, true},
{"malformed jwt", "x", time.Time{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &State{IDToken: tt.IDToken}
got, err := s.IssuedAt()
if (err != nil) != tt.wantErr {
t.Errorf("State.IssuedAt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("State.IssuedAt() = %v, want %v", got.Format(time.RFC3339), tt.want.Format(time.RFC3339))
}
})
}
}
func TestState_Impersonating(t *testing.T) {
t.Parallel()
tests := []struct {
@ -107,9 +33,8 @@ func TestState_Impersonating(t *testing.T) {
s := &State{
Email: tt.Email,
Groups: tt.Groups,
ImpersonateEmail: tt.ImpersonateEmail,
ImpersonateGroups: tt.ImpersonateGroups,
}
s.SetImpersonation(tt.ImpersonateEmail, strings.Join(tt.ImpersonateGroups, ","))
if got := s.Impersonating(); got != tt.want {
t.Errorf("State.Impersonating() = %v, want %v", got, tt.want)
}
@ -123,84 +48,80 @@ func TestState_Impersonating(t *testing.T) {
}
}
func TestMarshalSession(t *testing.T) {
secret := cryptutil.NewKey()
cipher, err := cryptutil.NewAEADCipher(secret)
if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err)
}
c := cryptutil.NewSecureJSONEncoder(cipher)
hugeString := make([]byte, 4097)
if _, err := rand.Read(hugeString); err != nil {
t.Fatal(err)
}
func TestState_Verify(t *testing.T) {
t.Parallel()
tests := []struct {
name string
s *State
Audience jwt.Audience
Expiry *jwt.NumericDate
NotBefore *jwt.NumericDate
IssuedAt *jwt.NumericDate
AccessToken *oauth2.Token
audience string
wantErr bool
}{
{"simple", &State{}, false},
{"too big", &State{AccessToken: fmt.Sprintf("%x", hugeString)}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
in, err := MarshalSession(tt.s, c)
if (err != nil) != tt.wantErr {
t.Errorf("MarshalSession() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil {
out, err := UnmarshalSession(in, c)
if err != nil {
t.Fatalf("expected to be decode session: %v", err)
}
if diff := cmp.Diff(tt.s, out); diff != "" {
t.Errorf("MarshalSession() = %s", diff)
}
}
})
}
}
func TestState_Valid(t *testing.T) {
tests := []struct {
name string
RefreshDeadline time.Time
wantErr bool
}{
{" good", time.Now().Add(10 * time.Second), false},
{" expired", time.Now().Add(-10 * time.Second), true},
{"good", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", false},
{"bad expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true},
{"bad audience", []string{"x", "y", "z"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true},
{"bad not before", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true},
{"bad issued at", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(time.Hour)), &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, "a", true},
{"bad access token expiry", []string{"a", "b", "c"}, jwt.NewNumericDate(time.Now().Add(time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), jwt.NewNumericDate(time.Now().Add(-time.Hour)), &oauth2.Token{Expiry: time.Now().Add(-time.Hour)}, "a", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &State{
RefreshDeadline: tt.RefreshDeadline,
Audience: tt.Audience,
Expiry: tt.Expiry,
NotBefore: tt.NotBefore,
IssuedAt: tt.IssuedAt,
AccessToken: tt.AccessToken,
}
if err := s.Valid(); (err != nil) != tt.wantErr {
t.Errorf("State.Valid() error = %v, wantErr %v", err, tt.wantErr)
if err := s.Verify(tt.audience); (err != nil) != tt.wantErr {
t.Errorf("State.Verify() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestState_ForceRefresh(t *testing.T) {
func TestState_RouteSession(t *testing.T) {
now := time.Now()
timeNow = func() time.Time {
return now
}
tests := []struct {
name string
RefreshDeadline time.Time
Issuer string
Audience jwt.Audience
Expiry *jwt.NumericDate
AccessToken *oauth2.Token
issuer string
audience []string
validity time.Duration
want *State
}{
{"good", time.Now().Truncate(time.Second)},
{"good", "authenticate.x.y.z", []string{"http.x.y.z"}, jwt.NewNumericDate(timeNow()), nil, "authenticate.a.b.c", []string{"http.a.b.c"}, 20 * time.Second, &State{Issuer: "authenticate.a.b.c", Audience: []string{"http.a.b.c"}, NotBefore: jwt.NewNumericDate(timeNow()), IssuedAt: jwt.NewNumericDate(timeNow()), Expiry: jwt.NewNumericDate(timeNow().Add(20 * time.Second))}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &State{
RefreshDeadline: tt.RefreshDeadline,
s := State{
Issuer: tt.Issuer,
Audience: tt.Audience,
Expiry: tt.Expiry,
AccessToken: tt.AccessToken,
}
s.ForceRefresh()
if s.RefreshDeadline != tt.RefreshDeadline {
t.Errorf("refresh deadline not updated")
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(State{}),
}
got := s.NewSession(tt.issuer, tt.audience)
got = got.RouteSession(tt.validity)
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
t.Errorf("State.RouteSession() = %s", diff)
}
})
}
}

View file

@ -6,19 +6,30 @@ import (
)
var (
// ErrExpired is the error for an expired session.
ErrExpired = errors.New("internal/sessions: session is expired")
// ErrNoSessionFound is the error for when no session is found.
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
// ErrMalformed is the error for when a session is found but is malformed.
ErrMalformed = errors.New("internal/sessions: session is malformed")
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
ErrNotValidYet = errors.New("internal/sessions: validation failed, token not valid yet (nbf)")
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
ErrExpired = errors.New("internal/sessions: validation failed, token is expired (exp)")
// ErrIssuedInTheFuture indicates that the iat field is in the future.
ErrIssuedInTheFuture = errors.New("internal/sessions: validation field, token issued in the future (iat)")
// ErrInvalidAudience indicated invalid aud claim.
ErrInvalidAudience = errors.New("internal/sessions: validation failed, invalid audience claim (aud)")
)
// SessionStore has the functions for setting, getting, and clearing the Session cookie
type SessionStore interface {
ClearSession(http.ResponseWriter, *http.Request)
LoadSession(*http.Request) (*State, error)
SaveSession(http.ResponseWriter, *http.Request, *State) error
SessionLoader
SaveSession(http.ResponseWriter, *http.Request, interface{}) error
}
// SessionLoader is implemented by any struct that loads a pomerium session
@ -26,3 +37,19 @@ type SessionStore interface {
type SessionLoader interface {
LoadSession(*http.Request) (*State, error)
}
// Encoder can both Marshal and Unmarshal a struct into and from a set of bytes.
type Encoder interface {
Marshaler
Unmarshaler
}
// Marshaler encodes a struct into a set of bytes.
type Marshaler interface {
Marshal(interface{}) ([]byte, error)
}
// Unmarshaler decodes a set of bytes and returns a struct.
type Unmarshaler interface {
Unmarshal([]byte, interface{}) error
}

View file

@ -1,12 +0,0 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import "strings"
// ParentSubdomain returns the parent subdomain.
func ParentSubdomain(s string) string {
if strings.Count(s, ".") < 2 {
return ""
}
split := strings.SplitN(s, ".", 2)
return split[1]
}

View file

@ -1,23 +0,0 @@
package sessions
import "testing"
func Test_ParentSubdomain(t *testing.T) {
t.Parallel()
tests := []struct {
s string
want string
}{
{"httpbin.corp.example.com", "corp.example.com"},
{"some.httpbin.corp.example.com", "httpbin.corp.example.com"},
{"example.com", ""},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) {
if got := ParentSubdomain(tt.s); got != tt.want {
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -143,9 +143,12 @@ func New() *template.Template {
text-align: center;
width: 75px;
height: auto;
border-radius: 50%;
}
.logo {
padding-bottom: 20px;
padding-top: 20px;
width: 115px;
height: auto;
}
@ -161,6 +164,7 @@ func New() *template.Template {
p.message {
margin-top: 10px;
margin-bottom: 10px;
padding-bottom: 20px;
}
.field {
@ -300,37 +304,119 @@ func New() *template.Template {
<div id="main">
<div id="info-box">
<div class="card">
{{if .Session.Picture }}
<img class="icon" src="{{.Session.Picture}}" alt="user image">
{{else}}
<svg class="icon ok" xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24">
<path fill="none" d="M0 0h24v24H0V0z" />
<path d="M11 7h2v2h-2zm0 4h2v6h-2zm1-9C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm0 18c-4.41 0-8-3.59-8-8s3.59-8 8-8 8 3.59 8 8-3.59 8-8 8z" />
</svg>
<form method="POST" action="{{.SignoutURL}}">
{{end}}
<form method="POST" action="/.pomerium/sign_out">
<section>
<h2>Current user</h2>
<p class="message">Your current session details.</p>
<fieldset>
{{if .Session.Name}}
<label>
<span>Name</span>
<input name="Name" type="text" class="field" value="{{.Session.Name}}" disabled>
</label>
{{else}}
{{if .Session.GivenName}}
<label>
<span>Given Name</span>
<input name="GivenName" type="text" class="field" value="{{.Session.GivenName}}" disabled>
</label>
{{end}}
{{if .Session.FamilyName}}
<label>
<span>Family Name</span>
<input name="FamilyName" type="text" class="field" value="{{.Session.FamilyName}}" disabled>
</label>
{{end}}
{{end}}
{{if .Session.Subject}}
<label>
<span>UserID</span>
<input name="email" type="text" class="field" value="{{.Session.Subject}}" disabled>
</label>
{{end}}
{{if .Session.Email}}
<label>
<span>Email</span>
<input name="email" type="email" class="field" value="{{.Email}}" disabled>
<input name="email" type="email" class="field" value="{{.Session.Email}}" disabled>
</label>
{{end}}
{{if .Session.User}}
<label>
<span>User</span>
<input name="user" type="text" class="field" value="{{.User}}" disabled>
<input name="user" type="text" class="field" value="{{.Session.User}}" disabled>
</label>
{{end}}
{{if .Session.Groups}}
<label class="select">
<span>Groups</span>
<div id="group" class="field">
<select name="group">
{{range .Groups}}
{{range .Session.Groups}}
<option value="{{.}}">{{.}}</option>
{{end}}
</select>
</div>
</label>
{{end}}
{{if .Session.Expiry}}
<label>
<span>Expiry</span>
<input name="session expiration" type="text" class="field" value="{{.RefreshDeadline}}" disabled>
<input name="session expiration" type="text" class="field" value="{{.Session.Expiry.Time}}" disabled>
</label>
{{end}}
{{if .Session.IssuedAt}}
<label>
<span>Issued</span>
<input name="session expiration" type="text" class="field" value="{{.Session.IssuedAt.Time}}" disabled>
</label>
{{end}}
{{if .Session.Issuer}}
<label>
<span>Issuer</span>
<input name="session expiration" type="text" class="field" value=" {{ .Session.Issuer}}" disabled>
</label>
{{end}}
{{if .Session.Audience}}
<label class="select">
<span>Audiences</span>
<div id="group" class="field">
<select name="group">
{{range .Session.Audience}}
<option value="{{.}}">{{ printf "%.30s" . }}</option>
{{end}}
</select>
</div>
</label>
{{end}}
{{if .Session.ImpersonateEmail}}
<label>
<span>Impersonating Email</span>
<input name="session expiration" type="text" class="field" value="{{.Session.ImpersonateEmail}}" disabled>
</label>
{{end}}
{{if .Session.ImpersonateGroups}}
<label class="select">
<span>Impersonating Groups</span>
<div id="group" class="field">
<select name="group">
{{range .Session.ImpersonateGroups}}
<option value="{{.}}">{{.}}</option>
{{end}}
</select>
</div>
</label>
{{end}}
</fieldset>
</section>
<div class="flex">
@ -338,17 +424,7 @@ func New() *template.Template {
<button class="button full" type="submit">Sign Out</button>
</div>
</form>
<section>
<h2>Refresh Identity</h2>
<p class="message">Pomerium will automatically refresh your user session. However, if your group memberships have recently changed and haven't taken effect yet, you can refresh your session manually.</p>
<form method="POST" action="/.pomerium/refresh">
<div class="flex">
{{ .csrfField }}
<button class="button full" type="submit">Refresh</button>
</div>
</form>
</section>
{{if .IsAdmin}}
<form method="POST" action="/.pomerium/impersonate">
@ -358,11 +434,11 @@ func New() *template.Template {
<fieldset>
<label>
<span>Email</span>
<input name="email" type="email" class="field" value="{{.ImpersonateEmail}}" placeholder="user@example.com">
<input name="email" type="email" class="field" value="" placeholder="user@example.com">
</label>
<label>
<span>Group</span>
<input name="group" type="text" class="field" value="{{.ImpersonateGroup}}" placeholder="engineering">
<input name="group" type="text" class="field" value="" placeholder="engineering">
</label>
</fieldset>
</section>

View file

@ -1,16 +1,17 @@
package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/gorilla/mux"
"github.com/pomerium/csrf"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates"
"github.com/pomerium/pomerium/internal/urlutil"
@ -27,15 +28,24 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
// 3. Enforce CSRF protections for any non-idempotent http method
h.Use(csrf.Protect(
p.cookieSecret,
csrf.Path("/"),
csrf.Domain(p.cookieDomain),
csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieName)),
csrf.Secure(p.cookieOptions.Secure),
csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieOptions.Name)),
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
))
h.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet)
h.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost)
h.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost)
h.HandleFunc("/refresh", p.ForceRefresh).Methods(http.MethodPost)
// Authenticate service callback handlers and middleware
c := r.PathPrefix(dashboardURL + "/callback").Subrouter()
// only accept payloads that have come from a trusted service (hmac)
c.Use(middleware.ValidateSignature(p.SharedKey))
c.HandleFunc("/", p.Callback).Queries("redirect_uri", "{redirect_uri}").Methods(http.MethodGet)
// Programmatic API handlers and middleware
a := r.PathPrefix(dashboardURL + "/api").Subrouter()
a.HandleFunc("/v1/login", p.ProgrammaticLogin).Queries("redirect_uri", "{redirect_uri}").Methods(http.MethodGet)
return r
}
@ -56,6 +66,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
redirectURL = uri
}
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
p.sessionStore.ClearSession(w, r)
http.Redirect(w, r, uri.String(), http.StatusFound)
}
@ -74,53 +85,14 @@ func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, err)
return
}
//todo(bdd): make sign out redirect a configuration option so that
// admins can set to whatever their corporate homepage is
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
signoutURL := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
templates.New().ExecuteTemplate(w, "dashboard.html", map[string]interface{}{
"Email": session.Email,
"User": session.User,
"Groups": session.Groups,
"RefreshDeadline": time.Until(session.RefreshDeadline).Round(time.Second).String(),
"SignoutURL": signoutURL.String(),
"Session": session,
"IsAdmin": isAdmin,
"ImpersonateEmail": session.ImpersonateEmail,
"ImpersonateGroup": strings.Join(session.ImpersonateGroups, ","),
"csrfField": csrf.TemplateField(r),
})
}
// ForceRefresh redeems and extends an existing authenticated oidc session with
// the underlying identity provider. All session details including groups,
// timeouts, will be renewed.
func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
session, err := sessions.FromContext(r.Context())
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
iss, err := session.IssuedAt()
if err != nil {
httputil.ErrorResponse(w, r, err)
return
}
// reject a refresh if it's been less than the refresh cooldown to prevent abuse
if time.Since(iss) < p.refreshCooldown {
errStr := fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown)
httpErr := httputil.Error(errStr, http.StatusBadRequest, nil)
httputil.ErrorResponse(w, r, httpErr)
return
}
session.ForceRefresh()
if err = p.sessionStore.SaveSession(w, r, session); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
http.Redirect(w, r, dashboardURL, http.StatusFound)
}
// Impersonate takes the result of a form and adds user impersonation details
// to the user's current user sessions state if the user is currently an
// administrative user. Requests are redirected back to the user dashboard.
@ -138,101 +110,112 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
return
}
// OK to impersonation
session.ImpersonateEmail = r.FormValue("email")
session.ImpersonateGroups = strings.Split(r.FormValue("group"), ",")
groups := r.FormValue("group")
if groups != "" {
session.ImpersonateGroups = strings.Split(groups, ",")
}
if err := p.sessionStore.SaveSession(w, r, session); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
http.Redirect(w, r, dashboardURL, http.StatusFound)
redirectURL := urlutil.GetAbsoluteURL(r)
redirectURL.Path = dashboardURL // redirect back to the dashboard
q := redirectURL.Query()
q.Add("impersonate_email", r.FormValue("email"))
q.Add("impersonate_group", r.FormValue("group"))
redirectURL.RawQuery = q.Encode()
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String()
http.Redirect(w, r, uri, http.StatusFound)
}
func (p *Proxy) registerFwdAuthHandlers() http.Handler {
r := httputil.NewRouter()
r.StrictSlash(true)
r.Use(sessions.RetrieveSession(p.sessionStore))
r.HandleFunc("/", p.VerifyAndSignin).Queries("uri", "{uri}").Methods(http.MethodGet)
r.HandleFunc("/verify", p.VerifyOnly).Queries("uri", "{uri}").Methods(http.MethodGet)
r.Handle("/", p.Verify(false)).Queries("uri", "{uri}").Methods(http.MethodGet)
r.Handle("/verify", p.Verify(true)).Queries("uri", "{uri}").Methods(http.MethodGet)
return r
}
// VerifyAndSignin checks a user's credentials for an arbitrary host. If the user
// Verify checks a user's credentials for an arbitrary host. If the user
// is properly authenticated and is authorized to access the supplied host,
// a `200` http status code is returned. If the user is not authenticated, they
// will be redirected to the authenticate service to sign in with their identity
// provider. If the user is unauthorized, a `401` error is returned.
func (p *Proxy) VerifyAndSignin(w http.ResponseWriter, r *http.Request) {
func (p *Proxy) Verify(verifyOnly bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
uri, err := urlutil.ParseAndValidateURL(r.FormValue("uri"))
if err != nil || uri.String() == "" {
httputil.ErrorResponse(w, r, httputil.Error("bad verification uri given", http.StatusBadRequest, nil))
httputil.ErrorResponse(w, r, httputil.Error("bad verification uri", http.StatusBadRequest, nil))
return
}
if err := p.authenticate(w, r); err != nil {
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
http.Redirect(w, r, uri.String(), http.StatusFound)
}
if err := p.authorize(r, uri); err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusUnauthorized, err))
if err := p.authenticate(verifyOnly, w, r); err != nil {
return
}
// check the queryparams to see if this check immediately followed
// authentication. If so, redirect back to the originally requested hostname.
if isCallback := r.URL.Query().Get(callbackQueryParam); isCallback == "true" {
q := uri.Query()
q.Del(callbackQueryParam)
uri.RawQuery = q.Encode()
http.Redirect(w, r, uri.String(), http.StatusFound)
if err := p.authorize(uri.Host, w, r); err != nil {
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, fmt.Sprintf("Access to %s is allowed.", uri.Host))
})
}
// VerifyOnly checks a user's credentials for an arbitrary host. If the user
// is properly authenticated and is authorized to access the supplied host,
// a `200` http status code is returned otherwise a `401` error is returned.
func (p *Proxy) VerifyOnly(w http.ResponseWriter, r *http.Request) {
uri, err := urlutil.ParseAndValidateURL(r.FormValue("uri"))
if err != nil || uri.String() == "" {
httputil.ErrorResponse(w, r, httputil.Error("bad verification uri given", http.StatusBadRequest, nil))
// Callback takes a `redirect_uri` query param that has been hmac'd by the
// authenticate service. Embedded in the `redirect_uri` are query-params
// that tell this handler how to set the per-route user session.
// Callback is responsible for redirecting the user back to the intended
// destination URL and path, as well as to clean up any additional query params
// added by the authenticate service.
func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) {
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
return
}
if err := p.authenticate(w, r); err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusUnauthorized, err))
q := redirectURL.Query()
// 1. extract the base64 encoded and encrypted JWT from redirect_uri's query params
encryptedJWT, err := base64.URLEncoding.DecodeString(q.Get("pomerium_jwt"))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
return
}
if err := p.authorize(r, uri); err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusUnauthorized, err))
q.Del("pomerium_jwt")
q.Del("impersonate_email")
q.Del("impersonate_group")
// 2. decrypt the JWT using the cipher using the _shared_ secret key
rawJWT, err := cryptutil.Decrypt(p.sharedCipher, encryptedJWT, nil)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
return
}
// 3. Save the decrypted JWT to the session store directly as a string, without resigning
if err = p.sessionStore.SaveSession(w, r, rawJWT); err != nil {
httputil.ErrorResponse(w, r, err)
return
}
// if this is a programmatic request, don't strip the tokens before redirect
if redirectURL.Query().Get("pomerium_programmatic_destination_url") != "" {
q.Set("pomerium_jwt", string(rawJWT))
}
redirectURL.RawQuery = q.Encode()
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
}
// ProgrammaticLogin returns a signed url that can be used to login
// using the authenticate service.
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) {
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
return
}
q := redirectURL.Query()
q.Add("pomerium_programmatic_destination_url", urlutil.GetAbsoluteURL(r).String())
redirectURL.RawQuery = q.Encode()
response := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String()
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(http.StatusOK)
}
func (p *Proxy) authorize(r *http.Request, uri *url.URL) error {
// attempt to retrieve the user session from the request context, validity
// of the identity session is asserted by the middleware chain
s, err := sessions.FromContext(r.Context())
if err != nil {
return err
}
// query the authorization service to see if the session's user has
// the appropriate authorization to access the given hostname
authorized, err := p.AuthorizeClient.Authorize(r.Context(), uri.Host, s)
if err != nil {
return err
} else if !authorized {
return fmt.Errorf("%s is not authorized for %s", s.RequestEmail(), uri.String())
}
return nil
w.Write([]byte(response))
}

View file

@ -1,4 +1,4 @@
package proxy // import "github.com/pomerium/pomerium/proxy"
package proxy
import (
"bytes"
@ -11,11 +11,14 @@ import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/proxy/clients"
"github.com/google/go-cmp/cmp"
"gopkg.in/square/go-jose.v2/jwt"
)
func TestProxy_RobotsTxt(t *testing.T) {
@ -62,17 +65,17 @@ func TestProxy_UserDashboard(t *testing.T) {
ctxError error
options config.Options
method string
cipher cryptutil.SecureEncoder
cipher sessions.Encoder
session sessions.SessionStore
authorizer clients.Authorizer
wantAdminForm bool
wantStatus int
}{
{"good", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK},
{"session context error", errors.New("error"), opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
{"want admin form good admin authorization", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
{"is admin but authorization fails", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
{"good", nil, opts, http.MethodGet, &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusOK},
{"session context error", errors.New("error"), opts, http.MethodGet, &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
{"want admin form good admin authorization", nil, opts, http.MethodGet, &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
{"is admin but authorization fails", nil, opts, http.MethodGet, &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
}
for _, tt := range tests {
@ -109,56 +112,6 @@ func TestProxy_UserDashboard(t *testing.T) {
}
}
func TestProxy_ForceRefresh(t *testing.T) {
opts := testOptions(t)
opts.RefreshCooldown = 0
timeSinceError := testOptions(t)
timeSinceError.RefreshCooldown = time.Duration(int(^uint(0) >> 1))
tests := []struct {
name string
ctxError error
options config.Options
method string
cipher cryptutil.SecureEncoder
session sessions.SessionStore
authorizer clients.Authorizer
wantStatus int
}{
{"good", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
{"cannot load session", errors.New("error"), opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
{"bad id token", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
{"issue date too soon", nil, timeSinceError, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest},
{"refresh failure", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
{"can't save refreshed session", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(tt.options)
if err != nil {
t.Fatal(err)
}
p.encoder = tt.cipher
p.sessionStore = tt.session
p.AuthorizeClient = tt.authorizer
r := httptest.NewRequest(tt.method, "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
p.ForceRefresh(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", opts)
}
})
}
}
func TestProxy_Impersonate(t *testing.T) {
t.Parallel()
opts := testOptions(t)
@ -171,18 +124,17 @@ func TestProxy_Impersonate(t *testing.T) {
email string
groups string
csrf string
cipher cryptutil.SecureEncoder
cipher sessions.Encoder
sessionStore sessions.SessionStore
authorizer clients.Authorizer
wantStatus int
}{
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
{"save session failure", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &encoding.MockEncoder{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -276,25 +228,24 @@ func TestProxy_VerifyWithMiddleware(t *testing.T) {
path string
verifyURI string
cipher cryptutil.SecureEncoder
cipher sessions.Encoder
sessionStore sessions.SessionStore
authorizer clients.Authorizer
wantStatus int
wantBody string
}{
{"good", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
{"good verify only", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
{"bad naked domain uri given", opts, nil, http.MethodGet, "", "/", "a.naked.domain", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri given\"}\n"},
{"bad naked domain uri given verify only", opts, nil, http.MethodGet, "", "/verify", "a.naked.domain", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri given\"}\n"},
{"bad empty verification uri given", opts, nil, http.MethodGet, "", "/", " ", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri given\"}\n"},
{"bad empty verification uri given verify only", opts, nil, http.MethodGet, "", "/verify", " ", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri given\"}\n"},
{"good post auth redirect", opts, nil, http.MethodGet, callbackQueryParam, "/", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, "<a href=\"https://some.domain.example\">Found</a>.\n\n"},
{"not authorized", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"Unauthorized\"}\n"},
{"not authorized verify endpoint", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"Unauthorized\"}\n"},
{"not authorized expired, redirect to auth", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
{"not authorized expired, don't redirect!", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"Unauthorized\"}\n"},
{"not authorized because of error", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusUnauthorized, "{\"error\":\"Unauthorized\"}\n"},
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"Unauthorized\"}\n"},
{"good", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
{"good verify only", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
{"bad naked domain uri", opts, nil, http.MethodGet, "", "/", "a.naked.domain", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
{"bad naked domain uri verify only", opts, nil, http.MethodGet, "", "/verify", "a.naked.domain", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
{"bad empty verification uri", opts, nil, http.MethodGet, "", "/", " ", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
{"bad empty verification uri verify only", opts, nil, http.MethodGet, "", "/verify", " ", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
{"not authorized", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized verify endpoint", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, "", "/", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, "", "/verify", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"},
{"not authorized because of error", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"error\":\"authz error\"}\n"},
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -346,3 +297,133 @@ func TestProxy_VerifyWithMiddleware(t *testing.T) {
})
}
}
func TestProxy_Callback(t *testing.T) {
t.Parallel()
opts := testOptions(t)
tests := []struct {
name string
options config.Options
method string
scheme string
host string
path string
qp map[string]string
cipher sessions.Encoder
sessionStore sessions.SessionStore
authorizer clients.Authorizer
wantStatus int
wantBody string
}{
{"good", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_programmatic_destination_url": "ok", "pomerium_jwt": "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &encoding.MockEncoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &encoding.MockEncoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad save session", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &encoding.MockEncoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusInternalServerError, ""},
{"bad base64", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "^"}, &encoding.MockEncoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, &encoding.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(tt.options)
if err != nil {
t.Fatal(err)
}
p.encoder = tt.cipher
p.sessionStore = tt.sessionStore
p.AuthorizeClient = tt.authorizer
p.UpdateOptions(tt.options)
redirectURI := &url.URL{Scheme: tt.scheme, Host: tt.host, Path: tt.path}
queryString := redirectURI.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
redirectURI.RawQuery = queryString.Encode()
uri := &url.URL{Path: "/"}
if tt.qp != nil {
qu := uri.Query()
qu.Set("redirect_uri", redirectURI.String())
uri.RawQuery = qu.Encode()
}
r := httptest.NewRequest(tt.method, uri.String(), nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
p.Callback(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", w.Body.String())
}
if tt.wantBody != "" {
body := w.Body.String()
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
t.Errorf("wrong body\n%s", diff)
}
}
})
}
}
func TestProxy_ProgrammaticLogin(t *testing.T) {
t.Parallel()
opts := testOptions(t)
tests := []struct {
name string
options config.Options
method string
scheme string
host string
path string
qp map[string]string
wantStatus int
wantBody string
}{
{"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "http://localhost"}, http.StatusOK, ""},
{"router miss, bad redirect_uri query", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"bad_redirect_uri": "http://localhost"}, http.StatusNotFound, ""},
{"bad redirect_uri missing scheme", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "localhost"}, http.StatusBadRequest, "{\"error\":\"malformed redirect_uri\"}\n"},
{"bad http method", opts, http.MethodPost, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "http://localhost"}, http.StatusMethodNotAllowed, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(tt.options)
if err != nil {
t.Fatal(err)
}
redirectURI := &url.URL{Scheme: tt.scheme, Host: tt.host, Path: tt.path}
queryString := redirectURI.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
redirectURI.RawQuery = queryString.Encode()
r := httptest.NewRequest(tt.method, redirectURI.String(), nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
router := httputil.NewRouter()
router = p.registerDashboardHandlers(router)
router.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", w.Body.String())
}
if tt.wantBody != "" {
body := w.Body.String()
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
t.Errorf("wrong body\n%s", diff)
}
}
})
}
}

View file

@ -3,9 +3,8 @@ package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
@ -30,34 +29,35 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession")
defer span.End()
if err := p.authenticate(w, r); err != nil {
if err := p.authenticate(false, w, r.WithContext(ctx)); err != nil {
p.sessionStore.ClearSession(w, r)
log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session")
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
http.Redirect(w, r, uri.String(), http.StatusFound)
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request) error {
// authenticate authenticates a user and sets an appropriate response type,
// redirect to authenticate or error handler depending on if err on failure is set.
func (p *Proxy) authenticate(errOnFailure bool, w http.ResponseWriter, r *http.Request) error {
s, err := sessions.FromContext(r.Context())
if err != nil {
if errOnFailure || (s != nil && s.Programmatic) {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
return err
}
if s == nil {
return fmt.Errorf("empty session state")
}
if err := s.Valid(); err != nil {
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
http.Redirect(w, r, uri.String(), http.StatusFound)
return err
}
// add pomerium's headers to the downstream request
r.Header.Set(HeaderUserID, s.User)
r.Header.Set(HeaderUserID, s.Subject)
r.Header.Set(HeaderEmail, s.RequestEmail())
r.Header.Set(HeaderGroups, s.RequestGroups())
// and upstream
w.Header().Set(HeaderUserID, s.User)
w.Header().Set(HeaderUserID, s.Subject)
w.Header().Set(HeaderEmail, s.RequestEmail())
w.Header().Set(HeaderGroups, s.RequestGroups())
return nil
@ -69,27 +69,35 @@ func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession")
defer span.End()
s, err := sessions.FromContext(r.Context())
if err != nil || s == nil {
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err))
return
}
authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s)
if err != nil {
httputil.ErrorResponse(w, r.WithContext(ctx), err)
return
} else if !authorized {
errMsg := fmt.Sprintf("%s is not authorized for this route", s.RequestEmail())
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error(errMsg, http.StatusForbidden, nil))
if err := p.authorize(r.Host, w, r.WithContext(ctx)); err != nil {
log.FromRequest(r).Debug().Err(err).Msg("proxy: AuthorizeSession")
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (p *Proxy) authorize(host string, w http.ResponseWriter, r *http.Request) error {
s, err := sessions.FromContext(r.Context())
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusUnauthorized, err))
return err
}
authorized, err := p.AuthorizeClient.Authorize(r.Context(), host, s)
if err != nil {
httputil.ErrorResponse(w, r, err)
return err
} else if !authorized {
err = fmt.Errorf("%s is not authorized for %s", s.RequestEmail(), host)
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
return err
}
return nil
}
// SignRequest is middleware that signs a JWT that contains a user's id,
// email, and group. Session state is retrieved from the users's request context
func (p *Proxy) SignRequest(signer cryptutil.JWTSigner) func(next http.Handler) http.Handler {
func (p *Proxy) SignRequest(signer sessions.Marshaler) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "proxy.SignRequest")
@ -99,12 +107,13 @@ func (p *Proxy) SignRequest(signer cryptutil.JWTSigner) func(next http.Handler)
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err))
return
}
jwt, err := signer.SignJWT(s.User, s.Email, strings.Join(s.Groups, ","))
newSession := s.NewSession(r.Host, []string{r.Host})
jwt, err := signer.Marshal(newSession.RouteSession(time.Minute))
if err != nil {
log.FromRequest(r).Warn().Err(err).Msg("proxy: failed signing jwt")
log.FromRequest(r).Error().Err(err).Msg("proxy: failed signing jwt")
} else {
r.Header.Set(HeaderJWT, jwt)
w.Header().Set(HeaderJWT, jwt)
r.Header.Set(HeaderJWT, string(jwt))
w.Header().Set(HeaderJWT, string(jwt))
}
next.ServeHTTP(w, r.WithContext(ctx))
})

View file

@ -1,7 +1,6 @@
package proxy
import (
"encoding/base64"
"errors"
"fmt"
"net/http"
@ -14,6 +13,7 @@ import (
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/proxy/clients"
"gopkg.in/square/go-jose.v2/jwt"
)
func TestProxy_AuthenticateSession(t *testing.T) {
@ -27,15 +27,18 @@ func TestProxy_AuthenticateSession(t *testing.T) {
tests := []struct {
name string
errOnFailure bool
session sessions.SessionStore
ctxError error
provider identity.Authenticator
wantStatus int
}{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusFound},
{"good", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, http.StatusOK},
{"invalid session", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"expired", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusFound},
{"expired and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusUnauthorized},
{"invalid session and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -81,10 +84,10 @@ func TestProxy_AuthorizeSession(t *testing.T) {
wantStatus int
}{
{"user is authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK},
{"user is not authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusForbidden},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusForbidden},
{"authz client error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError},
{"user is authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK},
{"user is not authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusUnauthorized},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
{"authz client error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -119,8 +122,9 @@ type mockJWTSigner struct {
// Sign implements the JWTSigner interface from the cryptutil package, but just
// base64's the inputs instead for stesting.
func (s *mockJWTSigner) SignJWT(user, email, groups string) (string, error) {
return base64.StdEncoding.EncodeToString([]byte(fmt.Sprint(user, email, groups))), s.SignError
func (s *mockJWTSigner) Marshal(v interface{}) ([]byte, error) {
return []byte("ok"), s.SignError
}
func TestProxy_SignRequest(t *testing.T) {
@ -142,7 +146,7 @@ func TestProxy_SignRequest(t *testing.T) {
wantStatus int
wantHeaders string
}{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "dGVzdA=="},
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "ok"},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""},
{"signature failure, warn but ok", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""},
}

View file

@ -1,6 +1,7 @@
package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"crypto/cipher"
"crypto/tls"
"encoding/base64"
"fmt"
@ -10,8 +11,10 @@ import (
"time"
"github.com/gorilla/mux"
"github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
@ -30,8 +33,6 @@ const (
signinURL = "/.pomerium/sign_in"
// signoutURL is the path to authenticate's sign out endpoint
signoutURL = "/.pomerium/sign_out"
callbackQueryParam = "pomerium-auth-callback"
)
// ValidateOptions checks that proper configuration settings are set to create
@ -54,7 +55,7 @@ func ValidateOptions(o config.Options) error {
}
if len(o.SigningKey) != 0 {
if _, err := cryptutil.NewES256Signer(o.SigningKey, ""); err != nil {
if _, err := jws.NewES256Signer(o.SigningKey, ""); err != nil {
return fmt.Errorf("proxy: invalid 'SIGNING_KEY': %v", err)
}
}
@ -65,6 +66,8 @@ func ValidateOptions(o config.Options) error {
type Proxy struct {
// SharedKey used to mutually authenticate service communication
SharedKey string
sharedCipher cipher.AEAD
authenticateURL *url.URL
authenticateSigninURL *url.URL
authenticateSignoutURL *url.URL
@ -72,9 +75,8 @@ type Proxy struct {
AuthorizeClient clients.Authorizer
encoder cryptutil.SecureEncoder
cookieName string
cookieDomain string
encoder sessions.Encoder
cookieOptions *sessions.CookieOptions
cookieSecret []byte
defaultUpstreamTimeout time.Duration
refreshCooldown time.Duration
@ -92,50 +94,48 @@ func New(opts config.Options) (*Proxy, error) {
return nil, err
}
// errors checked in ValidateOptions
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey)
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
cipher, _ := cryptutil.NewAEADCipherFromBase64(opts.CookieSecret)
encoder := cryptutil.NewSecureJSONEncoder(cipher)
if opts.CookieDomain == "" {
opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String())
}
cookieStore, err := sessions.NewCookieStore(
&sessions.CookieStoreOptions{
Name: opts.CookieName,
CookieDomain: opts.CookieDomain,
CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieExpire: opts.CookieExpire,
Encoder: encoder,
})
// used to load and verify JWT tokens signed by the authenticate service
encoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.AuthenticateURL.Host)
if err != nil {
return nil, err
}
cookieOptions := &sessions.CookieOptions{
Name: opts.CookieName,
Domain: opts.CookieDomain,
Secure: opts.CookieSecure,
HTTPOnly: opts.CookieHTTPOnly,
Expire: opts.CookieExpire,
}
cookieStore, err := sessions.NewCookieLoader(cookieOptions, encoder)
if err != nil {
return nil, err
}
p := &Proxy{
SharedKey: opts.SharedKey,
sharedCipher: sharedCipher,
encoder: encoder,
cookieSecret: decodedCookieSecret,
cookieDomain: opts.CookieDomain,
cookieName: opts.CookieName,
cookieOptions: cookieOptions,
defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
refreshCooldown: opts.RefreshCooldown,
sessionStore: cookieStore,
sessionLoaders: []sessions.SessionLoader{
cookieStore,
sessions.NewHeaderStore(encoder),
sessions.NewQueryParamStore(encoder)},
sessions.NewHeaderStore(encoder, "Pomerium"),
sessions.NewQueryParamStore(encoder, "pomerium_session")},
signingKey: opts.SigningKey,
templates: templates.New(),
}
// errors checked in ValidateOptions
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
@ -238,14 +238,14 @@ func (p *Proxy) reverseProxyHandler(r *mux.Router, policy *config.Policy) (*mux.
// 4. Retrieve the user session and add it to the request context
rp.Use(sessions.RetrieveSession(p.sessionLoaders...))
// 5. Strip the user session cookie from the downstream request
rp.Use(middleware.StripCookie(p.cookieName))
rp.Use(middleware.StripCookie(p.cookieOptions.Name))
// 6. AuthN - Verify the user is authenticated. Set email, group, & id headers
rp.Use(p.AuthenticateSession)
// 7. AuthZ - Verify the user is authorized for route
rp.Use(p.AuthorizeSession)
// Optional: Add a signed JWT attesting to the user's id, email, and group
if len(p.signingKey) != 0 {
signer, err := cryptutil.NewES256Signer(p.signingKey, policy.Source.Host)
signer, err := jws.NewES256Signer(p.signingKey, policy.Destination.Host)
if err != nil {
return nil, err
}

View file

@ -172,7 +172,10 @@ func Test_UpdateOptions(t *testing.T) {
corsPreflight.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", CORSAllowPreflight: true}}
disableAuth := testOptions(t)
disableAuth.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", AllowPublicUnauthenticatedAccess: true}}
fwdAuth := testOptions(t)
fwdAuth.ForwardAuthURL = &url.URL{Scheme: "https", Host: "corp.example.example"}
reqHeaders := testOptions(t)
reqHeaders.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", SetRequestHeaders: map[string]string{"x": "y"}}}
tests := []struct {
name string
originalOptions config.Options
@ -198,6 +201,8 @@ func Test_UpdateOptions(t *testing.T) {
{"no websockets, custom timeout", good, customTimeout, "", "https://corp.example.example", false, true},
{"enable cors preflight", good, corsPreflight, "", "https://corp.example.example", false, true},
{"disable auth", good, disableAuth, "", "https://corp.example.example", false, true},
{"enable forward auth", good, fwdAuth, "", "https://corp.example.example", false, true},
{"set request headers", good, reqHeaders, "", "https://corp.example.example", false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -1,79 +1,136 @@
from __future__ import absolute_import, division, print_function
import argparse
import http.server
import json
import sys
import urllib.parse
import webbrowser
from urllib.parse import urlparse
import requests
done = False
parser = argparse.ArgumentParser()
parser.add_argument('--openid-configuration',
default="https://accounts.google.com/.well-known/openid-configuration")
parser.add_argument('--client-id')
parser.add_argument('--client-secret')
parser.add_argument('--pomerium-client-id')
parser.add_argument('--code')
parser.add_argument('--pomerium-token-url',
default="https://authenticate.corp.beyondperimeter.com/api/v1/token")
parser.add_argument('--pomerium-token')
parser.add_argument('--pomerium-url', default="https://httpbin.corp.beyondperimeter.com/get")
parser.add_argument("--login", action="store_true")
parser.add_argument(
"--dst", default="https://httpbin.imac.bdd.io/headers",
)
parser.add_argument(
"--refresh-endpoint", default="https://authenticate.imac.bdd.io/api/v1/refresh",
)
parser.add_argument("--server", default="localhost", type=str)
parser.add_argument("--port", default=8000, type=int)
parser.add_argument(
"--cred", default="pomerium-cred.json",
)
args = parser.parse_args()
class PomeriumSession:
def __init__(self, jwt, refresh_token):
self.jwt = jwt
self.refresh_token = refresh_token
def to_json(self):
return json.dumps(self.__dict__, indent=2)
@classmethod
def from_json_file(cls, fn):
with open(fn) as f:
data = json.load(f)
return cls(**data)
class Callback(http.server.BaseHTTPRequestHandler):
def log_message(self, format, *args):
# silence http server logs for now
return
def do_GET(self):
global args
global done
self.send_response(200)
self.end_headers()
response = b"OK"
if "pomerium" in self.path:
path = urllib.parse.urlparse(self.path).query
path_qp = urllib.parse.parse_qs(path)
session = PomeriumSession(
path_qp.get("pomerium_jwt")[0],
path_qp.get("pomerium_refresh_token")[0],
)
done = True
response = session.to_json().encode()
with open(args.cred, "w", encoding="utf-8") as f:
f.write(session.to_json())
print("=> pomerium json credential saved to:\n{}".format(f.name))
self.wfile.write(response)
def main():
args = parser.parse_args()
code = args.code
pomerium_token = args.pomerium_token
oidc_document = requests.get(args.openid_configuration).json()
token_url = oidc_document['token_endpoint']
print(token_url)
sign_in_url = oidc_document['authorization_endpoint']
global args
if not code and not pomerium_token:
if not args.client_id:
print("client-id is required")
sys.exit(1)
dst = urllib.parse.urlparse(args.dst)
try:
cred = PomeriumSession.from_json_file(args.cred)
except:
print("=> no credential found, let's login")
args.login = True
sign_in_url = "{}?response_type=code&scope=openid%20email&access_type=offline&redirect_uri=urn:ietf:wg:oauth:2.0:oob&client_id={}".format(
oidc_document['authorization_endpoint'], args.client_id)
print("Access code not set, so we'll do the process interactively!")
print("Go to the url : {}".format(sign_in_url))
code = input("Complete the login and enter your code:")
print(code)
# initial login to make sure we have our credential
if args.login:
dst = urllib.parse.urlparse(args.dst)
query_params = {"redirect_uri": "http://{}:{}".format(args.server, args.port)}
enc_query_params = urllib.parse.urlencode(query_params)
dst_login = "{}://{}{}?{}".format(
dst.scheme, dst.hostname, "/.pomerium/api/v1/login", enc_query_params,
)
response = requests.get(dst_login)
print("=> Your browser has been opened to visit:\n{}".format(response.text))
webbrowser.open(response.text)
if not pomerium_token:
req = requests.post(
token_url, {
'client_id': args.client_id,
'client_secret': args.client_secret,
'code': code,
'redirect_uri': 'urn:ietf:wg:oauth:2.0:oob',
'grant_type': 'authorization_code'
})
with http.server.HTTPServer((args.server, args.port), Callback) as httpd:
while not done:
httpd.handle_request()
refresh_token = req.json()['refresh_token']
print("refresh token: {}".format(refresh_token))
print("create a new id_token with our pomerium app as the audience")
req = requests.post(
token_url, {
'refresh_token': refresh_token,
'client_id': args.client_id,
'client_secret': args.client_secret,
'audience': args.pomerium_client_id,
'grant_type': 'refresh_token'
})
id_token = req.json()['id_token']
print("pomerium id_token: {}".format(id_token))
print("exchange our identity providers id token for a pomerium bearer token")
req = requests.post(args.pomerium_token_url, {'id_token': id_token})
pomerium_token = req.json()['Token']
print("pomerium bearer token is: {}".format(pomerium_token))
req = requests.get(args.pomerium_url, headers={'Authorization': 'Bearer ' + pomerium_token})
json_formatted = json.dumps(req.json(), indent=1)
print(json_formatted)
cred = PomeriumSession.from_json_file(args.cred)
response = requests.get(
args.dst,
headers={
"Authorization": "Pomerium {}".format(cred.jwt),
"Content-type": "application/json",
"Accept": "application/json",
},
)
print(
"==> request\n{}\n==> response.status_code\n{}\n==>response.text\n{}\n".format(
args.dst, response.status_code, response.text
)
)
# if response.status_code == 200:
if response.status_code == 401:
# user our refresh token to get a new cred
print("==> got a 401, let's try to refresh that credential")
response = requests.get(
args.refresh_endpoint,
headers={
"Authorization": "Pomerium {}".format(cred.refresh_token),
"Content-type": "application/json",
"Accept": "application/json",
},
)
print(
"==>request\n{}\n ==> response.status_code\n{}\nresponse.text==>\n{}\n".format(
args.refresh_endpoint, response.status_code, response.text
)
)
# update our cred!
with open(args.cred, "w", encoding="utf-8") as f:
f.write(response.text)
print("=> pomerium json credential saved to:\n{}".format(f.name))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -1,53 +0,0 @@
#!/bin/bash
# Create a new OAUTH2 provider DISTINCT from your pomerium configuration
# Select type as "OTHER"
CLIENT_ID='REPLACE-ME.apps.googleusercontent.com'
CLIENT_SECRET='REPLACE-ME'
SIGNIN_URL='https://accounts.google.com/o/oauth2/v2/auth?client_id='$CLIENT_ID'&response_type=code&scope=openid%20email&access_type=offline&redirect_uri=urn:ietf:wg:oauth:2.0:oob'
# This would be your pomerium client id
POMERIUM_CLIENT_ID='REPLACE-ME.apps.googleusercontent.com'
echo "Follow the following URL to get an offline auth code from your IdP"
echo $SIGNIN_URL
read -p 'Enter the authorization code as a result of logging in: ' CODE
echo $CODE
echo "Exchange our authorization code to get a refresh_token"
echo "refresh_tokens can be used to generate indefinite access tokens / id_tokens"
curl \
-d client_id=$CLIENT_ID \
-d client_secret=$CLIENT_SECRET \
-d code=$CODE \
-d redirect_uri=urn:ietf:wg:oauth:2.0:oob \
-d grant_type=authorization_code \
https://www.googleapis.com/oauth2/v4/token
read -p 'Enter the refresh token result:' REFRESH_TOKEN
echo $REFRESH_TOKEN
echo "Use our refresh_token to create a new id_token with an audience of pomerium's oauth client"
curl \
-d client_id=$CLIENT_ID \
-d client_secret=$CLIENT_SECRET \
-d refresh_token=$REFRESH_TOKEN \
-d grant_type=refresh_token \
-d audience=$POMERIUM_CLIENT_ID \
https://www.googleapis.com/oauth2/v4/token
echo "now we have an id_token with an audience that matches that of our pomerium app"
read -p 'Enter the resulting id_token:' ID_TOKEN
echo $ID_TOKEN
curl -X POST \
-d id_token=$ID_TOKEN \
https://authenticate.corp.beyondperimeter.com/api/v1/token
read -p 'Enter the resulting Token:' POMERIUM_ACCESS_TOKEN
echo $POMERIUM_ACCESS_TOKEN
echo "we have our bearer token that can be used with pomerium now"
curl \
-H "Authorization: Bearer ${POMERIUM_ACCESS_TOKEN}" \
"https://httpbin.corp.beyondperimeter.com/"