authenticate: encrypt & mac oauth2 callback state

- cryptutil: add hmac & tests
- cryptutil: rename cipher / encoders to be more clear
- cryptutil: simplify SecureEncoder interface
- cryptutil: renamed NewCipherFromBase64 to NewAEADCipherFromBase64
- cryptutil: move key & random generators to helpers

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2019-09-19 08:56:48 -07:00
parent 3a806c6dfc
commit 7c755d833f
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
26 changed files with 539 additions and 464 deletions

View file

@ -1,6 +1,7 @@
package authenticate // import "github.com/pomerium/pomerium/authenticate" package authenticate // import "github.com/pomerium/pomerium/authenticate"
import ( import (
"crypto/cipher"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
@ -20,10 +21,10 @@ const callbackPath = "/oauth2/callback"
// ValidateOptions checks that configuration are complete and valid. // ValidateOptions checks that configuration are complete and valid.
// Returns on first error found. // Returns on first error found.
func ValidateOptions(o config.Options) error { func ValidateOptions(o config.Options) error {
if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil { if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %v", err) return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %v", err)
} }
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil { if _, err := cryptutil.NewAEADCipherFromBase64(o.CookieSecret); err != nil {
return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %v", err) return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %v", err)
} }
if err := urlutil.ValidateURL(o.AuthenticateURL); err != nil { if err := urlutil.ValidateURL(o.AuthenticateURL); err != nil {
@ -48,7 +49,8 @@ type Authenticate struct {
cookieSecret []byte cookieSecret []byte
templates *template.Template templates *template.Template
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
cipher cryptutil.Cipher cipher cipher.AEAD
encoder cryptutil.SecureEncoder
provider identity.Authenticator provider identity.Authenticator
} }
@ -58,7 +60,8 @@ func New(opts config.Options) (*Authenticate, error) {
return nil, err return nil, err
} }
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret) decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
cipher, err := cryptutil.NewCipher(decodedCookieSecret) cipher, err := cryptutil.NewAEADCipher(decodedCookieSecret)
encoder := cryptutil.NewSecureJSONEncoder(cipher)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -72,7 +75,7 @@ func New(opts config.Options) (*Authenticate, error) {
CookieSecure: opts.CookieSecure, CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly, CookieHTTPOnly: opts.CookieHTTPOnly,
CookieExpire: opts.CookieExpire, CookieExpire: opts.CookieExpire,
CookieCipher: cipher, Encoder: encoder,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -100,6 +103,7 @@ func New(opts config.Options) (*Authenticate, error) {
templates: templates.New(), templates: templates.New(),
sessionStore: cookieStore, sessionStore: cookieStore,
cipher: cipher, cipher: cipher,
encoder: encoder,
provider: provider, provider: provider,
cookieSecret: decodedCookieSecret, cookieSecret: decodedCookieSecret,
cookieName: opts.CookieName, cookieName: opts.CookieName,

View file

@ -12,6 +12,7 @@ import (
"github.com/pomerium/csrf" "github.com/pomerium/csrf"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
@ -38,8 +39,8 @@ func (a *Authenticate) Handler() http.Handler {
a.cookieSecret, a.cookieSecret,
csrf.Path("/"), csrf.Path("/"),
csrf.Domain(a.cookieDomain), csrf.Domain(a.cookieDomain),
csrf.UnsafePaths([]string{"/oauth2/callback"}), // enforce CSRF on "safe" handler csrf.UnsafePaths([]string{callbackPath}), // enforce CSRF on "safe" handler
csrf.FormValueName("state"), // rfc6749 section-10.12 csrf.FormValueName("state"), // rfc6749 section-10.12
csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieName)), csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieName)),
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)), csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
)) ))
@ -137,15 +138,21 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
} }
// redirectToIdentityProvider starts the authenticate process by redirecting the // redirectToIdentityProvider starts the authenticate process by redirecting the
// user to their respective identity provider. // user to their respective identity provider. This function also builds the
// // 'state' parameter which is encrypted and includes authenticating data
// for validation.
// 'state' is : nonce|timestamp|redirect_url|encrypt(redirect_url)+mac(nonce,ts))
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest // https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
// https://tools.ietf.org/html/rfc6749#section-4.2.1 // https://tools.ietf.org/html/rfc6749#section-4.2.1
func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http.Request) {
redirectURL := a.RedirectURL.ResolveReference(r.URL) redirectURL := a.RedirectURL.ResolveReference(r.URL)
nonce := csrf.Token(r) nonce := csrf.Token(r)
state := fmt.Sprintf("%v:%v", nonce, redirectURL.String()) now := time.Now().Unix()
encodedState := base64.URLEncoding.EncodeToString([]byte(state)) b := []byte(fmt.Sprintf("%s|%d|", nonce, now))
enc := cryptutil.Encrypt(a.cipher, []byte(redirectURL.String()), b)
b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b)
http.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound) http.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
} }
@ -187,19 +194,33 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
if err != nil { if err != nil {
return nil, httputil.Error("malformed state", http.StatusBadRequest, err) return nil, httputil.Error("malformed state", http.StatusBadRequest, err)
} }
// split state into its it's components (nonce:redirect_uri)
statePayload := strings.SplitN(string(bytes), ":", 2) // split state into its it's components, e.g.
if len(statePayload) != 2 { // (nonce|timestamp|redirect_url|encrypted_data(redirect_url)+mac(nonce,ts))
return nil, fmt.Errorf("state malformed, size: %d", len(statePayload)) statePayload := strings.SplitN(string(bytes), "|", 3)
} if len(statePayload) != 3 {
// parse redirect_uri; ignore csrf nonce (validity asserted by middleware) return nil, httputil.Error("'state' is malformed", http.StatusBadRequest,
redirectURL, err := urlutil.ParseAndValidateURL(statePayload[1]) fmt.Errorf("state malformed, size: %d", len(statePayload)))
if err != nil {
return nil, httputil.Error("invalid redirect uri", http.StatusBadRequest, err)
} }
// todo(bdd): if we want to be _extra_ sure, we can validate that the // verify that the returned timestamp is valid (replay attack)
// redirectURL hmac is valid. But the nonce should cover the integrity... if err := cryptutil.ValidTimestamp(statePayload[1]); err != nil {
return nil, httputil.Error(err.Error(), http.StatusBadRequest, err)
}
// Use our AEAD construct to enforce secrecy and authenticity,:
// mac: to validate the nonce again, and above timestamp
// decrypt: to prevent leaking 'redirect_uri' to IdP or logs)
b := []byte(fmt.Sprint(statePayload[0], "|", statePayload[1], "|"))
redirectString, err := cryptutil.Decrypt(a.cipher, []byte(statePayload[2]), b)
if err != nil {
return nil, httputil.Error("'state' has invalid hmac", http.StatusBadRequest, err)
}
redirectURL, err := urlutil.ParseAndValidateURL(string(redirectString))
if err != nil {
return nil, httputil.Error("'state' has invalid redirect uri", http.StatusBadRequest, err)
}
// OK. Looks good so let's persist our user session // OK. Looks good so let's persist our user session
if err := a.sessionStore.SaveSession(w, r, session); err != nil { if err := a.sessionStore.SaveSession(w, r, session); err != nil {
@ -222,7 +243,7 @@ func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, err) httputil.ErrorResponse(w, r, err)
return return
} }
encToken, err := sessions.MarshalSession(session, a.cipher) encToken, err := sessions.MarshalSession(session, a.encoder)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return return

View file

@ -15,6 +15,7 @@ import (
"github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates" "github.com/pomerium/pomerium/internal/templates"
"golang.org/x/crypto/chacha20poly1305"
) )
func testAuthenticate() *Authenticate { func testAuthenticate() *Authenticate {
@ -73,22 +74,22 @@ func TestAuthenticate_SignIn(t *testing.T) {
session sessions.SessionStore session sessions.SessionStore
restStore sessions.SessionStore restStore sessions.SessionStore
provider identity.MockProvider provider identity.MockProvider
cipher cryptutil.Cipher encoder cryptutil.SecureEncoder
wantCode int wantCode int
}{ }{
{"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound}, {"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusFound},
{"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusFound}, {"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockEncoder{}, http.StatusFound},
{"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound}, {"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockEncoder{}, http.StatusFound},
{"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusFound}, // mocking hmac is meh {"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockEncoder{}, http.StatusFound}, // mocking hmac is meh
{"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound}, {"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockEncoder{}, http.StatusFound},
// {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError}, // {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusInternalServerError},
{"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound}, {"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusFound},
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound}, {"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusFound},
{"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, {"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusBadRequest},
// actually caught by go's handler, but we should keep the test. // actually caught by go's handler, but we should keep the test.
{"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest}, {"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusBadRequest},
{"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusFound}, {"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{MarshalError: errors.New("error")}, http.StatusFound},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -97,7 +98,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
provider: tt.provider, provider: tt.provider,
RedirectURL: uriParseHelper("https://some.example"), RedirectURL: uriParseHelper("https://some.example"),
SharedKey: "secret", SharedKey: "secret",
cipher: tt.cipher, encoder: tt.encoder,
} }
uri := &url.URL{Host: "corp.some.example", Scheme: "https", Path: "/"} uri := &url.URL{Host: "corp.some.example", Scheme: "https", Path: "/"}
uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI) uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI)
@ -178,14 +179,19 @@ func TestAuthenticate_SignOut(t *testing.T) {
func TestAuthenticate_OAuthCallback(t *testing.T) { func TestAuthenticate_OAuthCallback(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
name string name string
method string method string
// url params ts int64
paramErr string stateOvveride string
code string extraMac string
state string extraState string
paramErr string
code string
redirectURI string
authenticateURL string authenticateURL string
session sessions.SessionStore session sessions.SessionStore
provider identity.MockProvider provider identity.MockProvider
@ -193,30 +199,52 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
want string want string
wantCode int wantCode int
}{ }{
{"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusFound}, {"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusFound},
{"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError}, {"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
{"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusInternalServerError}, {"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusInternalServerError},
{"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest}, {"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
{"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest}, {"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
{"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest}, {"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
{"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad base64 state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")) + "%", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"too many state delimeters", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io:wait")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"too few state delimeters", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusInternalServerError}, {"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
authURL, _ := url.Parse(tt.authenticateURL) authURL, _ := url.Parse(tt.authenticateURL)
a := &Authenticate{ a := &Authenticate{
RedirectURL: authURL, RedirectURL: authURL,
sessionStore: tt.session, sessionStore: tt.session,
provider: tt.provider, provider: tt.provider,
cipher: aead,
} }
u, _ := url.Parse("/oauthGet") u, _ := url.Parse("/oauthGet")
params, _ := url.ParseQuery(u.RawQuery) params, _ := url.ParseQuery(u.RawQuery)
params.Add("error", tt.paramErr) params.Add("error", tt.paramErr)
params.Add("code", tt.code) params.Add("code", tt.code)
params.Add("state", tt.state) nonce := cryptutil.NewBase64Key() // mock csrf
// (nonce|timestamp|redirect_url|encrypt(redirect_url),mac(nonce,ts))
b := []byte(fmt.Sprintf("%s|%d|%s", nonce, tt.ts, tt.extraMac))
enc := cryptutil.Encrypt(a.cipher, []byte(tt.redirectURI), b)
b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b)
if tt.extraState != "" {
encodedState += tt.extraState
}
if tt.stateOvveride != "" {
encodedState = tt.stateOvveride
}
params.Add("state", encodedState)
u.RawQuery = params.Encode() u.RawQuery = params.Encode()
@ -240,22 +268,27 @@ func TestAuthenticate_ExchangeToken(t *testing.T) {
method string method string
idToken string idToken string
restStore sessions.SessionStore restStore sessions.SessionStore
cipher cryptutil.Cipher encoder cryptutil.SecureEncoder
provider identity.MockProvider provider identity.MockProvider
want string want string
}{ }{
{"good", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""}, {"good", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
{"could not exchange identity for session", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionError: errors.New("error")}, ""}, {"could not exchange identity for session", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionError: errors.New("error")}, ""},
{"missing token", http.MethodPost, "", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "missing id token"}, {"missing token", http.MethodPost, "", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "missing id token"},
{"malformed form", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""}, {"malformed form", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
{"can't marshal token", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{MarshalError: errors.New("can't marshal token")}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "can't marshal token"}, {"can't marshal token", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{MarshalError: errors.New("can't marshal token")}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "can't marshal token"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
a := &Authenticate{ a := &Authenticate{
cipher: tt.cipher, encoder: tt.encoder,
provider: tt.provider, provider: tt.provider,
sessionStore: tt.restStore, sessionStore: tt.restStore,
cipher: aead,
} }
form := url.Values{} form := url.Values{}
if tt.idToken != "" { if tt.idToken != "" {
@ -282,6 +315,7 @@ func TestAuthenticate_ExchangeToken(t *testing.T) {
} }
func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Content-Type-Options", "nosniff")
@ -305,13 +339,17 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
a := Authenticate{ a := Authenticate{
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"), RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
sessionStore: tt.session, sessionStore: tt.session,
provider: tt.provider, provider: tt.provider,
cipher: aead,
} }
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r) state, _ := tt.session.LoadSession(r)

View file

@ -10,6 +10,7 @@
### Security ### Security
- The user's original intended location before completing the authentication process is now encrypted and kept confidential from the identity provider. [GH-316](https://github.com/pomerium/pomerium/pull/316)
- Under certain circumstances, where debug logging was enabled, pomerium's shared secret could be leaked to http access logs as a query param. - Under certain circumstances, where debug logging was enabled, pomerium's shared secret could be leaked to http access logs as a query param.
### Fixed ### Fixed

2
go.mod
View file

@ -14,7 +14,7 @@ require (
github.com/magiconair/properties v1.8.1 // indirect github.com/magiconair/properties v1.8.1 // indirect
github.com/mitchellh/hashstructure v1.0.0 github.com/mitchellh/hashstructure v1.0.0
github.com/pelletier/go-toml v1.4.0 // indirect github.com/pelletier/go-toml v1.4.0 // indirect
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30 github.com/pomerium/csrf v1.6.2-0.20190918035251-f3318380bad3
github.com/pomerium/go-oidc v2.0.0+incompatible github.com/pomerium/go-oidc v2.0.0+incompatible
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/prometheus/client_golang v0.9.3 github.com/prometheus/client_golang v0.9.3

3
go.sum
View file

@ -126,6 +126,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30 h1:jggCv6hZvcxjGa3gqkYY2EUuOkITI9Znugz/f3QJfRQ= github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30 h1:jggCv6hZvcxjGa3gqkYY2EUuOkITI9Znugz/f3QJfRQ=
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0= github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0=
github.com/pomerium/csrf v1.6.2-0.20190918035251-f3318380bad3 h1:FmzFXnCAepHZwl6QPhTFqBHcbcGevdiEQjutK+M5bj4=
github.com/pomerium/csrf v1.6.2-0.20190918035251-f3318380bad3/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0=
github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o= github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o=
github.com/pomerium/go-oidc v2.0.0+incompatible/go.mod h1:DRsGVw6MOgxbfq4Y57jKOE8lbEfayxeiY0A8/4vxjBM= github.com/pomerium/go-oidc v2.0.0+incompatible/go.mod h1:DRsGVw6MOgxbfq4Y57jKOE8lbEfayxeiY0A8/4vxjBM=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU= github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
@ -196,6 +198,7 @@ golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmV
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8 h1:1wopBVtVdWnn03fZelqdXTqk7U7zPQCb+T4rbU9ZEoU= golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8 h1:1wopBVtVdWnn03fZelqdXTqk7U7zPQCb+T4rbU9ZEoU=
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 h1:0hQKqeLdqlt5iIwVOBErRisrHJAN57yOiPRQItI20fU=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=

View file

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"crypto/cipher" "crypto/cipher"
"crypto/rand"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -13,122 +12,47 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
) )
// DefaultKeySize is the default key size in bytes. // NewAEADCipher takes secret key and returns a new XChacha20poly1305 cipher.
const DefaultKeySize = 32 func NewAEADCipher(secret []byte) (cipher.AEAD, error) {
if len(secret) != 32 {
// NewKey generates a random 32-byte key. return nil, fmt.Errorf("cryptutil: got %d bytes but want 32", len(secret))
//
// Panics if source of randomness fails.
func NewKey() []byte {
return randomBytes(DefaultKeySize)
}
// NewBase64Key generates a random base64 encoded 32-byte key.
//
// Panics if source of randomness fails.
func NewBase64Key() string {
return NewRandomStringN(DefaultKeySize)
}
// NewRandomStringN returns base64 encoded random string of a given num of bytes.
//
// Panics if source of randomness fails.
func NewRandomStringN(c int) string {
return base64.StdEncoding.EncodeToString(randomBytes(c))
}
func randomBytes(c int) []byte {
if c < 0 {
c = DefaultKeySize
} }
b := make([]byte, c) return chacha20poly1305.NewX(secret)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return b
} }
// Cipher provides methods to encrypt and decrypt values. // NewAEADCipherFromBase64 takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher.
type Cipher interface { func NewAEADCipherFromBase64(s string) (cipher.AEAD, error) {
Encrypt([]byte) ([]byte, error)
Decrypt([]byte) ([]byte, error)
Marshal(interface{}) (string, error)
Unmarshal(string, interface{}) error
}
// XChaCha20Cipher provides methods to encrypt and decrypt values.
// Using an AEAD is a cipher providing authenticated encryption with associated data.
// For a description of the methodology, see https://en.wikipedia.org/wiki/Authenticated_encryption
type XChaCha20Cipher struct {
aead cipher.AEAD
}
// NewCipher takes secret key and returns a new XChacha20poly1305 cipher.
func NewCipher(secret []byte) (*XChaCha20Cipher, error) {
aead, err := chacha20poly1305.NewX(secret)
if err != nil {
return nil, err
}
return &XChaCha20Cipher{
aead: aead,
}, nil
}
// NewCipherFromBase64 takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher.
func NewCipherFromBase64(s string) (*XChaCha20Cipher, error) {
decoded, err := base64.StdEncoding.DecodeString(s) decoded, err := base64.StdEncoding.DecodeString(s)
if err != nil { if err != nil {
return nil, fmt.Errorf("cryptutil: invalid base64: %v", err) return nil, fmt.Errorf("cryptutil: invalid base64: %v", err)
} }
if len(decoded) != 32 { return NewAEADCipher(decoded)
return nil, fmt.Errorf("cryptutil: got %d bytes but want 32", len(decoded))
}
return NewCipher(decoded)
} }
// GenerateNonce generates a random nonce. // SecureEncoder provides and interface for to encrypt and decrypting structures .
// Panics if source of randomness fails. type SecureEncoder interface {
func (c *XChaCha20Cipher) GenerateNonce() []byte { Marshal(interface{}) (string, error)
return randomBytes(c.aead.NonceSize()) Unmarshal(string, interface{}) error
} }
// Encrypt a value using XChaCha20-Poly1305 // SecureJSONEncoder implements SecureEncoder for JSON using an AEAD cipher.
func (c *XChaCha20Cipher) Encrypt(plaintext []byte) (joined []byte, err error) { //
defer func() { // See https://en.wikipedia.org/wiki/Authenticated_encryption
if r := recover(); r != nil { type SecureJSONEncoder struct {
err = fmt.Errorf("cryptutil: error encrypting bytes: %v", r) aead cipher.AEAD
}
}()
nonce := c.GenerateNonce()
ciphertext := c.aead.Seal(nil, nonce, plaintext, nil)
// we return the nonce as part of the returned value
joined = append(ciphertext, nonce...)
return joined, nil
} }
// Decrypt a value using XChaCha20-Poly1305 // NewSecureJSONEncoder takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher.
func (c *XChaCha20Cipher) Decrypt(joined []byte) ([]byte, error) { func NewSecureJSONEncoder(aead cipher.AEAD) SecureEncoder {
if len(joined) <= c.aead.NonceSize() { return &SecureJSONEncoder{aead: aead}
return nil, fmt.Errorf("cryptutil: invalid input size: %d", len(joined))
}
// grab out the nonce
pivot := len(joined) - c.aead.NonceSize()
ciphertext := joined[:pivot]
nonce := joined[pivot:]
plaintext, err := c.aead.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
} }
// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher // Marshal marshals the interface state as JSON, encrypts the JSON using the cipher
// and base64 encodes the binary value as a string and returns the result // and base64 encodes the binary value as a string and returns the result
func (c *XChaCha20Cipher) Marshal(s interface{}) (string, error) { //
// can panic if source of random entropy is exhausted generating a nonce.
func (c *SecureJSONEncoder) Marshal(s interface{}) (string, error) {
// encode json value // encode json value
plaintext, err := json.Marshal(s) plaintext, err := json.Marshal(s)
if err != nil { if err != nil {
@ -140,10 +64,8 @@ func (c *XChaCha20Cipher) Marshal(s interface{}) (string, error) {
return "", err return "", err
} }
// encrypt the compressed JSON bytes // encrypt the compressed JSON bytes
ciphertext, err := c.Encrypt(compressed) ciphertext := Encrypt(c.aead, compressed, nil)
if err != nil {
return "", err
}
// base64-encode the result // base64-encode the result
encoded := base64.RawURLEncoding.EncodeToString(ciphertext) encoded := base64.RawURLEncoding.EncodeToString(ciphertext)
return encoded, nil return encoded, nil
@ -151,14 +73,14 @@ func (c *XChaCha20Cipher) Marshal(s interface{}) (string, error) {
// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the // Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the
// byte slice the passed cipher, and unmarshals the resulting JSON into the struct pointer passed // byte slice the passed cipher, and unmarshals the resulting JSON into the struct pointer passed
func (c *XChaCha20Cipher) Unmarshal(value string, s interface{}) error { func (c *SecureJSONEncoder) Unmarshal(value string, s interface{}) error {
// convert base64 string value to bytes // convert base64 string value to bytes
ciphertext, err := base64.RawURLEncoding.DecodeString(value) ciphertext, err := base64.RawURLEncoding.DecodeString(value)
if err != nil { if err != nil {
return err return err
} }
// decrypt the bytes // decrypt the bytes
compressed, err := c.Decrypt(ciphertext) compressed, err := Decrypt(c.aead, ciphertext, nil)
if err != nil { if err != nil {
return err return err
} }
@ -172,10 +94,10 @@ func (c *XChaCha20Cipher) Unmarshal(value string, s interface{}) error {
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
// compress gzips a set of bytes
func compress(data []byte) ([]byte, error) { func compress(data []byte) ([]byte, error) {
var buf bytes.Buffer var buf bytes.Buffer
writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression) writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression)
@ -194,6 +116,7 @@ func compress(data []byte) ([]byte, error) {
return buf.Bytes(), nil return buf.Bytes(), nil
} }
// decompress un-gzips a set of bytes
func decompress(data []byte) ([]byte, error) { func decompress(data []byte) ([]byte, error) {
reader, err := gzip.NewReader(bytes.NewReader(data)) reader, err := gzip.NewReader(bytes.NewReader(data))
if err != nil { if err != nil {
@ -206,3 +129,27 @@ func decompress(data []byte) ([]byte, error) {
} }
return buf.Bytes(), nil return buf.Bytes(), nil
} }
// Encrypt encrypts a value with optional associated data
//
// Panics if source of randomness fails.
func Encrypt(a cipher.AEAD, data, ad []byte) []byte {
iv := randomBytes(a.NonceSize())
ciphertext := a.Seal(nil, iv, data, ad)
return append(ciphertext, iv...)
}
// Decrypt a value with optional associated data
func Decrypt(a cipher.AEAD, data, ad []byte) ([]byte, error) {
if len(data) <= a.NonceSize() {
return nil, fmt.Errorf("cryptutil: invalid input size: %d", len(data))
}
size := len(data) - a.NonceSize()
ciphertext := data[:size]
nonce := data[size:]
plaintext, err := a.Open(nil, nonce, ciphertext, ad)
if err != nil {
return nil, err
}
return plaintext, nil
}

View file

@ -1,12 +1,8 @@
package cryptutil package cryptutil
import ( import (
"crypto/rand"
"crypto/sha1"
"encoding/base64" "encoding/base64"
"fmt"
"reflect" "reflect"
"sync"
"testing" "testing"
) )
@ -14,27 +10,24 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
plaintext := []byte("my plain text value") plaintext := []byte("my plain text value")
key := NewKey() key := NewKey()
c, err := NewCipher(key) c, err := NewAEADCipher(key)
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }
ciphertext, err := c.Encrypt(plaintext) ciphertext := Encrypt(c, plaintext, nil)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if reflect.DeepEqual(plaintext, ciphertext) { if reflect.DeepEqual(plaintext, ciphertext) {
t.Fatalf("plaintext is not encrypted plaintext:%v ciphertext:%x", plaintext, ciphertext) t.Fatalf("plaintext is not encrypted plaintext:%v ciphertext:%x", plaintext, ciphertext)
} }
got, err := c.Decrypt(ciphertext) got, err := Decrypt(c, ciphertext, nil)
if err != nil { if err != nil {
t.Fatalf("unexpected err decrypting: %v", err) t.Fatalf("unexpected err decrypting: %v", err)
} }
// if less than 32 bytes, fail // if less than 32 bytes, fail
_, err = c.Decrypt([]byte("oh")) _, err = Decrypt(c, []byte("oh"), nil)
if err == nil { if err == nil {
t.Fatalf("should fail if <32 bytes output: %v", err) t.Fatalf("should fail if <32 bytes output: %v", err)
} }
@ -49,10 +42,11 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
func TestMarshalAndUnmarshalStruct(t *testing.T) { func TestMarshalAndUnmarshalStruct(t *testing.T) {
key := NewKey() key := NewKey()
c, err := NewCipher(key) a, err := NewAEADCipher(key)
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }
c := SecureJSONEncoder{aead: a}
type TC struct { type TC struct {
Field string `json:"field"` Field string `json:"field"`
@ -101,102 +95,7 @@ func TestMarshalAndUnmarshalStruct(t *testing.T) {
} }
} }
func TestCipherDataRace(t *testing.T) { func TestSecureJSONEncoder_Marshal(t *testing.T) {
cipher, err := NewCipher(NewKey())
if err != nil {
t.Fatalf("unexpected generating cipher err: %v", err)
}
type TC struct {
Field string `json:"field"`
}
wg := &sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func(c *XChaCha20Cipher, wg *sync.WaitGroup) {
defer wg.Done()
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
t.Fatalf("unexecpted error reading random bytes: %v", err)
}
sha := fmt.Sprintf("%x", sha1.New().Sum(b))
tc := &TC{
Field: sha,
}
value1, err := c.Marshal(tc)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
value2, err := c.Marshal(tc)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if value1 == value2 {
t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2)
}
got1 := &TC{}
err = c.Unmarshal(value1, got1)
if err != nil {
t.Fatalf("unexpected err unmarshalling struct: %v", err)
}
if !reflect.DeepEqual(got1, tc) {
t.Logf("want: %#v", tc)
t.Logf(" got: %#v", got1)
t.Fatalf("expected structs to be equal")
}
got2 := &TC{}
err = c.Unmarshal(value2, got2)
if err != nil {
t.Fatalf("unexpected err unmarshalling struct: %v", err)
}
if !reflect.DeepEqual(got1, got2) {
t.Logf("got2: %#v", got2)
t.Logf("got1: %#v", got1)
t.Fatalf("expected structs to be equal")
}
}(cipher, wg)
}
wg.Wait()
}
func TestGenerateRandomString(t *testing.T) {
t.Parallel()
tests := []struct {
name string
c int
want int
}{
{"simple", 32, 32},
{"zero", 0, 0},
{"negative", -1, 32},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := NewRandomStringN(tt.c)
b, err := base64.StdEncoding.DecodeString(o)
if err != nil {
t.Error(err)
}
got := len(b)
if got != tt.want {
t.Errorf("NewRandomStringN() = %d, want %d", got, tt.want)
}
})
}
}
func TestXChaCha20Cipher_Marshal(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
name string name string
@ -225,20 +124,22 @@ func TestXChaCha20Cipher_Marshal(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c, err := NewCipher(NewKey()) c, err := NewAEADCipher(NewKey())
if err != nil { if err != nil {
t.Fatalf("unexpected err: %v", err) t.Fatalf("unexpected err: %v", err)
} }
_, err = c.Marshal(tt.s) e := SecureJSONEncoder{aead: c}
_, err = e.Marshal(tt.s)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("XChaCha20Cipher.Marshal() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("SecureJSONEncoder.Marshal() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
}) })
} }
} }
func TestNewCipher(t *testing.T) { func TestNewAEADCipher(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
name string name string
@ -251,16 +152,16 @@ func TestNewCipher(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := NewCipher(tt.secret) _, err := NewAEADCipher(tt.secret)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("NewCipher() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("NewAEADCipher() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
}) })
} }
} }
func TestNewCipherFromBase64(t *testing.T) { func TestNewAEADCipherFromBase64(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
name string name string
@ -274,34 +175,11 @@ func TestNewCipherFromBase64(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := NewCipherFromBase64(tt.s) _, err := NewAEADCipherFromBase64(tt.s)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("NewCipherFromBase64() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("NewAEADCipherFromBase64() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
}) })
} }
} }
func TestNewBase64Key(t *testing.T) {
t.Parallel()
tests := []struct {
name string
want int
}{
{"simple", 32},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := NewBase64Key()
b, err := base64.StdEncoding.DecodeString(o)
if err != nil {
t.Error(err)
}
got := len(b)
if got != tt.want {
t.Errorf("NewBase64Key() = %d, want %d", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,45 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import (
"crypto/rand"
"encoding/base64"
)
// DefaultKeySize is the default key size in bytes.
const DefaultKeySize = 32
// NewKey generates a random 32-byte (256 bit) key.
//
// Panics if source of randomness fails.
func NewKey() []byte {
return randomBytes(DefaultKeySize)
}
// NewBase64Key generates a random base64 encoded 32-byte key.
//
// Panics if source of randomness fails.
func NewBase64Key() string {
return NewRandomStringN(DefaultKeySize)
}
// NewRandomStringN returns base64 encoded random string of a given num of bytes.
//
// Panics if source of randomness fails.
func NewRandomStringN(c int) string {
return base64.StdEncoding.EncodeToString(randomBytes(c))
}
// randomBytes generates C number of random bytes suitable for cryptographic
// operations.
//
// Panics if source of randomness fails.
func randomBytes(c int) []byte {
if c < 0 {
c = DefaultKeySize
}
b := make([]byte, c)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return b
}

View file

@ -0,0 +1,55 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import (
"encoding/base64"
"testing"
)
func TestGenerateRandomString(t *testing.T) {
t.Parallel()
tests := []struct {
name string
c int
want int
}{
{"simple", 32, 32},
{"zero", 0, 0},
{"negative", -1, 32},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := NewRandomStringN(tt.c)
b, err := base64.StdEncoding.DecodeString(o)
if err != nil {
t.Error(err)
}
got := len(b)
if got != tt.want {
t.Errorf("NewRandomStringN() = %d, want %d", got, tt.want)
}
})
}
}
func TestNewBase64Key(t *testing.T) {
t.Parallel()
tests := []struct {
name string
want int
}{
{"simple", 32},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := NewBase64Key()
b, err := base64.StdEncoding.DecodeString(o)
if err != nil {
t.Error(err)
}
got := len(b)
if got != tt.want {
t.Errorf("NewBase64Key() = %d, want %d", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,50 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
import (
"crypto/hmac"
"crypto/sha512"
"errors"
"strconv"
"time"
)
var (
errTimestampMalformed = errors.New("internal/cryptutil: timestamp malformed")
errTimestampExpired = errors.New("internal/cryptutil: timestamp expired")
errTimestampTooSoon = errors.New("internal/cryptutil: timestamp too soon")
)
// GenerateHMAC produces a symmetric signature using a shared secret key.
func GenerateHMAC(data []byte, key string) []byte {
h := hmac.New(sha512.New512_256, []byte(key))
h.Write(data)
return h.Sum(nil)
}
// CheckHMAC securely checks the supplied MAC against a message using the
// shared secret key.
func CheckHMAC(data, suppliedMAC []byte, key string) bool {
expectedMAC := GenerateHMAC(data, key)
return hmac.Equal(expectedMAC, suppliedMAC)
}
// ValidTimestamp is a helper function often used in conjunction with an HMAC
// function to verify that the timestamp (in unix seconds) is within leeway
// period.
// todo(bdd) : should leeway be configurable?
func ValidTimestamp(ts string) error {
var timeStamp int64
var err error
if timeStamp, err = strconv.ParseInt(ts, 10, 64); err != nil {
return errTimestampMalformed
}
// unix time in seconds
tm := time.Unix(timeStamp, 0)
if time.Since(tm) > DefaultLeeway {
return errTimestampExpired
}
if time.Until(tm) > DefaultLeeway {
return errTimestampTooSoon
}
return nil
}

View file

@ -0,0 +1,70 @@
package cryptutil
import (
"bytes"
"encoding/hex"
"fmt"
"testing"
"time"
)
func TestHMAC(t *testing.T) {
// https://groups.google.com/d/msg/sci.crypt/OolWgsgQD-8/jHciyWkaL0gJ
hmacTests := []struct {
key string
data string
digest string
}{
{
key: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
data: "4869205468657265", // "Hi There"
digest: "9f9126c3d9c3c330d760425ca8a217e31feae31bfe70196ff81642b868402eab",
},
{
key: "4a656665", // "Jefe"
data: "7768617420646f2079612077616e7420666f72206e6f7468696e673f", // "what do ya want for nothing?"
digest: "6df7b24630d5ccb2ee335407081a87188c221489768fa2020513b2d593359456",
},
}
for idx, tt := range hmacTests {
keySlice, _ := hex.DecodeString(tt.key)
dataBytes, _ := hex.DecodeString(tt.data)
expectedDigest, _ := hex.DecodeString(tt.digest)
keyBytes := &[32]byte{}
copy(keyBytes[:], keySlice)
macDigest := GenerateHMAC(dataBytes, string(keyBytes[:]))
if !bytes.Equal(macDigest, expectedDigest) {
t.Errorf("test %d generated unexpected mac", idx)
}
if !CheckHMAC(dataBytes, macDigest, string(keyBytes[:])) {
t.Errorf("test %d generated unexpected mac", idx)
}
}
}
func TestValidTimestamp(t *testing.T) {
t.Parallel()
tests := []struct {
name string
ts string
wantErr bool
}{
{"good - now", fmt.Sprint(time.Now().Unix()), false},
{"good - now - 200ms", fmt.Sprint(time.Now().Add(-200 * time.Millisecond).Unix()), false},
{"good - now + 200ms", fmt.Sprint(time.Now().Add(200 * time.Millisecond).Unix()), false},
{"bad - now + 10m", fmt.Sprint(time.Now().Add(10 * time.Minute).Unix()), true},
{"bad - now - 10m", fmt.Sprint(time.Now().Add(-10 * time.Minute).Unix()), true},
{"malformed - non int", fmt.Sprint("pomerium"), true},
{"malformed - negative number", fmt.Sprint("-1"), true},
{"malformed - huge number", fmt.Sprintf("%d", 10*10000000000), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ValidTimestamp(tt.ts); (err != nil) != tt.wantErr {
t.Errorf("ValidTimestamp() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -1,28 +1,18 @@
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil" package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
// MockCipher MockCSRFStore is a mock implementation of Cipher. // MockEncoder MockCSRFStore is a mock implementation of Cipher.
type MockCipher struct { type MockEncoder struct {
EncryptResponse []byte
EncryptError error
DecryptResponse []byte
DecryptError error
MarshalResponse string MarshalResponse string
MarshalError error MarshalError error
UnmarshalError error UnmarshalError error
} }
// Encrypt is a mock implementation of MockCipher. // Marshal is a mock implementation of MockEncoder.
func (mc MockCipher) Encrypt(b []byte) ([]byte, error) { return mc.EncryptResponse, mc.EncryptError } func (mc MockEncoder) Marshal(i interface{}) (string, error) {
// Decrypt is a mock implementation of MockCipher.
func (mc MockCipher) Decrypt(b []byte) ([]byte, error) { return mc.DecryptResponse, mc.DecryptError }
// Marshal is a mock implementation of MockCipher.
func (mc MockCipher) Marshal(i interface{}) (string, error) {
return mc.MarshalResponse, mc.MarshalError return mc.MarshalResponse, mc.MarshalError
} }
// Unmarshal is a mock implementation of MockCipher. // Unmarshal is a mock implementation of MockEncoder.
func (mc MockCipher) Unmarshal(s string, i interface{}) error { func (mc MockEncoder) Unmarshal(s string, i interface{}) error {
return mc.UnmarshalError return mc.UnmarshalError
} }

View file

@ -5,31 +5,13 @@ import (
"testing" "testing"
) )
func TestMockCipher_Unmarshal(t *testing.T) { func TestMockEncoder(t *testing.T) {
e := errors.New("err") e := errors.New("err")
mc := MockCipher{ mc := MockEncoder{
EncryptResponse: []byte("EncryptResponse"),
EncryptError: e,
DecryptResponse: []byte("DecryptResponse"),
DecryptError: e,
MarshalResponse: "MarshalResponse", MarshalResponse: "MarshalResponse",
MarshalError: e, MarshalError: e,
UnmarshalError: e, UnmarshalError: e,
} }
b, err := mc.Encrypt([]byte("test"))
if string(b) != "EncryptResponse" {
t.Error("unexpected encrypt response")
}
if err != e {
t.Error("unexpected encrypt error")
}
b, err = mc.Decrypt([]byte("test"))
if string(b) != "DecryptResponse" {
t.Error("unexpected Decrypt response")
}
if err != e {
t.Error("unexpected Decrypt error")
}
s, err := mc.Marshal("test") s, err := mc.Marshal("test")
if err != e { if err != e {
t.Error("unexpected Marshal error") t.Error("unexpected Marshal error")

View file

@ -9,6 +9,11 @@ import (
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
) )
const (
// DefaultLeeway defines the default leeway for matching NotBefore/Expiry claims.
DefaultLeeway = 5.0 * time.Minute
)
// JWTSigner implements JWT signing according to JSON Web Token (JWT) RFC7519 // JWTSigner implements JWT signing according to JSON Web Token (JWT) RFC7519
// https://tools.ietf.org/html/rfc7519 // https://tools.ietf.org/html/rfc7519
type JWTSigner interface { type JWTSigner interface {
@ -92,8 +97,8 @@ func (s *ES256Signer) SignJWT(user, email, groups string) (string, error) {
s.Groups = groups s.Groups = groups
now := time.Now() now := time.Now()
s.IssuedAt = *jwt.NewNumericDate(now) s.IssuedAt = *jwt.NewNumericDate(now)
s.Expiry = *jwt.NewNumericDate(now.Add(jwt.DefaultLeeway)) s.Expiry = *jwt.NewNumericDate(now.Add(DefaultLeeway))
s.NotBefore = *jwt.NewNumericDate(now.Add(-1 * jwt.DefaultLeeway)) s.NotBefore = *jwt.NewNumericDate(now.Add(-1 * DefaultLeeway))
rawJWT, err := jwt.Signed(s.signer).Claims(s).CompactSerialize() rawJWT, err := jwt.Signed(s.signer).Claims(s).CompactSerialize()
if err != nil { if err != nil {
return "", fmt.Errorf("cryptutil: sign failed %v", err) return "", fmt.Errorf("cryptutil: sign failed %v", err)

View file

@ -1,11 +1,13 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil" package httputil
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/gorilla/mux"
) )
func TestCSRFFailureHandler(t *testing.T) { func TestCSRFFailureHandler(t *testing.T) {
@ -35,3 +37,19 @@ func TestCSRFFailureHandler(t *testing.T) {
}) })
} }
} }
func TestNewRouter(t *testing.T) {
tests := []struct {
name string
want *mux.Router
}{
{"this is a gorilla router right?", mux.NewRouter()},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewRouter(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewRouter() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,14 +1,11 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware" package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import ( import (
"crypto/hmac"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"time"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
@ -183,22 +180,8 @@ func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
if err != nil { if err != nil {
return false return false
} }
i, err := strconv.ParseInt(timestamp, 10, 64) if err := cryptutil.ValidTimestamp(timestamp); err != nil {
if err != nil {
return false return false
} }
tm := time.Unix(i, 0) return cryptutil.CheckHMAC([]byte(fmt.Sprint(redirectURI, timestamp)), requestSig, secret)
ttl := 5 * time.Minute
if time.Since(tm) > ttl {
return false
}
localSig := redirectURLSignature(redirectURI, tm, secret)
return hmac.Equal(requestSig, localSig)
}
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
h := cryptutil.Hash(secret, data)
return h
} }

View file

@ -8,8 +8,15 @@ import (
"net/url" "net/url"
"testing" "testing"
"time" "time"
"github.com/pomerium/pomerium/internal/cryptutil"
) )
func hmacHelperFunc(rawRedirect string, timestamp time.Time, secret string) []byte {
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
return cryptutil.GenerateHMAC(data, secret)
}
func Test_SameDomain(t *testing.T) { func Test_SameDomain(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
@ -45,7 +52,7 @@ func Test_ValidSignature(t *testing.T) {
goodURL := "https://example.com/redirect" goodURL := "https://example.com/redirect"
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=" secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
now := fmt.Sprint(time.Now().Unix()) now := fmt.Sprint(time.Now().Unix())
rawSig := redirectURLSignature(goodURL, time.Now(), secretA) rawSig := hmacHelperFunc(goodURL, time.Now(), secretA)
sig := base64.URLEncoding.EncodeToString(rawSig) sig := base64.URLEncoding.EncodeToString(rawSig)
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix()) staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
@ -73,27 +80,6 @@ func Test_ValidSignature(t *testing.T) {
} }
} }
func Test_redirectURLSignature(t *testing.T) {
tests := []struct {
name string
rawRedirect string
timestamp time.Time
secret string
want string
}{
{"good signature", "https://example.com/redirect", time.Unix(1546797901, 0), "K3yqsJPahIzu5CdfCVJlIK4N8Dc135-27Tg1ROuQdhc=", "XeVJC2Iysq7mRUwOL3FX_5vx1d_kZV2HONHNig9fcKk="},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := redirectURLSignature(tt.rawRedirect, tt.timestamp, tt.secret)
out := base64.URLEncoding.EncodeToString(got)
if out != tt.want {
t.Errorf("redirectURLSignature() = %v, want %v", tt.want, out)
}
})
}
}
func TestSetHeaders(t *testing.T) { func TestSetHeaders(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -209,7 +195,7 @@ func TestValidateSignature(t *testing.T) {
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=" secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
now := fmt.Sprint(time.Now().Unix()) now := fmt.Sprint(time.Now().Unix())
goodURL := "https://example.com/redirect" goodURL := "https://example.com/redirect"
rawSig := redirectURLSignature(goodURL, time.Now(), secretA) rawSig := hmacHelperFunc(goodURL, time.Now(), secretA)
sig := base64.URLEncoding.EncodeToString(rawSig) sig := base64.URLEncoding.EncodeToString(rawSig)
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix()) staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())

View file

@ -33,7 +33,7 @@ const MaxNumChunks = 5
// CookieStore represents all the cookie related configurations // CookieStore represents all the cookie related configurations
type CookieStore struct { type CookieStore struct {
Name string Name string
CookieCipher cryptutil.Cipher Encoder cryptutil.SecureEncoder
CookieExpire time.Duration CookieExpire time.Duration
CookieRefresh time.Duration CookieRefresh time.Duration
CookieSecure bool CookieSecure bool
@ -50,7 +50,7 @@ type CookieStoreOptions struct {
CookieDomain string CookieDomain string
BearerTokenHeader string BearerTokenHeader string
CookieExpire time.Duration CookieExpire time.Duration
CookieCipher cryptutil.Cipher Encoder cryptutil.SecureEncoder
} }
// NewCookieStore returns a new session with ciphers for each of the cookie secrets // NewCookieStore returns a new session with ciphers for each of the cookie secrets
@ -58,7 +58,7 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
if opts.Name == "" { if opts.Name == "" {
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty") return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
} }
if opts.CookieCipher == nil { if opts.Encoder == nil {
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil") return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
} }
if opts.BearerTokenHeader == "" { if opts.BearerTokenHeader == "" {
@ -71,7 +71,7 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
CookieHTTPOnly: opts.CookieHTTPOnly, CookieHTTPOnly: opts.CookieHTTPOnly,
CookieDomain: opts.CookieDomain, CookieDomain: opts.CookieDomain,
CookieExpire: opts.CookieExpire, CookieExpire: opts.CookieExpire,
CookieCipher: opts.CookieCipher, Encoder: opts.Encoder,
BearerTokenHeader: opts.BearerTokenHeader, BearerTokenHeader: opts.BearerTokenHeader,
}, nil }, nil
} }
@ -188,7 +188,7 @@ func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
if cipherText == "" { if cipherText == "" {
return nil, ErrEmptySession return nil, ErrEmptySession
} }
session, err := UnmarshalSession(cipherText, cs.CookieCipher) session, err := UnmarshalSession(cipherText, cs.Encoder)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -197,7 +197,7 @@ func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
// SaveSession saves a session state to a request sessions. // SaveSession saves a session state to a request sessions.
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error { func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
value, err := MarshalSession(s, cs.CookieCipher) value, err := MarshalSession(s, cs.Encoder)
if err != nil { if err != nil {
return err return err
} }

View file

@ -15,33 +15,21 @@ import (
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
) )
type mockCipher struct{} type MockEncoder struct{}
func (a mockCipher) Encrypt(s []byte) ([]byte, error) { func (a MockEncoder) Marshal(s interface{}) (string, error) { return "", errors.New("error") }
if string(s) == "error" { func (a MockEncoder) Unmarshal(s string, i interface{}) error {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Marshal(s interface{}) (string, error) { return "", errors.New("error") }
func (a mockCipher) Unmarshal(s string, i interface{}) error {
if s == "unmarshal error" || s == "error" { if s == "unmarshal error" || s == "error" {
return errors.New("error") return errors.New("error")
} }
return nil return nil
} }
func TestNewCookieStore(t *testing.T) { func TestNewCookieStore(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.NewKey()) cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
encoder := cryptutil.NewSecureJSONEncoder(cipher)
tests := []struct { tests := []struct {
name string name string
opts *CookieStoreOptions opts *CookieStoreOptions
@ -55,7 +43,7 @@ func TestNewCookieStore(t *testing.T) {
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieDomain: "pomerium.io", CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second, CookieExpire: 10 * time.Second,
CookieCipher: cipher, Encoder: encoder,
BearerTokenHeader: "Authorization", BearerTokenHeader: "Authorization",
}, },
&CookieStore{ &CookieStore{
@ -64,7 +52,7 @@ func TestNewCookieStore(t *testing.T) {
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieDomain: "pomerium.io", CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second, CookieExpire: 10 * time.Second,
CookieCipher: cipher, Encoder: encoder,
BearerTokenHeader: "Authorization", BearerTokenHeader: "Authorization",
}, },
false}, false},
@ -75,7 +63,7 @@ func TestNewCookieStore(t *testing.T) {
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieDomain: "pomerium.io", CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second, CookieExpire: 10 * time.Second,
CookieCipher: cipher, Encoder: encoder,
BearerTokenHeader: "Authorization", BearerTokenHeader: "Authorization",
}, },
nil, nil,
@ -87,7 +75,7 @@ func TestNewCookieStore(t *testing.T) {
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieDomain: "pomerium.io", CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second, CookieExpire: 10 * time.Second,
CookieCipher: nil, Encoder: nil,
}, },
nil, nil,
true}, true},
@ -100,7 +88,7 @@ func TestNewCookieStore(t *testing.T) {
return return
} }
cmpOpts := []cmp.Option{ cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(cryptutil.XChaCha20Cipher{}), cmpopts.IgnoreUnexported(cryptutil.SecureJSONEncoder{}),
} }
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" { if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
@ -111,7 +99,8 @@ func TestNewCookieStore(t *testing.T) {
} }
func TestCookieStore_makeCookie(t *testing.T) { func TestCookieStore_makeCookie(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.NewKey()) cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -145,7 +134,7 @@ func TestCookieStore_makeCookie(t *testing.T) {
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieDomain: tt.cookieDomain, CookieDomain: tt.cookieDomain,
CookieExpire: 10 * time.Second, CookieExpire: 10 * time.Second,
CookieCipher: cipher}) Encoder: cryptutil.NewSecureJSONEncoder(cipher)})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -161,10 +150,12 @@ func TestCookieStore_makeCookie(t *testing.T) {
} }
func TestCookieStore_SaveSession(t *testing.T) { func TestCookieStore_SaveSession(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.NewKey()) c, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
cipher := cryptutil.NewSecureJSONEncoder(c)
hugeString := make([]byte, 4097) hugeString := make([]byte, 4097)
if _, err := rand.Read(hugeString); err != nil { if _, err := rand.Read(hugeString); err != nil {
t.Fatal(err) t.Fatal(err)
@ -172,12 +163,12 @@ func TestCookieStore_SaveSession(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
State *State State *State
cipher cryptutil.Cipher cipher cryptutil.SecureEncoder
wantErr bool wantErr bool
wantLoadErr bool wantLoadErr bool
}{ }{
{"good", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false}, {"good", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
{"bad cipher", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, mockCipher{}, true, true}, {"bad cipher", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, MockEncoder{}, true, true},
{"huge cookie", &State{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false}, {"huge cookie", &State{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
} }
for _, tt := range tests { for _, tt := range tests {
@ -188,7 +179,7 @@ func TestCookieStore_SaveSession(t *testing.T) {
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieDomain: "pomerium.io", CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second, CookieExpire: 10 * time.Second,
CookieCipher: tt.cipher} Encoder: tt.cipher}
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()

View file

@ -81,11 +81,13 @@ func TestVerifier(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewCipherFromBase64(cryptutil.NewBase64Key()) cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := cryptutil.NewSecureJSONEncoder(cipher)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
encSession, err := MarshalSession(&tt.state, cipher) encSession, err := MarshalSession(&tt.state, encoder)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -96,8 +98,8 @@ func TestVerifier(t *testing.T) {
} }
cs, err := NewCookieStore(&CookieStoreOptions{ cs, err := NewCookieStore(&CookieStoreOptions{
Name: "_pomerium", Name: "_pomerium",
CookieCipher: cipher, Encoder: encoder,
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View file

@ -89,7 +89,7 @@ func (s *State) IssuedAt() (time.Time, error) {
// MarshalSession marshals the session state as JSON, encrypts the JSON using the // MarshalSession marshals the session state as JSON, encrypts the JSON using the
// given cipher, and base64-encodes the result // given cipher, and base64-encodes the result
func MarshalSession(s *State, c cryptutil.Cipher) (string, error) { func MarshalSession(s *State, c cryptutil.SecureEncoder) (string, error) {
v, err := c.Marshal(s) v, err := c.Marshal(s)
if err != nil { if err != nil {
return "", err return "", err
@ -99,7 +99,7 @@ func MarshalSession(s *State, c cryptutil.Cipher) (string, error) {
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the // UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
// byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct // byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct
func UnmarshalSession(value string, c cryptutil.Cipher) (*State, error) { func UnmarshalSession(value string, c cryptutil.SecureEncoder) (*State, error) {
s := &State{} s := &State{}
err := c.Unmarshal(value, s) err := c.Unmarshal(value, s)
if err != nil { if err != nil {

View file

@ -13,7 +13,8 @@ import (
func TestStateSerialization(t *testing.T) { func TestStateSerialization(t *testing.T) {
secret := cryptutil.NewKey() secret := cryptutil.NewKey()
c, err := cryptutil.NewCipher(secret) cipher, err := cryptutil.NewAEADCipher(secret)
c := cryptutil.NewSecureJSONEncoder(cipher)
if err != nil { if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err) t.Fatalf("expected to be able to create cipher: %v", err)
} }
@ -124,10 +125,12 @@ func TestState_Impersonating(t *testing.T) {
func TestMarshalSession(t *testing.T) { func TestMarshalSession(t *testing.T) {
secret := cryptutil.NewKey() secret := cryptutil.NewKey()
c, err := cryptutil.NewCipher(secret) cipher, err := cryptutil.NewAEADCipher(secret)
if err != nil { if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err) t.Fatalf("expected to be able to create cipher: %v", err)
} }
c := cryptutil.NewSecureJSONEncoder(cipher)
hugeString := make([]byte, 4097) hugeString := make([]byte, 4097)
if _, err := rand.Read(hugeString); err != nil { if _, err := rand.Read(hugeString); err != nil {
t.Fatal(err) t.Fatal(err)

View file

@ -81,9 +81,9 @@ func timestamp() int64 {
// SignedRedirectURL takes a destination URL and adds redirect_uri to it's // SignedRedirectURL takes a destination URL and adds redirect_uri to it's
// query params, along with a timestamp and an keyed signature. // query params, along with a timestamp and an keyed signature.
func SignedRedirectURL(key string, destination, urlToSign *url.URL) *url.URL { func SignedRedirectURL(key string, destination, u *url.URL) *url.URL {
now := timestamp() now := timestamp()
rawURL := urlToSign.String() rawURL := u.String()
params, _ := url.ParseQuery(destination.RawQuery) // handled by incoming mux params, _ := url.ParseQuery(destination.RawQuery) // handled by incoming mux
params.Set("redirect_uri", rawURL) params.Set("redirect_uri", rawURL)
params.Set("ts", fmt.Sprint(now)) params.Set("ts", fmt.Sprint(now))
@ -95,7 +95,7 @@ func SignedRedirectURL(key string, destination, urlToSign *url.URL) *url.URL {
// hmacURL takes a redirect url string and timestamp and returns the base64 // hmacURL takes a redirect url string and timestamp and returns the base64
// encoded HMAC result. // encoded HMAC result.
func hmacURL(key, data string, timestamp int64) string { func hmacURL(key, data string, timestamp int64) string {
h := cryptutil.Hash(key, []byte(fmt.Sprint(data, timestamp))) h := cryptutil.GenerateHMAC([]byte(fmt.Sprint(data, timestamp)), key)
return base64.URLEncoding.EncodeToString(h) return base64.URLEncoding.EncodeToString(h)
} }

View file

@ -123,7 +123,7 @@ func TestProxy_router(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"} p.encoder = &cryptutil.MockEncoder{MarshalResponse: "foo"}
req := httptest.NewRequest(http.MethodGet, tt.host, nil) req := httptest.NewRequest(http.MethodGet, tt.host, nil)
_, ok := p.router(req) _, ok := p.router(req)
@ -203,7 +203,7 @@ func TestProxy_Proxy(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"} p.encoder = &cryptutil.MockEncoder{MarshalResponse: "foo"}
p.sessionStore = tt.session p.sessionStore = tt.session
p.AuthorizeClient = tt.authorizer p.AuthorizeClient = tt.authorizer
r := httptest.NewRequest(tt.method, tt.host, nil) r := httptest.NewRequest(tt.method, tt.host, nil)
@ -231,17 +231,17 @@ func TestProxy_UserDashboard(t *testing.T) {
ctxError error ctxError error
options config.Options options config.Options
method string method string
cipher cryptutil.Cipher cipher cryptutil.SecureEncoder
session sessions.SessionStore session sessions.SessionStore
authorizer clients.Authorizer authorizer clients.Authorizer
wantAdminForm bool wantAdminForm bool
wantStatus int wantStatus int
}{ }{
{"good", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK}, {"good", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK},
{"session context error", errors.New("error"), opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusInternalServerError}, {"session context error", errors.New("error"), opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
{"want admin form good admin authorization", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK}, {"want admin form good admin authorization", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
{"is admin but authorization fails", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError}, {"is admin but authorization fails", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
} }
for _, tt := range tests { for _, tt := range tests {
@ -250,7 +250,7 @@ func TestProxy_UserDashboard(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.cipher = tt.cipher p.encoder = tt.cipher
p.sessionStore = tt.session p.sessionStore = tt.session
p.AuthorizeClient = tt.authorizer p.AuthorizeClient = tt.authorizer
@ -289,17 +289,17 @@ func TestProxy_ForceRefresh(t *testing.T) {
ctxError error ctxError error
options config.Options options config.Options
method string method string
cipher cryptutil.Cipher cipher cryptutil.SecureEncoder
session sessions.SessionStore session sessions.SessionStore
authorizer clients.Authorizer authorizer clients.Authorizer
wantStatus int wantStatus int
}{ }{
{"good", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound}, {"good", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
{"cannot load session", errors.New("error"), opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError}, {"cannot load session", errors.New("error"), opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
{"bad id token", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError}, {"bad id token", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
{"issue date too soon", nil, timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest}, {"issue date too soon", nil, timeSinceError, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest},
{"refresh failure", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound}, {"refresh failure", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
{"can't save refreshed session", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError}, {"can't save refreshed session", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -307,7 +307,7 @@ func TestProxy_ForceRefresh(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.cipher = tt.cipher p.encoder = tt.cipher
p.sessionStore = tt.session p.sessionStore = tt.session
p.AuthorizeClient = tt.authorizer p.AuthorizeClient = tt.authorizer
@ -340,18 +340,18 @@ func TestProxy_Impersonate(t *testing.T) {
email string email string
groups string groups string
csrf string csrf string
cipher cryptutil.Cipher cipher cryptutil.SecureEncoder
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
authorizer clients.Authorizer authorizer clients.Authorizer
wantStatus int wantStatus int
}{ }{
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, {"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, {"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden}, {"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden}, {"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
{"save session failure", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"save session failure", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, {"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -359,7 +359,7 @@ func TestProxy_Impersonate(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.cipher = tt.cipher p.encoder = tt.cipher
p.sessionStore = tt.sessionStore p.sessionStore = tt.sessionStore
p.AuthorizeClient = tt.authorizer p.AuthorizeClient = tt.authorizer
postForm := url.Values{} postForm := url.Values{}

View file

@ -43,11 +43,11 @@ const (
// ValidateOptions checks that proper configuration settings are set to create // ValidateOptions checks that proper configuration settings are set to create
// a proper Proxy instance // a proper Proxy instance
func ValidateOptions(o config.Options) error { func ValidateOptions(o config.Options) error {
if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil { if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %v", err) return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %v", err)
} }
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil { if _, err := cryptutil.NewAEADCipherFromBase64(o.CookieSecret); err != nil {
return fmt.Errorf("proxy: invalid 'COOKIE_SECRET': %v", err) return fmt.Errorf("proxy: invalid 'COOKIE_SECRET': %v", err)
} }
@ -78,7 +78,8 @@ type Proxy struct {
AuthorizeClient clients.Authorizer AuthorizeClient clients.Authorizer
cipher cryptutil.Cipher // cipher cipher.AEAD
encoder cryptutil.SecureEncoder
cookieName string cookieName string
cookieDomain string cookieDomain string
cookieSecret []byte cookieSecret []byte
@ -105,10 +106,12 @@ func New(opts config.Options) (*Proxy, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
cipher, err := cryptutil.NewCipherFromBase64(opts.CookieSecret) cipher, err := cryptutil.NewAEADCipherFromBase64(opts.CookieSecret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
encoder := cryptutil.NewSecureJSONEncoder(cipher)
if opts.CookieDomain == "" { if opts.CookieDomain == "" {
opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String()) opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String())
} }
@ -120,7 +123,7 @@ func New(opts config.Options) (*Proxy, error) {
CookieSecure: opts.CookieSecure, CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly, CookieHTTPOnly: opts.CookieHTTPOnly,
CookieExpire: opts.CookieExpire, CookieExpire: opts.CookieExpire,
CookieCipher: cipher, Encoder: encoder,
}) })
if err != nil { if err != nil {
@ -130,7 +133,7 @@ func New(opts config.Options) (*Proxy, error) {
SharedKey: opts.SharedKey, SharedKey: opts.SharedKey,
routeConfigs: make(map[string]*routeConfig), routeConfigs: make(map[string]*routeConfig),
cipher: cipher, encoder: encoder,
cookieSecret: decodedCookieSecret, cookieSecret: decodedCookieSecret,
cookieDomain: opts.CookieDomain, cookieDomain: opts.CookieDomain,
cookieName: opts.CookieName, cookieName: opts.CookieName,