mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-03 11:22:45 +02:00
authenticate: encrypt & mac oauth2 callback state
- cryptutil: add hmac & tests - cryptutil: rename cipher / encoders to be more clear - cryptutil: simplify SecureEncoder interface - cryptutil: renamed NewCipherFromBase64 to NewAEADCipherFromBase64 - cryptutil: move key & random generators to helpers Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
3a806c6dfc
commit
7c755d833f
26 changed files with 539 additions and 464 deletions
|
@ -1,6 +1,7 @@
|
||||||
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/cipher"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -20,10 +21,10 @@ const callbackPath = "/oauth2/callback"
|
||||||
// ValidateOptions checks that configuration are complete and valid.
|
// ValidateOptions checks that configuration are complete and valid.
|
||||||
// Returns on first error found.
|
// Returns on first error found.
|
||||||
func ValidateOptions(o config.Options) error {
|
func ValidateOptions(o config.Options) error {
|
||||||
if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil {
|
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
|
||||||
return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %v", err)
|
return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %v", err)
|
||||||
}
|
}
|
||||||
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
|
if _, err := cryptutil.NewAEADCipherFromBase64(o.CookieSecret); err != nil {
|
||||||
return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %v", err)
|
return fmt.Errorf("authenticate: 'COOKIE_SECRET' invalid %v", err)
|
||||||
}
|
}
|
||||||
if err := urlutil.ValidateURL(o.AuthenticateURL); err != nil {
|
if err := urlutil.ValidateURL(o.AuthenticateURL); err != nil {
|
||||||
|
@ -48,7 +49,8 @@ type Authenticate struct {
|
||||||
cookieSecret []byte
|
cookieSecret []byte
|
||||||
templates *template.Template
|
templates *template.Template
|
||||||
sessionStore sessions.SessionStore
|
sessionStore sessions.SessionStore
|
||||||
cipher cryptutil.Cipher
|
cipher cipher.AEAD
|
||||||
|
encoder cryptutil.SecureEncoder
|
||||||
provider identity.Authenticator
|
provider identity.Authenticator
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,7 +60,8 @@ func New(opts config.Options) (*Authenticate, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||||
cipher, err := cryptutil.NewCipher(decodedCookieSecret)
|
cipher, err := cryptutil.NewAEADCipher(decodedCookieSecret)
|
||||||
|
encoder := cryptutil.NewSecureJSONEncoder(cipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -72,7 +75,7 @@ func New(opts config.Options) (*Authenticate, error) {
|
||||||
CookieSecure: opts.CookieSecure,
|
CookieSecure: opts.CookieSecure,
|
||||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||||
CookieExpire: opts.CookieExpire,
|
CookieExpire: opts.CookieExpire,
|
||||||
CookieCipher: cipher,
|
Encoder: encoder,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -100,6 +103,7 @@ func New(opts config.Options) (*Authenticate, error) {
|
||||||
templates: templates.New(),
|
templates: templates.New(),
|
||||||
sessionStore: cookieStore,
|
sessionStore: cookieStore,
|
||||||
cipher: cipher,
|
cipher: cipher,
|
||||||
|
encoder: encoder,
|
||||||
provider: provider,
|
provider: provider,
|
||||||
cookieSecret: decodedCookieSecret,
|
cookieSecret: decodedCookieSecret,
|
||||||
cookieName: opts.CookieName,
|
cookieName: opts.CookieName,
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/csrf"
|
"github.com/pomerium/csrf"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
@ -38,8 +39,8 @@ func (a *Authenticate) Handler() http.Handler {
|
||||||
a.cookieSecret,
|
a.cookieSecret,
|
||||||
csrf.Path("/"),
|
csrf.Path("/"),
|
||||||
csrf.Domain(a.cookieDomain),
|
csrf.Domain(a.cookieDomain),
|
||||||
csrf.UnsafePaths([]string{"/oauth2/callback"}), // enforce CSRF on "safe" handler
|
csrf.UnsafePaths([]string{callbackPath}), // enforce CSRF on "safe" handler
|
||||||
csrf.FormValueName("state"), // rfc6749 section-10.12
|
csrf.FormValueName("state"), // rfc6749 section-10.12
|
||||||
csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieName)),
|
csrf.CookieName(fmt.Sprintf("%s_csrf", a.cookieName)),
|
||||||
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
|
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
|
||||||
))
|
))
|
||||||
|
@ -137,15 +138,21 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// redirectToIdentityProvider starts the authenticate process by redirecting the
|
// redirectToIdentityProvider starts the authenticate process by redirecting the
|
||||||
// user to their respective identity provider.
|
// user to their respective identity provider. This function also builds the
|
||||||
//
|
// 'state' parameter which is encrypted and includes authenticating data
|
||||||
|
// for validation.
|
||||||
|
// 'state' is : nonce|timestamp|redirect_url|encrypt(redirect_url)+mac(nonce,ts))
|
||||||
|
|
||||||
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
|
// https://openid.net/specs/openid-connect-core-1_0-final.html#AuthRequest
|
||||||
// https://tools.ietf.org/html/rfc6749#section-4.2.1
|
// https://tools.ietf.org/html/rfc6749#section-4.2.1
|
||||||
func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
func (a *Authenticate) redirectToIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||||
redirectURL := a.RedirectURL.ResolveReference(r.URL)
|
redirectURL := a.RedirectURL.ResolveReference(r.URL)
|
||||||
nonce := csrf.Token(r)
|
nonce := csrf.Token(r)
|
||||||
state := fmt.Sprintf("%v:%v", nonce, redirectURL.String())
|
now := time.Now().Unix()
|
||||||
encodedState := base64.URLEncoding.EncodeToString([]byte(state))
|
b := []byte(fmt.Sprintf("%s|%d|", nonce, now))
|
||||||
|
enc := cryptutil.Encrypt(a.cipher, []byte(redirectURL.String()), b)
|
||||||
|
b = append(b, enc...)
|
||||||
|
encodedState := base64.URLEncoding.EncodeToString(b)
|
||||||
http.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
|
http.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,19 +194,33 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, httputil.Error("malformed state", http.StatusBadRequest, err)
|
return nil, httputil.Error("malformed state", http.StatusBadRequest, err)
|
||||||
}
|
}
|
||||||
// split state into its it's components (nonce:redirect_uri)
|
|
||||||
statePayload := strings.SplitN(string(bytes), ":", 2)
|
// split state into its it's components, e.g.
|
||||||
if len(statePayload) != 2 {
|
// (nonce|timestamp|redirect_url|encrypted_data(redirect_url)+mac(nonce,ts))
|
||||||
return nil, fmt.Errorf("state malformed, size: %d", len(statePayload))
|
statePayload := strings.SplitN(string(bytes), "|", 3)
|
||||||
}
|
if len(statePayload) != 3 {
|
||||||
// parse redirect_uri; ignore csrf nonce (validity asserted by middleware)
|
return nil, httputil.Error("'state' is malformed", http.StatusBadRequest,
|
||||||
redirectURL, err := urlutil.ParseAndValidateURL(statePayload[1])
|
fmt.Errorf("state malformed, size: %d", len(statePayload)))
|
||||||
if err != nil {
|
|
||||||
return nil, httputil.Error("invalid redirect uri", http.StatusBadRequest, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo(bdd): if we want to be _extra_ sure, we can validate that the
|
// verify that the returned timestamp is valid (replay attack)
|
||||||
// redirectURL hmac is valid. But the nonce should cover the integrity...
|
if err := cryptutil.ValidTimestamp(statePayload[1]); err != nil {
|
||||||
|
return nil, httputil.Error(err.Error(), http.StatusBadRequest, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use our AEAD construct to enforce secrecy and authenticity,:
|
||||||
|
// mac: to validate the nonce again, and above timestamp
|
||||||
|
// decrypt: to prevent leaking 'redirect_uri' to IdP or logs)
|
||||||
|
b := []byte(fmt.Sprint(statePayload[0], "|", statePayload[1], "|"))
|
||||||
|
redirectString, err := cryptutil.Decrypt(a.cipher, []byte(statePayload[2]), b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, httputil.Error("'state' has invalid hmac", http.StatusBadRequest, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL, err := urlutil.ParseAndValidateURL(string(redirectString))
|
||||||
|
if err != nil {
|
||||||
|
return nil, httputil.Error("'state' has invalid redirect uri", http.StatusBadRequest, err)
|
||||||
|
}
|
||||||
|
|
||||||
// OK. Looks good so let's persist our user session
|
// OK. Looks good so let's persist our user session
|
||||||
if err := a.sessionStore.SaveSession(w, r, session); err != nil {
|
if err := a.sessionStore.SaveSession(w, r, session); err != nil {
|
||||||
|
@ -222,7 +243,7 @@ func (a *Authenticate) ExchangeToken(w http.ResponseWriter, r *http.Request) {
|
||||||
httputil.ErrorResponse(w, r, err)
|
httputil.ErrorResponse(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
encToken, err := sessions.MarshalSession(session, a.cipher)
|
encToken, err := sessions.MarshalSession(session, a.encoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/templates"
|
"github.com/pomerium/pomerium/internal/templates"
|
||||||
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testAuthenticate() *Authenticate {
|
func testAuthenticate() *Authenticate {
|
||||||
|
@ -73,22 +74,22 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
session sessions.SessionStore
|
session sessions.SessionStore
|
||||||
restStore sessions.SessionStore
|
restStore sessions.SessionStore
|
||||||
provider identity.MockProvider
|
provider identity.MockProvider
|
||||||
cipher cryptutil.Cipher
|
encoder cryptutil.SecureEncoder
|
||||||
wantCode int
|
wantCode int
|
||||||
}{
|
}{
|
||||||
{"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
|
{"good", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusFound},
|
||||||
{"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockCipher{}, http.StatusFound},
|
{"session not valid", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: false}, &cryptutil.MockEncoder{}, http.StatusFound},
|
||||||
{"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound},
|
{"session expired good refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockEncoder{}, http.StatusFound},
|
||||||
{"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockCipher{}, http.StatusFound}, // mocking hmac is meh
|
{"session expired bad refresh", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshError: errors.New("error")}, &cryptutil.MockEncoder{}, http.StatusFound}, // mocking hmac is meh
|
||||||
{"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockCipher{}, http.StatusFound},
|
{"session expired bad refresh save", "state=example", "https://some.example", &sessions.MockSessionStore{SaveError: errors.New("ruh roh"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, identity.MockProvider{ValidateResponse: true, RefreshResponse: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, &cryptutil.MockEncoder{}, http.StatusFound},
|
||||||
|
|
||||||
// {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusInternalServerError},
|
// {"no cookie found trying to load", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: http.ErrNoCookie, Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusInternalServerError},
|
||||||
{"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
|
{"unexpected error trying to load session", "state=example", "https://some.example", &sessions.MockSessionStore{LoadError: errors.New("error"), Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusFound},
|
||||||
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusFound},
|
{"empty state", "state=", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusFound},
|
||||||
{"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
{"malformed redirect uri", "state=example", "https://accounts.google.^", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusBadRequest},
|
||||||
// actually caught by go's handler, but we should keep the test.
|
// actually caught by go's handler, but we should keep the test.
|
||||||
{"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{}, http.StatusBadRequest},
|
{"bad redirect uri query", "state=nonce", "%gh&%ij", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{}, http.StatusBadRequest},
|
||||||
{"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockCipher{MarshalError: errors.New("error")}, http.StatusFound},
|
{"marshal session failure", "state=example", "https://some.example", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second)}}, identity.MockProvider{ValidateResponse: true}, &cryptutil.MockEncoder{MarshalError: errors.New("error")}, http.StatusFound},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -97,7 +98,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
provider: tt.provider,
|
provider: tt.provider,
|
||||||
RedirectURL: uriParseHelper("https://some.example"),
|
RedirectURL: uriParseHelper("https://some.example"),
|
||||||
SharedKey: "secret",
|
SharedKey: "secret",
|
||||||
cipher: tt.cipher,
|
encoder: tt.encoder,
|
||||||
}
|
}
|
||||||
uri := &url.URL{Host: "corp.some.example", Scheme: "https", Path: "/"}
|
uri := &url.URL{Host: "corp.some.example", Scheme: "https", Path: "/"}
|
||||||
uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI)
|
uri.RawQuery = fmt.Sprintf("%s&redirect_uri=%s", tt.state, tt.redirectURI)
|
||||||
|
@ -178,14 +179,19 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
|
|
||||||
func TestAuthenticate_OAuthCallback(t *testing.T) {
|
func TestAuthenticate_OAuthCallback(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
method string
|
method string
|
||||||
|
|
||||||
// url params
|
ts int64
|
||||||
paramErr string
|
stateOvveride string
|
||||||
code string
|
extraMac string
|
||||||
state string
|
extraState string
|
||||||
|
paramErr string
|
||||||
|
code string
|
||||||
|
redirectURI string
|
||||||
|
|
||||||
authenticateURL string
|
authenticateURL string
|
||||||
session sessions.SessionStore
|
session sessions.SessionStore
|
||||||
provider identity.MockProvider
|
provider identity.MockProvider
|
||||||
|
@ -193,30 +199,52 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
||||||
want string
|
want string
|
||||||
wantCode int
|
wantCode int
|
||||||
}{
|
}{
|
||||||
{"good", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusFound},
|
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusFound},
|
||||||
{"failed authenticate", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
|
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
|
||||||
{"failed save session", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusInternalServerError},
|
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusInternalServerError},
|
||||||
{"provider returned error", http.MethodGet, "idp error", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
|
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
|
||||||
{"empty code", http.MethodGet, "", "", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
|
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
|
||||||
{"invalid redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:corp.pomerium.io")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
|
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "", http.StatusBadRequest},
|
||||||
{"bad redirect uri", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:http://^^^")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"bad base64 state", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io")) + "%", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"too many state delimeters", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce:https://corp.pomerium.io:wait")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"too few state delimeters", http.MethodGet, "", "code", base64.URLEncoding.EncodeToString([]byte("nonce")), "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusInternalServerError},
|
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
|
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
|
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
|
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", Email: "blah@blah.com", RefreshDeadline: time.Now().Add(10 * time.Second)}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
authURL, _ := url.Parse(tt.authenticateURL)
|
authURL, _ := url.Parse(tt.authenticateURL)
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
RedirectURL: authURL,
|
RedirectURL: authURL,
|
||||||
sessionStore: tt.session,
|
sessionStore: tt.session,
|
||||||
provider: tt.provider,
|
provider: tt.provider,
|
||||||
|
cipher: aead,
|
||||||
}
|
}
|
||||||
u, _ := url.Parse("/oauthGet")
|
u, _ := url.Parse("/oauthGet")
|
||||||
params, _ := url.ParseQuery(u.RawQuery)
|
params, _ := url.ParseQuery(u.RawQuery)
|
||||||
params.Add("error", tt.paramErr)
|
params.Add("error", tt.paramErr)
|
||||||
params.Add("code", tt.code)
|
params.Add("code", tt.code)
|
||||||
params.Add("state", tt.state)
|
nonce := cryptutil.NewBase64Key() // mock csrf
|
||||||
|
|
||||||
|
// (nonce|timestamp|redirect_url|encrypt(redirect_url),mac(nonce,ts))
|
||||||
|
b := []byte(fmt.Sprintf("%s|%d|%s", nonce, tt.ts, tt.extraMac))
|
||||||
|
|
||||||
|
enc := cryptutil.Encrypt(a.cipher, []byte(tt.redirectURI), b)
|
||||||
|
b = append(b, enc...)
|
||||||
|
encodedState := base64.URLEncoding.EncodeToString(b)
|
||||||
|
if tt.extraState != "" {
|
||||||
|
encodedState += tt.extraState
|
||||||
|
}
|
||||||
|
if tt.stateOvveride != "" {
|
||||||
|
encodedState = tt.stateOvveride
|
||||||
|
}
|
||||||
|
params.Add("state", encodedState)
|
||||||
|
|
||||||
u.RawQuery = params.Encode()
|
u.RawQuery = params.Encode()
|
||||||
|
|
||||||
|
@ -240,22 +268,27 @@ func TestAuthenticate_ExchangeToken(t *testing.T) {
|
||||||
method string
|
method string
|
||||||
idToken string
|
idToken string
|
||||||
restStore sessions.SessionStore
|
restStore sessions.SessionStore
|
||||||
cipher cryptutil.Cipher
|
encoder cryptutil.SecureEncoder
|
||||||
provider identity.MockProvider
|
provider identity.MockProvider
|
||||||
want string
|
want string
|
||||||
}{
|
}{
|
||||||
{"good", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
|
{"good", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
|
||||||
{"could not exchange identity for session", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionError: errors.New("error")}, ""},
|
{"could not exchange identity for session", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionError: errors.New("error")}, ""},
|
||||||
{"missing token", http.MethodPost, "", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "missing id token"},
|
{"missing token", http.MethodPost, "", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "missing id token"},
|
||||||
{"malformed form", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
|
{"malformed form", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, ""},
|
||||||
{"can't marshal token", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockCipher{MarshalError: errors.New("can't marshal token")}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "can't marshal token"},
|
{"can't marshal token", http.MethodPost, "token", &sessions.MockSessionStore{}, &cryptutil.MockEncoder{MarshalError: errors.New("can't marshal token")}, identity.MockProvider{IDTokenToSessionResponse: sessions.State{IDToken: "ok"}}, "can't marshal token"},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
cipher: tt.cipher,
|
encoder: tt.encoder,
|
||||||
provider: tt.provider,
|
provider: tt.provider,
|
||||||
sessionStore: tt.restStore,
|
sessionStore: tt.restStore,
|
||||||
|
cipher: aead,
|
||||||
}
|
}
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
if tt.idToken != "" {
|
if tt.idToken != "" {
|
||||||
|
@ -282,6 +315,7 @@ func TestAuthenticate_ExchangeToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
|
@ -305,13 +339,17 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
a := Authenticate{
|
a := Authenticate{
|
||||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||||
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
||||||
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
||||||
sessionStore: tt.session,
|
sessionStore: tt.session,
|
||||||
provider: tt.provider,
|
provider: tt.provider,
|
||||||
|
cipher: aead,
|
||||||
}
|
}
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
state, _ := tt.session.LoadSession(r)
|
state, _ := tt.session.LoadSession(r)
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
|
|
||||||
### Security
|
### Security
|
||||||
|
|
||||||
|
- The user's original intended location before completing the authentication process is now encrypted and kept confidential from the identity provider. [GH-316](https://github.com/pomerium/pomerium/pull/316)
|
||||||
- Under certain circumstances, where debug logging was enabled, pomerium's shared secret could be leaked to http access logs as a query param.
|
- Under certain circumstances, where debug logging was enabled, pomerium's shared secret could be leaked to http access logs as a query param.
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -14,7 +14,7 @@ require (
|
||||||
github.com/magiconair/properties v1.8.1 // indirect
|
github.com/magiconair/properties v1.8.1 // indirect
|
||||||
github.com/mitchellh/hashstructure v1.0.0
|
github.com/mitchellh/hashstructure v1.0.0
|
||||||
github.com/pelletier/go-toml v1.4.0 // indirect
|
github.com/pelletier/go-toml v1.4.0 // indirect
|
||||||
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30
|
github.com/pomerium/csrf v1.6.2-0.20190918035251-f3318380bad3
|
||||||
github.com/pomerium/go-oidc v2.0.0+incompatible
|
github.com/pomerium/go-oidc v2.0.0+incompatible
|
||||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||||
github.com/prometheus/client_golang v0.9.3
|
github.com/prometheus/client_golang v0.9.3
|
||||||
|
|
3
go.sum
3
go.sum
|
@ -126,6 +126,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30 h1:jggCv6hZvcxjGa3gqkYY2EUuOkITI9Znugz/f3QJfRQ=
|
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30 h1:jggCv6hZvcxjGa3gqkYY2EUuOkITI9Znugz/f3QJfRQ=
|
||||||
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0=
|
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0=
|
||||||
|
github.com/pomerium/csrf v1.6.2-0.20190918035251-f3318380bad3 h1:FmzFXnCAepHZwl6QPhTFqBHcbcGevdiEQjutK+M5bj4=
|
||||||
|
github.com/pomerium/csrf v1.6.2-0.20190918035251-f3318380bad3/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0=
|
||||||
github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o=
|
github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o=
|
||||||
github.com/pomerium/go-oidc v2.0.0+incompatible/go.mod h1:DRsGVw6MOgxbfq4Y57jKOE8lbEfayxeiY0A8/4vxjBM=
|
github.com/pomerium/go-oidc v2.0.0+incompatible/go.mod h1:DRsGVw6MOgxbfq4Y57jKOE8lbEfayxeiY0A8/4vxjBM=
|
||||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
||||||
|
@ -196,6 +198,7 @@ golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmV
|
||||||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8 h1:1wopBVtVdWnn03fZelqdXTqk7U7zPQCb+T4rbU9ZEoU=
|
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8 h1:1wopBVtVdWnn03fZelqdXTqk7U7zPQCb+T4rbU9ZEoU=
|
||||||
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
|
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 h1:0hQKqeLdqlt5iIwVOBErRisrHJAN57yOiPRQItI20fU=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -13,122 +12,47 @@ import (
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultKeySize is the default key size in bytes.
|
// NewAEADCipher takes secret key and returns a new XChacha20poly1305 cipher.
|
||||||
const DefaultKeySize = 32
|
func NewAEADCipher(secret []byte) (cipher.AEAD, error) {
|
||||||
|
if len(secret) != 32 {
|
||||||
// NewKey generates a random 32-byte key.
|
return nil, fmt.Errorf("cryptutil: got %d bytes but want 32", len(secret))
|
||||||
//
|
|
||||||
// Panics if source of randomness fails.
|
|
||||||
func NewKey() []byte {
|
|
||||||
return randomBytes(DefaultKeySize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBase64Key generates a random base64 encoded 32-byte key.
|
|
||||||
//
|
|
||||||
// Panics if source of randomness fails.
|
|
||||||
func NewBase64Key() string {
|
|
||||||
return NewRandomStringN(DefaultKeySize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRandomStringN returns base64 encoded random string of a given num of bytes.
|
|
||||||
//
|
|
||||||
// Panics if source of randomness fails.
|
|
||||||
func NewRandomStringN(c int) string {
|
|
||||||
return base64.StdEncoding.EncodeToString(randomBytes(c))
|
|
||||||
}
|
|
||||||
|
|
||||||
func randomBytes(c int) []byte {
|
|
||||||
if c < 0 {
|
|
||||||
c = DefaultKeySize
|
|
||||||
}
|
}
|
||||||
b := make([]byte, c)
|
return chacha20poly1305.NewX(secret)
|
||||||
if _, err := rand.Read(b); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cipher provides methods to encrypt and decrypt values.
|
// NewAEADCipherFromBase64 takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher.
|
||||||
type Cipher interface {
|
func NewAEADCipherFromBase64(s string) (cipher.AEAD, error) {
|
||||||
Encrypt([]byte) ([]byte, error)
|
|
||||||
Decrypt([]byte) ([]byte, error)
|
|
||||||
Marshal(interface{}) (string, error)
|
|
||||||
Unmarshal(string, interface{}) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// XChaCha20Cipher provides methods to encrypt and decrypt values.
|
|
||||||
// Using an AEAD is a cipher providing authenticated encryption with associated data.
|
|
||||||
// For a description of the methodology, see https://en.wikipedia.org/wiki/Authenticated_encryption
|
|
||||||
type XChaCha20Cipher struct {
|
|
||||||
aead cipher.AEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCipher takes secret key and returns a new XChacha20poly1305 cipher.
|
|
||||||
func NewCipher(secret []byte) (*XChaCha20Cipher, error) {
|
|
||||||
aead, err := chacha20poly1305.NewX(secret)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &XChaCha20Cipher{
|
|
||||||
aead: aead,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCipherFromBase64 takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher.
|
|
||||||
func NewCipherFromBase64(s string) (*XChaCha20Cipher, error) {
|
|
||||||
decoded, err := base64.StdEncoding.DecodeString(s)
|
decoded, err := base64.StdEncoding.DecodeString(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cryptutil: invalid base64: %v", err)
|
return nil, fmt.Errorf("cryptutil: invalid base64: %v", err)
|
||||||
}
|
}
|
||||||
if len(decoded) != 32 {
|
return NewAEADCipher(decoded)
|
||||||
return nil, fmt.Errorf("cryptutil: got %d bytes but want 32", len(decoded))
|
|
||||||
}
|
|
||||||
return NewCipher(decoded)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateNonce generates a random nonce.
|
// SecureEncoder provides and interface for to encrypt and decrypting structures .
|
||||||
// Panics if source of randomness fails.
|
type SecureEncoder interface {
|
||||||
func (c *XChaCha20Cipher) GenerateNonce() []byte {
|
Marshal(interface{}) (string, error)
|
||||||
return randomBytes(c.aead.NonceSize())
|
Unmarshal(string, interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt a value using XChaCha20-Poly1305
|
// SecureJSONEncoder implements SecureEncoder for JSON using an AEAD cipher.
|
||||||
func (c *XChaCha20Cipher) Encrypt(plaintext []byte) (joined []byte, err error) {
|
//
|
||||||
defer func() {
|
// See https://en.wikipedia.org/wiki/Authenticated_encryption
|
||||||
if r := recover(); r != nil {
|
type SecureJSONEncoder struct {
|
||||||
err = fmt.Errorf("cryptutil: error encrypting bytes: %v", r)
|
aead cipher.AEAD
|
||||||
}
|
|
||||||
}()
|
|
||||||
nonce := c.GenerateNonce()
|
|
||||||
|
|
||||||
ciphertext := c.aead.Seal(nil, nonce, plaintext, nil)
|
|
||||||
|
|
||||||
// we return the nonce as part of the returned value
|
|
||||||
joined = append(ciphertext, nonce...)
|
|
||||||
return joined, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrypt a value using XChaCha20-Poly1305
|
// NewSecureJSONEncoder takes a base64 encoded secret key and returns a new XChacha20poly1305 cipher.
|
||||||
func (c *XChaCha20Cipher) Decrypt(joined []byte) ([]byte, error) {
|
func NewSecureJSONEncoder(aead cipher.AEAD) SecureEncoder {
|
||||||
if len(joined) <= c.aead.NonceSize() {
|
return &SecureJSONEncoder{aead: aead}
|
||||||
return nil, fmt.Errorf("cryptutil: invalid input size: %d", len(joined))
|
|
||||||
}
|
|
||||||
// grab out the nonce
|
|
||||||
pivot := len(joined) - c.aead.NonceSize()
|
|
||||||
ciphertext := joined[:pivot]
|
|
||||||
nonce := joined[pivot:]
|
|
||||||
|
|
||||||
plaintext, err := c.aead.Open(nil, nonce, ciphertext, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return plaintext, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher
|
// Marshal marshals the interface state as JSON, encrypts the JSON using the cipher
|
||||||
// and base64 encodes the binary value as a string and returns the result
|
// and base64 encodes the binary value as a string and returns the result
|
||||||
func (c *XChaCha20Cipher) Marshal(s interface{}) (string, error) {
|
//
|
||||||
|
// can panic if source of random entropy is exhausted generating a nonce.
|
||||||
|
func (c *SecureJSONEncoder) Marshal(s interface{}) (string, error) {
|
||||||
// encode json value
|
// encode json value
|
||||||
plaintext, err := json.Marshal(s)
|
plaintext, err := json.Marshal(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -140,10 +64,8 @@ func (c *XChaCha20Cipher) Marshal(s interface{}) (string, error) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
// encrypt the compressed JSON bytes
|
// encrypt the compressed JSON bytes
|
||||||
ciphertext, err := c.Encrypt(compressed)
|
ciphertext := Encrypt(c.aead, compressed, nil)
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
// base64-encode the result
|
// base64-encode the result
|
||||||
encoded := base64.RawURLEncoding.EncodeToString(ciphertext)
|
encoded := base64.RawURLEncoding.EncodeToString(ciphertext)
|
||||||
return encoded, nil
|
return encoded, nil
|
||||||
|
@ -151,14 +73,14 @@ func (c *XChaCha20Cipher) Marshal(s interface{}) (string, error) {
|
||||||
|
|
||||||
// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
// Unmarshal takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
||||||
// byte slice the passed cipher, and unmarshals the resulting JSON into the struct pointer passed
|
// byte slice the passed cipher, and unmarshals the resulting JSON into the struct pointer passed
|
||||||
func (c *XChaCha20Cipher) Unmarshal(value string, s interface{}) error {
|
func (c *SecureJSONEncoder) Unmarshal(value string, s interface{}) error {
|
||||||
// convert base64 string value to bytes
|
// convert base64 string value to bytes
|
||||||
ciphertext, err := base64.RawURLEncoding.DecodeString(value)
|
ciphertext, err := base64.RawURLEncoding.DecodeString(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// decrypt the bytes
|
// decrypt the bytes
|
||||||
compressed, err := c.Decrypt(ciphertext)
|
compressed, err := Decrypt(c.aead, ciphertext, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -172,10 +94,10 @@ func (c *XChaCha20Cipher) Unmarshal(value string, s interface{}) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// compress gzips a set of bytes
|
||||||
func compress(data []byte) ([]byte, error) {
|
func compress(data []byte) ([]byte, error) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression)
|
writer, err := gzip.NewWriterLevel(&buf, gzip.DefaultCompression)
|
||||||
|
@ -194,6 +116,7 @@ func compress(data []byte) ([]byte, error) {
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// decompress un-gzips a set of bytes
|
||||||
func decompress(data []byte) ([]byte, error) {
|
func decompress(data []byte) ([]byte, error) {
|
||||||
reader, err := gzip.NewReader(bytes.NewReader(data))
|
reader, err := gzip.NewReader(bytes.NewReader(data))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -206,3 +129,27 @@ func decompress(data []byte) ([]byte, error) {
|
||||||
}
|
}
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Encrypt encrypts a value with optional associated data
|
||||||
|
//
|
||||||
|
// Panics if source of randomness fails.
|
||||||
|
func Encrypt(a cipher.AEAD, data, ad []byte) []byte {
|
||||||
|
iv := randomBytes(a.NonceSize())
|
||||||
|
ciphertext := a.Seal(nil, iv, data, ad)
|
||||||
|
return append(ciphertext, iv...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt a value with optional associated data
|
||||||
|
func Decrypt(a cipher.AEAD, data, ad []byte) ([]byte, error) {
|
||||||
|
if len(data) <= a.NonceSize() {
|
||||||
|
return nil, fmt.Errorf("cryptutil: invalid input size: %d", len(data))
|
||||||
|
}
|
||||||
|
size := len(data) - a.NonceSize()
|
||||||
|
ciphertext := data[:size]
|
||||||
|
nonce := data[size:]
|
||||||
|
plaintext, err := a.Open(nil, nonce, ciphertext, ad)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,12 +1,8 @@
|
||||||
package cryptutil
|
package cryptutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha1"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,27 +10,24 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
||||||
plaintext := []byte("my plain text value")
|
plaintext := []byte("my plain text value")
|
||||||
|
|
||||||
key := NewKey()
|
key := NewKey()
|
||||||
c, err := NewCipher(key)
|
c, err := NewAEADCipher(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected err: %v", err)
|
t.Fatalf("unexpected err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ciphertext, err := c.Encrypt(plaintext)
|
ciphertext := Encrypt(c, plaintext, nil)
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if reflect.DeepEqual(plaintext, ciphertext) {
|
if reflect.DeepEqual(plaintext, ciphertext) {
|
||||||
t.Fatalf("plaintext is not encrypted plaintext:%v ciphertext:%x", plaintext, ciphertext)
|
t.Fatalf("plaintext is not encrypted plaintext:%v ciphertext:%x", plaintext, ciphertext)
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err := c.Decrypt(ciphertext)
|
got, err := Decrypt(c, ciphertext, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected err decrypting: %v", err)
|
t.Fatalf("unexpected err decrypting: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if less than 32 bytes, fail
|
// if less than 32 bytes, fail
|
||||||
_, err = c.Decrypt([]byte("oh"))
|
_, err = Decrypt(c, []byte("oh"), nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("should fail if <32 bytes output: %v", err)
|
t.Fatalf("should fail if <32 bytes output: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -49,10 +42,11 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
||||||
func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
||||||
key := NewKey()
|
key := NewKey()
|
||||||
|
|
||||||
c, err := NewCipher(key)
|
a, err := NewAEADCipher(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected err: %v", err)
|
t.Fatalf("unexpected err: %v", err)
|
||||||
}
|
}
|
||||||
|
c := SecureJSONEncoder{aead: a}
|
||||||
|
|
||||||
type TC struct {
|
type TC struct {
|
||||||
Field string `json:"field"`
|
Field string `json:"field"`
|
||||||
|
@ -101,102 +95,7 @@ func TestMarshalAndUnmarshalStruct(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCipherDataRace(t *testing.T) {
|
func TestSecureJSONEncoder_Marshal(t *testing.T) {
|
||||||
cipher, err := NewCipher(NewKey())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected generating cipher err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type TC struct {
|
|
||||||
Field string `json:"field"`
|
|
||||||
}
|
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(c *XChaCha20Cipher, wg *sync.WaitGroup) {
|
|
||||||
defer wg.Done()
|
|
||||||
b := make([]byte, 32)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexecpted error reading random bytes: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sha := fmt.Sprintf("%x", sha1.New().Sum(b))
|
|
||||||
tc := &TC{
|
|
||||||
Field: sha,
|
|
||||||
}
|
|
||||||
|
|
||||||
value1, err := c.Marshal(tc)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
value2, err := c.Marshal(tc)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if value1 == value2 {
|
|
||||||
t.Fatalf("expected marshaled values to not be equal %v != %v", value1, value2)
|
|
||||||
}
|
|
||||||
|
|
||||||
got1 := &TC{}
|
|
||||||
err = c.Unmarshal(value1, got1)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(got1, tc) {
|
|
||||||
t.Logf("want: %#v", tc)
|
|
||||||
t.Logf(" got: %#v", got1)
|
|
||||||
t.Fatalf("expected structs to be equal")
|
|
||||||
}
|
|
||||||
|
|
||||||
got2 := &TC{}
|
|
||||||
err = c.Unmarshal(value2, got2)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected err unmarshalling struct: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(got1, got2) {
|
|
||||||
t.Logf("got2: %#v", got2)
|
|
||||||
t.Logf("got1: %#v", got1)
|
|
||||||
t.Fatalf("expected structs to be equal")
|
|
||||||
}
|
|
||||||
|
|
||||||
}(cipher, wg)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateRandomString(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
c int
|
|
||||||
want int
|
|
||||||
}{
|
|
||||||
{"simple", 32, 32},
|
|
||||||
{"zero", 0, 0},
|
|
||||||
{"negative", -1, 32},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
o := NewRandomStringN(tt.c)
|
|
||||||
b, err := base64.StdEncoding.DecodeString(o)
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
got := len(b)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("NewRandomStringN() = %d, want %d", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestXChaCha20Cipher_Marshal(t *testing.T) {
|
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -225,20 +124,22 @@ func TestXChaCha20Cipher_Marshal(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
||||||
c, err := NewCipher(NewKey())
|
c, err := NewAEADCipher(NewKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected err: %v", err)
|
t.Fatalf("unexpected err: %v", err)
|
||||||
}
|
}
|
||||||
_, err = c.Marshal(tt.s)
|
e := SecureJSONEncoder{aead: c}
|
||||||
|
|
||||||
|
_, err = e.Marshal(tt.s)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("XChaCha20Cipher.Marshal() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("SecureJSONEncoder.Marshal() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewCipher(t *testing.T) {
|
func TestNewAEADCipher(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -251,16 +152,16 @@ func TestNewCipher(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
_, err := NewCipher(tt.secret)
|
_, err := NewAEADCipher(tt.secret)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("NewCipher() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("NewAEADCipher() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewCipherFromBase64(t *testing.T) {
|
func TestNewAEADCipherFromBase64(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -274,34 +175,11 @@ func TestNewCipherFromBase64(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
_, err := NewCipherFromBase64(tt.s)
|
_, err := NewAEADCipherFromBase64(tt.s)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("NewCipherFromBase64() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("NewAEADCipherFromBase64() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewBase64Key(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
want int
|
|
||||||
}{
|
|
||||||
{"simple", 32},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
o := NewBase64Key()
|
|
||||||
b, err := base64.StdEncoding.DecodeString(o)
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
got := len(b)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("NewBase64Key() = %d, want %d", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
45
internal/cryptutil/helpers.go
Normal file
45
internal/cryptutil/helpers.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultKeySize is the default key size in bytes.
|
||||||
|
const DefaultKeySize = 32
|
||||||
|
|
||||||
|
// NewKey generates a random 32-byte (256 bit) key.
|
||||||
|
//
|
||||||
|
// Panics if source of randomness fails.
|
||||||
|
func NewKey() []byte {
|
||||||
|
return randomBytes(DefaultKeySize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBase64Key generates a random base64 encoded 32-byte key.
|
||||||
|
//
|
||||||
|
// Panics if source of randomness fails.
|
||||||
|
func NewBase64Key() string {
|
||||||
|
return NewRandomStringN(DefaultKeySize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRandomStringN returns base64 encoded random string of a given num of bytes.
|
||||||
|
//
|
||||||
|
// Panics if source of randomness fails.
|
||||||
|
func NewRandomStringN(c int) string {
|
||||||
|
return base64.StdEncoding.EncodeToString(randomBytes(c))
|
||||||
|
}
|
||||||
|
|
||||||
|
// randomBytes generates C number of random bytes suitable for cryptographic
|
||||||
|
// operations.
|
||||||
|
//
|
||||||
|
// Panics if source of randomness fails.
|
||||||
|
func randomBytes(c int) []byte {
|
||||||
|
if c < 0 {
|
||||||
|
c = DefaultKeySize
|
||||||
|
}
|
||||||
|
b := make([]byte, c)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
55
internal/cryptutil/helpers_test.go
Normal file
55
internal/cryptutil/helpers_test.go
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateRandomString(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
c int
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"simple", 32, 32},
|
||||||
|
{"zero", 0, 0},
|
||||||
|
{"negative", -1, 32},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
o := NewRandomStringN(tt.c)
|
||||||
|
b, err := base64.StdEncoding.DecodeString(o)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
got := len(b)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("NewRandomStringN() = %d, want %d", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewBase64Key(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"simple", 32},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
o := NewBase64Key()
|
||||||
|
b, err := base64.StdEncoding.DecodeString(o)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
got := len(b)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("NewBase64Key() = %d, want %d", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
50
internal/cryptutil/hmac.go
Normal file
50
internal/cryptutil/hmac.go
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha512"
|
||||||
|
"errors"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errTimestampMalformed = errors.New("internal/cryptutil: timestamp malformed")
|
||||||
|
errTimestampExpired = errors.New("internal/cryptutil: timestamp expired")
|
||||||
|
errTimestampTooSoon = errors.New("internal/cryptutil: timestamp too soon")
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateHMAC produces a symmetric signature using a shared secret key.
|
||||||
|
func GenerateHMAC(data []byte, key string) []byte {
|
||||||
|
h := hmac.New(sha512.New512_256, []byte(key))
|
||||||
|
h.Write(data)
|
||||||
|
return h.Sum(nil)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckHMAC securely checks the supplied MAC against a message using the
|
||||||
|
// shared secret key.
|
||||||
|
func CheckHMAC(data, suppliedMAC []byte, key string) bool {
|
||||||
|
expectedMAC := GenerateHMAC(data, key)
|
||||||
|
return hmac.Equal(expectedMAC, suppliedMAC)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidTimestamp is a helper function often used in conjunction with an HMAC
|
||||||
|
// function to verify that the timestamp (in unix seconds) is within leeway
|
||||||
|
// period.
|
||||||
|
// todo(bdd) : should leeway be configurable?
|
||||||
|
func ValidTimestamp(ts string) error {
|
||||||
|
var timeStamp int64
|
||||||
|
var err error
|
||||||
|
if timeStamp, err = strconv.ParseInt(ts, 10, 64); err != nil {
|
||||||
|
return errTimestampMalformed
|
||||||
|
}
|
||||||
|
// unix time in seconds
|
||||||
|
tm := time.Unix(timeStamp, 0)
|
||||||
|
if time.Since(tm) > DefaultLeeway {
|
||||||
|
return errTimestampExpired
|
||||||
|
}
|
||||||
|
if time.Until(tm) > DefaultLeeway {
|
||||||
|
return errTimestampTooSoon
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
70
internal/cryptutil/hmac_test.go
Normal file
70
internal/cryptutil/hmac_test.go
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
package cryptutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHMAC(t *testing.T) {
|
||||||
|
// https://groups.google.com/d/msg/sci.crypt/OolWgsgQD-8/jHciyWkaL0gJ
|
||||||
|
hmacTests := []struct {
|
||||||
|
key string
|
||||||
|
data string
|
||||||
|
digest string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
key: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
|
||||||
|
data: "4869205468657265", // "Hi There"
|
||||||
|
digest: "9f9126c3d9c3c330d760425ca8a217e31feae31bfe70196ff81642b868402eab",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: "4a656665", // "Jefe"
|
||||||
|
data: "7768617420646f2079612077616e7420666f72206e6f7468696e673f", // "what do ya want for nothing?"
|
||||||
|
digest: "6df7b24630d5ccb2ee335407081a87188c221489768fa2020513b2d593359456",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for idx, tt := range hmacTests {
|
||||||
|
keySlice, _ := hex.DecodeString(tt.key)
|
||||||
|
dataBytes, _ := hex.DecodeString(tt.data)
|
||||||
|
expectedDigest, _ := hex.DecodeString(tt.digest)
|
||||||
|
|
||||||
|
keyBytes := &[32]byte{}
|
||||||
|
copy(keyBytes[:], keySlice)
|
||||||
|
|
||||||
|
macDigest := GenerateHMAC(dataBytes, string(keyBytes[:]))
|
||||||
|
if !bytes.Equal(macDigest, expectedDigest) {
|
||||||
|
t.Errorf("test %d generated unexpected mac", idx)
|
||||||
|
}
|
||||||
|
if !CheckHMAC(dataBytes, macDigest, string(keyBytes[:])) {
|
||||||
|
t.Errorf("test %d generated unexpected mac", idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidTimestamp(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ts string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"good - now", fmt.Sprint(time.Now().Unix()), false},
|
||||||
|
{"good - now - 200ms", fmt.Sprint(time.Now().Add(-200 * time.Millisecond).Unix()), false},
|
||||||
|
{"good - now + 200ms", fmt.Sprint(time.Now().Add(200 * time.Millisecond).Unix()), false},
|
||||||
|
{"bad - now + 10m", fmt.Sprint(time.Now().Add(10 * time.Minute).Unix()), true},
|
||||||
|
{"bad - now - 10m", fmt.Sprint(time.Now().Add(-10 * time.Minute).Unix()), true},
|
||||||
|
{"malformed - non int", fmt.Sprint("pomerium"), true},
|
||||||
|
{"malformed - negative number", fmt.Sprint("-1"), true},
|
||||||
|
{"malformed - huge number", fmt.Sprintf("%d", 10*10000000000), true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := ValidTimestamp(tt.ts); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ValidTimestamp() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,28 +1,18 @@
|
||||||
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
package cryptutil // import "github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
|
||||||
// MockCipher MockCSRFStore is a mock implementation of Cipher.
|
// MockEncoder MockCSRFStore is a mock implementation of Cipher.
|
||||||
type MockCipher struct {
|
type MockEncoder struct {
|
||||||
EncryptResponse []byte
|
|
||||||
EncryptError error
|
|
||||||
DecryptResponse []byte
|
|
||||||
DecryptError error
|
|
||||||
MarshalResponse string
|
MarshalResponse string
|
||||||
MarshalError error
|
MarshalError error
|
||||||
UnmarshalError error
|
UnmarshalError error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt is a mock implementation of MockCipher.
|
// Marshal is a mock implementation of MockEncoder.
|
||||||
func (mc MockCipher) Encrypt(b []byte) ([]byte, error) { return mc.EncryptResponse, mc.EncryptError }
|
func (mc MockEncoder) Marshal(i interface{}) (string, error) {
|
||||||
|
|
||||||
// Decrypt is a mock implementation of MockCipher.
|
|
||||||
func (mc MockCipher) Decrypt(b []byte) ([]byte, error) { return mc.DecryptResponse, mc.DecryptError }
|
|
||||||
|
|
||||||
// Marshal is a mock implementation of MockCipher.
|
|
||||||
func (mc MockCipher) Marshal(i interface{}) (string, error) {
|
|
||||||
return mc.MarshalResponse, mc.MarshalError
|
return mc.MarshalResponse, mc.MarshalError
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unmarshal is a mock implementation of MockCipher.
|
// Unmarshal is a mock implementation of MockEncoder.
|
||||||
func (mc MockCipher) Unmarshal(s string, i interface{}) error {
|
func (mc MockEncoder) Unmarshal(s string, i interface{}) error {
|
||||||
return mc.UnmarshalError
|
return mc.UnmarshalError
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,31 +5,13 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMockCipher_Unmarshal(t *testing.T) {
|
func TestMockEncoder(t *testing.T) {
|
||||||
e := errors.New("err")
|
e := errors.New("err")
|
||||||
mc := MockCipher{
|
mc := MockEncoder{
|
||||||
EncryptResponse: []byte("EncryptResponse"),
|
|
||||||
EncryptError: e,
|
|
||||||
DecryptResponse: []byte("DecryptResponse"),
|
|
||||||
DecryptError: e,
|
|
||||||
MarshalResponse: "MarshalResponse",
|
MarshalResponse: "MarshalResponse",
|
||||||
MarshalError: e,
|
MarshalError: e,
|
||||||
UnmarshalError: e,
|
UnmarshalError: e,
|
||||||
}
|
}
|
||||||
b, err := mc.Encrypt([]byte("test"))
|
|
||||||
if string(b) != "EncryptResponse" {
|
|
||||||
t.Error("unexpected encrypt response")
|
|
||||||
}
|
|
||||||
if err != e {
|
|
||||||
t.Error("unexpected encrypt error")
|
|
||||||
}
|
|
||||||
b, err = mc.Decrypt([]byte("test"))
|
|
||||||
if string(b) != "DecryptResponse" {
|
|
||||||
t.Error("unexpected Decrypt response")
|
|
||||||
}
|
|
||||||
if err != e {
|
|
||||||
t.Error("unexpected Decrypt error")
|
|
||||||
}
|
|
||||||
s, err := mc.Marshal("test")
|
s, err := mc.Marshal("test")
|
||||||
if err != e {
|
if err != e {
|
||||||
t.Error("unexpected Marshal error")
|
t.Error("unexpected Marshal error")
|
||||||
|
|
|
@ -9,6 +9,11 @@ import (
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultLeeway defines the default leeway for matching NotBefore/Expiry claims.
|
||||||
|
DefaultLeeway = 5.0 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
// JWTSigner implements JWT signing according to JSON Web Token (JWT) RFC7519
|
// JWTSigner implements JWT signing according to JSON Web Token (JWT) RFC7519
|
||||||
// https://tools.ietf.org/html/rfc7519
|
// https://tools.ietf.org/html/rfc7519
|
||||||
type JWTSigner interface {
|
type JWTSigner interface {
|
||||||
|
@ -92,8 +97,8 @@ func (s *ES256Signer) SignJWT(user, email, groups string) (string, error) {
|
||||||
s.Groups = groups
|
s.Groups = groups
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
s.IssuedAt = *jwt.NewNumericDate(now)
|
s.IssuedAt = *jwt.NewNumericDate(now)
|
||||||
s.Expiry = *jwt.NewNumericDate(now.Add(jwt.DefaultLeeway))
|
s.Expiry = *jwt.NewNumericDate(now.Add(DefaultLeeway))
|
||||||
s.NotBefore = *jwt.NewNumericDate(now.Add(-1 * jwt.DefaultLeeway))
|
s.NotBefore = *jwt.NewNumericDate(now.Add(-1 * DefaultLeeway))
|
||||||
rawJWT, err := jwt.Signed(s.signer).Claims(s).CompactSerialize()
|
rawJWT, err := jwt.Signed(s.signer).Claims(s).CompactSerialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("cryptutil: sign failed %v", err)
|
return "", fmt.Errorf("cryptutil: sign failed %v", err)
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
package httputil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCSRFFailureHandler(t *testing.T) {
|
func TestCSRFFailureHandler(t *testing.T) {
|
||||||
|
@ -35,3 +37,19 @@ func TestCSRFFailureHandler(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewRouter(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want *mux.Router
|
||||||
|
}{
|
||||||
|
{"this is a gorilla router right?", mux.NewRouter()},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := NewRouter(); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("NewRouter() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,14 +1,11 @@
|
||||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/hmac"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
@ -183,22 +180,8 @@ func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
i, err := strconv.ParseInt(timestamp, 10, 64)
|
if err := cryptutil.ValidTimestamp(timestamp); err != nil {
|
||||||
if err != nil {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
tm := time.Unix(i, 0)
|
return cryptutil.CheckHMAC([]byte(fmt.Sprint(redirectURI, timestamp)), requestSig, secret)
|
||||||
ttl := 5 * time.Minute
|
|
||||||
if time.Since(tm) > ttl {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
localSig := redirectURLSignature(redirectURI, tm, secret)
|
|
||||||
|
|
||||||
return hmac.Equal(requestSig, localSig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
|
|
||||||
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
|
|
||||||
h := cryptutil.Hash(secret, data)
|
|
||||||
return h
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,8 +8,15 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func hmacHelperFunc(rawRedirect string, timestamp time.Time, secret string) []byte {
|
||||||
|
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
|
||||||
|
return cryptutil.GenerateHMAC(data, secret)
|
||||||
|
}
|
||||||
|
|
||||||
func Test_SameDomain(t *testing.T) {
|
func Test_SameDomain(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
@ -45,7 +52,7 @@ func Test_ValidSignature(t *testing.T) {
|
||||||
goodURL := "https://example.com/redirect"
|
goodURL := "https://example.com/redirect"
|
||||||
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
||||||
now := fmt.Sprint(time.Now().Unix())
|
now := fmt.Sprint(time.Now().Unix())
|
||||||
rawSig := redirectURLSignature(goodURL, time.Now(), secretA)
|
rawSig := hmacHelperFunc(goodURL, time.Now(), secretA)
|
||||||
sig := base64.URLEncoding.EncodeToString(rawSig)
|
sig := base64.URLEncoding.EncodeToString(rawSig)
|
||||||
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
||||||
|
|
||||||
|
@ -73,27 +80,6 @@ func Test_ValidSignature(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_redirectURLSignature(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
rawRedirect string
|
|
||||||
timestamp time.Time
|
|
||||||
secret string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"good signature", "https://example.com/redirect", time.Unix(1546797901, 0), "K3yqsJPahIzu5CdfCVJlIK4N8Dc135-27Tg1ROuQdhc=", "XeVJC2Iysq7mRUwOL3FX_5vx1d_kZV2HONHNig9fcKk="},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := redirectURLSignature(tt.rawRedirect, tt.timestamp, tt.secret)
|
|
||||||
out := base64.URLEncoding.EncodeToString(got)
|
|
||||||
if out != tt.want {
|
|
||||||
t.Errorf("redirectURLSignature() = %v, want %v", tt.want, out)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetHeaders(t *testing.T) {
|
func TestSetHeaders(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -209,7 +195,7 @@ func TestValidateSignature(t *testing.T) {
|
||||||
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
||||||
now := fmt.Sprint(time.Now().Unix())
|
now := fmt.Sprint(time.Now().Unix())
|
||||||
goodURL := "https://example.com/redirect"
|
goodURL := "https://example.com/redirect"
|
||||||
rawSig := redirectURLSignature(goodURL, time.Now(), secretA)
|
rawSig := hmacHelperFunc(goodURL, time.Now(), secretA)
|
||||||
sig := base64.URLEncoding.EncodeToString(rawSig)
|
sig := base64.URLEncoding.EncodeToString(rawSig)
|
||||||
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ const MaxNumChunks = 5
|
||||||
// CookieStore represents all the cookie related configurations
|
// CookieStore represents all the cookie related configurations
|
||||||
type CookieStore struct {
|
type CookieStore struct {
|
||||||
Name string
|
Name string
|
||||||
CookieCipher cryptutil.Cipher
|
Encoder cryptutil.SecureEncoder
|
||||||
CookieExpire time.Duration
|
CookieExpire time.Duration
|
||||||
CookieRefresh time.Duration
|
CookieRefresh time.Duration
|
||||||
CookieSecure bool
|
CookieSecure bool
|
||||||
|
@ -50,7 +50,7 @@ type CookieStoreOptions struct {
|
||||||
CookieDomain string
|
CookieDomain string
|
||||||
BearerTokenHeader string
|
BearerTokenHeader string
|
||||||
CookieExpire time.Duration
|
CookieExpire time.Duration
|
||||||
CookieCipher cryptutil.Cipher
|
Encoder cryptutil.SecureEncoder
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
|
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
|
||||||
|
@ -58,7 +58,7 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
|
||||||
if opts.Name == "" {
|
if opts.Name == "" {
|
||||||
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
|
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
|
||||||
}
|
}
|
||||||
if opts.CookieCipher == nil {
|
if opts.Encoder == nil {
|
||||||
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
|
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
|
||||||
}
|
}
|
||||||
if opts.BearerTokenHeader == "" {
|
if opts.BearerTokenHeader == "" {
|
||||||
|
@ -71,7 +71,7 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
|
||||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||||
CookieDomain: opts.CookieDomain,
|
CookieDomain: opts.CookieDomain,
|
||||||
CookieExpire: opts.CookieExpire,
|
CookieExpire: opts.CookieExpire,
|
||||||
CookieCipher: opts.CookieCipher,
|
Encoder: opts.Encoder,
|
||||||
BearerTokenHeader: opts.BearerTokenHeader,
|
BearerTokenHeader: opts.BearerTokenHeader,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -188,7 +188,7 @@ func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
|
||||||
if cipherText == "" {
|
if cipherText == "" {
|
||||||
return nil, ErrEmptySession
|
return nil, ErrEmptySession
|
||||||
}
|
}
|
||||||
session, err := UnmarshalSession(cipherText, cs.CookieCipher)
|
session, err := UnmarshalSession(cipherText, cs.Encoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -197,7 +197,7 @@ func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
|
||||||
|
|
||||||
// SaveSession saves a session state to a request sessions.
|
// SaveSession saves a session state to a request sessions.
|
||||||
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
|
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
|
||||||
value, err := MarshalSession(s, cs.CookieCipher)
|
value, err := MarshalSession(s, cs.Encoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,33 +15,21 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockCipher struct{}
|
type MockEncoder struct{}
|
||||||
|
|
||||||
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
|
func (a MockEncoder) Marshal(s interface{}) (string, error) { return "", errors.New("error") }
|
||||||
if string(s) == "error" {
|
func (a MockEncoder) Unmarshal(s string, i interface{}) error {
|
||||||
return []byte(""), errors.New("error encrypting")
|
|
||||||
}
|
|
||||||
return []byte("OK"), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
|
|
||||||
if string(s) == "error" {
|
|
||||||
return []byte(""), errors.New("error encrypting")
|
|
||||||
}
|
|
||||||
return []byte("OK"), nil
|
|
||||||
}
|
|
||||||
func (a mockCipher) Marshal(s interface{}) (string, error) { return "", errors.New("error") }
|
|
||||||
func (a mockCipher) Unmarshal(s string, i interface{}) error {
|
|
||||||
if s == "unmarshal error" || s == "error" {
|
if s == "unmarshal error" || s == "error" {
|
||||||
return errors.New("error")
|
return errors.New("error")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func TestNewCookieStore(t *testing.T) {
|
func TestNewCookieStore(t *testing.T) {
|
||||||
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
|
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
encoder := cryptutil.NewSecureJSONEncoder(cipher)
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
opts *CookieStoreOptions
|
opts *CookieStoreOptions
|
||||||
|
@ -55,7 +43,7 @@ func TestNewCookieStore(t *testing.T) {
|
||||||
CookieHTTPOnly: true,
|
CookieHTTPOnly: true,
|
||||||
CookieDomain: "pomerium.io",
|
CookieDomain: "pomerium.io",
|
||||||
CookieExpire: 10 * time.Second,
|
CookieExpire: 10 * time.Second,
|
||||||
CookieCipher: cipher,
|
Encoder: encoder,
|
||||||
BearerTokenHeader: "Authorization",
|
BearerTokenHeader: "Authorization",
|
||||||
},
|
},
|
||||||
&CookieStore{
|
&CookieStore{
|
||||||
|
@ -64,7 +52,7 @@ func TestNewCookieStore(t *testing.T) {
|
||||||
CookieHTTPOnly: true,
|
CookieHTTPOnly: true,
|
||||||
CookieDomain: "pomerium.io",
|
CookieDomain: "pomerium.io",
|
||||||
CookieExpire: 10 * time.Second,
|
CookieExpire: 10 * time.Second,
|
||||||
CookieCipher: cipher,
|
Encoder: encoder,
|
||||||
BearerTokenHeader: "Authorization",
|
BearerTokenHeader: "Authorization",
|
||||||
},
|
},
|
||||||
false},
|
false},
|
||||||
|
@ -75,7 +63,7 @@ func TestNewCookieStore(t *testing.T) {
|
||||||
CookieHTTPOnly: true,
|
CookieHTTPOnly: true,
|
||||||
CookieDomain: "pomerium.io",
|
CookieDomain: "pomerium.io",
|
||||||
CookieExpire: 10 * time.Second,
|
CookieExpire: 10 * time.Second,
|
||||||
CookieCipher: cipher,
|
Encoder: encoder,
|
||||||
BearerTokenHeader: "Authorization",
|
BearerTokenHeader: "Authorization",
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
|
@ -87,7 +75,7 @@ func TestNewCookieStore(t *testing.T) {
|
||||||
CookieHTTPOnly: true,
|
CookieHTTPOnly: true,
|
||||||
CookieDomain: "pomerium.io",
|
CookieDomain: "pomerium.io",
|
||||||
CookieExpire: 10 * time.Second,
|
CookieExpire: 10 * time.Second,
|
||||||
CookieCipher: nil,
|
Encoder: nil,
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
true},
|
true},
|
||||||
|
@ -100,7 +88,7 @@ func TestNewCookieStore(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cmpOpts := []cmp.Option{
|
cmpOpts := []cmp.Option{
|
||||||
cmpopts.IgnoreUnexported(cryptutil.XChaCha20Cipher{}),
|
cmpopts.IgnoreUnexported(cryptutil.SecureJSONEncoder{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
||||||
|
@ -111,7 +99,8 @@ func TestNewCookieStore(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCookieStore_makeCookie(t *testing.T) {
|
func TestCookieStore_makeCookie(t *testing.T) {
|
||||||
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
|
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -145,7 +134,7 @@ func TestCookieStore_makeCookie(t *testing.T) {
|
||||||
CookieHTTPOnly: true,
|
CookieHTTPOnly: true,
|
||||||
CookieDomain: tt.cookieDomain,
|
CookieDomain: tt.cookieDomain,
|
||||||
CookieExpire: 10 * time.Second,
|
CookieExpire: 10 * time.Second,
|
||||||
CookieCipher: cipher})
|
Encoder: cryptutil.NewSecureJSONEncoder(cipher)})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -161,10 +150,12 @@ func TestCookieStore_makeCookie(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCookieStore_SaveSession(t *testing.T) {
|
func TestCookieStore_SaveSession(t *testing.T) {
|
||||||
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
|
c, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
cipher := cryptutil.NewSecureJSONEncoder(c)
|
||||||
|
|
||||||
hugeString := make([]byte, 4097)
|
hugeString := make([]byte, 4097)
|
||||||
if _, err := rand.Read(hugeString); err != nil {
|
if _, err := rand.Read(hugeString); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -172,12 +163,12 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
State *State
|
State *State
|
||||||
cipher cryptutil.Cipher
|
cipher cryptutil.SecureEncoder
|
||||||
wantErr bool
|
wantErr bool
|
||||||
wantLoadErr bool
|
wantLoadErr bool
|
||||||
}{
|
}{
|
||||||
{"good", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
|
{"good", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
|
||||||
{"bad cipher", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, mockCipher{}, true, true},
|
{"bad cipher", &State{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, MockEncoder{}, true, true},
|
||||||
{"huge cookie", &State{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
|
{"huge cookie", &State{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -188,7 +179,7 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
||||||
CookieHTTPOnly: true,
|
CookieHTTPOnly: true,
|
||||||
CookieDomain: "pomerium.io",
|
CookieDomain: "pomerium.io",
|
||||||
CookieExpire: 10 * time.Second,
|
CookieExpire: 10 * time.Second,
|
||||||
CookieCipher: tt.cipher}
|
Encoder: tt.cipher}
|
||||||
|
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
|
@ -81,11 +81,13 @@ func TestVerifier(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
cipher, err := cryptutil.NewCipherFromBase64(cryptutil.NewBase64Key())
|
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||||
|
encoder := cryptutil.NewSecureJSONEncoder(cipher)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
encSession, err := MarshalSession(&tt.state, cipher)
|
encSession, err := MarshalSession(&tt.state, encoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -96,8 +98,8 @@ func TestVerifier(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
cs, err := NewCookieStore(&CookieStoreOptions{
|
cs, err := NewCookieStore(&CookieStoreOptions{
|
||||||
Name: "_pomerium",
|
Name: "_pomerium",
|
||||||
CookieCipher: cipher,
|
Encoder: encoder,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
|
@ -89,7 +89,7 @@ func (s *State) IssuedAt() (time.Time, error) {
|
||||||
|
|
||||||
// MarshalSession marshals the session state as JSON, encrypts the JSON using the
|
// MarshalSession marshals the session state as JSON, encrypts the JSON using the
|
||||||
// given cipher, and base64-encodes the result
|
// given cipher, and base64-encodes the result
|
||||||
func MarshalSession(s *State, c cryptutil.Cipher) (string, error) {
|
func MarshalSession(s *State, c cryptutil.SecureEncoder) (string, error) {
|
||||||
v, err := c.Marshal(s)
|
v, err := c.Marshal(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -99,7 +99,7 @@ func MarshalSession(s *State, c cryptutil.Cipher) (string, error) {
|
||||||
|
|
||||||
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
// UnmarshalSession takes the marshaled string, base64-decodes into a byte slice, decrypts the
|
||||||
// byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct
|
// byte slice using the passed cipher, and unmarshals the resulting JSON into a session state struct
|
||||||
func UnmarshalSession(value string, c cryptutil.Cipher) (*State, error) {
|
func UnmarshalSession(value string, c cryptutil.SecureEncoder) (*State, error) {
|
||||||
s := &State{}
|
s := &State{}
|
||||||
err := c.Unmarshal(value, s)
|
err := c.Unmarshal(value, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -13,7 +13,8 @@ import (
|
||||||
|
|
||||||
func TestStateSerialization(t *testing.T) {
|
func TestStateSerialization(t *testing.T) {
|
||||||
secret := cryptutil.NewKey()
|
secret := cryptutil.NewKey()
|
||||||
c, err := cryptutil.NewCipher(secret)
|
cipher, err := cryptutil.NewAEADCipher(secret)
|
||||||
|
c := cryptutil.NewSecureJSONEncoder(cipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -124,10 +125,12 @@ func TestState_Impersonating(t *testing.T) {
|
||||||
|
|
||||||
func TestMarshalSession(t *testing.T) {
|
func TestMarshalSession(t *testing.T) {
|
||||||
secret := cryptutil.NewKey()
|
secret := cryptutil.NewKey()
|
||||||
c, err := cryptutil.NewCipher(secret)
|
cipher, err := cryptutil.NewAEADCipher(secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||||
}
|
}
|
||||||
|
c := cryptutil.NewSecureJSONEncoder(cipher)
|
||||||
|
|
||||||
hugeString := make([]byte, 4097)
|
hugeString := make([]byte, 4097)
|
||||||
if _, err := rand.Read(hugeString); err != nil {
|
if _, err := rand.Read(hugeString); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
|
@ -81,9 +81,9 @@ func timestamp() int64 {
|
||||||
|
|
||||||
// SignedRedirectURL takes a destination URL and adds redirect_uri to it's
|
// SignedRedirectURL takes a destination URL and adds redirect_uri to it's
|
||||||
// query params, along with a timestamp and an keyed signature.
|
// query params, along with a timestamp and an keyed signature.
|
||||||
func SignedRedirectURL(key string, destination, urlToSign *url.URL) *url.URL {
|
func SignedRedirectURL(key string, destination, u *url.URL) *url.URL {
|
||||||
now := timestamp()
|
now := timestamp()
|
||||||
rawURL := urlToSign.String()
|
rawURL := u.String()
|
||||||
params, _ := url.ParseQuery(destination.RawQuery) // handled by incoming mux
|
params, _ := url.ParseQuery(destination.RawQuery) // handled by incoming mux
|
||||||
params.Set("redirect_uri", rawURL)
|
params.Set("redirect_uri", rawURL)
|
||||||
params.Set("ts", fmt.Sprint(now))
|
params.Set("ts", fmt.Sprint(now))
|
||||||
|
@ -95,7 +95,7 @@ func SignedRedirectURL(key string, destination, urlToSign *url.URL) *url.URL {
|
||||||
// hmacURL takes a redirect url string and timestamp and returns the base64
|
// hmacURL takes a redirect url string and timestamp and returns the base64
|
||||||
// encoded HMAC result.
|
// encoded HMAC result.
|
||||||
func hmacURL(key, data string, timestamp int64) string {
|
func hmacURL(key, data string, timestamp int64) string {
|
||||||
h := cryptutil.Hash(key, []byte(fmt.Sprint(data, timestamp)))
|
h := cryptutil.GenerateHMAC([]byte(fmt.Sprint(data, timestamp)), key)
|
||||||
return base64.URLEncoding.EncodeToString(h)
|
return base64.URLEncoding.EncodeToString(h)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -123,7 +123,7 @@ func TestProxy_router(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"}
|
p.encoder = &cryptutil.MockEncoder{MarshalResponse: "foo"}
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, tt.host, nil)
|
req := httptest.NewRequest(http.MethodGet, tt.host, nil)
|
||||||
_, ok := p.router(req)
|
_, ok := p.router(req)
|
||||||
|
@ -203,7 +203,7 @@ func TestProxy_Proxy(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
p.cipher = &cryptutil.MockCipher{MarshalResponse: "foo"}
|
p.encoder = &cryptutil.MockEncoder{MarshalResponse: "foo"}
|
||||||
p.sessionStore = tt.session
|
p.sessionStore = tt.session
|
||||||
p.AuthorizeClient = tt.authorizer
|
p.AuthorizeClient = tt.authorizer
|
||||||
r := httptest.NewRequest(tt.method, tt.host, nil)
|
r := httptest.NewRequest(tt.method, tt.host, nil)
|
||||||
|
@ -231,17 +231,17 @@ func TestProxy_UserDashboard(t *testing.T) {
|
||||||
ctxError error
|
ctxError error
|
||||||
options config.Options
|
options config.Options
|
||||||
method string
|
method string
|
||||||
cipher cryptutil.Cipher
|
cipher cryptutil.SecureEncoder
|
||||||
session sessions.SessionStore
|
session sessions.SessionStore
|
||||||
authorizer clients.Authorizer
|
authorizer clients.Authorizer
|
||||||
|
|
||||||
wantAdminForm bool
|
wantAdminForm bool
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK},
|
{"good", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusOK},
|
||||||
{"session context error", errors.New("error"), opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
|
{"session context error", errors.New("error"), opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
|
||||||
{"want admin form good admin authorization", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
|
{"want admin form good admin authorization", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
|
||||||
{"is admin but authorization fails", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
|
{"is admin but authorization fails", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -250,7 +250,7 @@ func TestProxy_UserDashboard(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
p.cipher = tt.cipher
|
p.encoder = tt.cipher
|
||||||
p.sessionStore = tt.session
|
p.sessionStore = tt.session
|
||||||
p.AuthorizeClient = tt.authorizer
|
p.AuthorizeClient = tt.authorizer
|
||||||
|
|
||||||
|
@ -289,17 +289,17 @@ func TestProxy_ForceRefresh(t *testing.T) {
|
||||||
ctxError error
|
ctxError error
|
||||||
options config.Options
|
options config.Options
|
||||||
method string
|
method string
|
||||||
cipher cryptutil.Cipher
|
cipher cryptutil.SecureEncoder
|
||||||
session sessions.SessionStore
|
session sessions.SessionStore
|
||||||
authorizer clients.Authorizer
|
authorizer clients.Authorizer
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
{"good", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
||||||
{"cannot load session", errors.New("error"), opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
{"cannot load session", errors.New("error"), opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||||
{"bad id token", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
{"bad id token", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "bad"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||||
{"issue date too soon", nil, timeSinceError, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest},
|
{"issue date too soon", nil, timeSinceError, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusBadRequest},
|
||||||
{"refresh failure", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
{"refresh failure", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusFound},
|
||||||
{"can't save refreshed session", nil, opts, http.MethodGet, &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
{"can't save refreshed session", nil, opts, http.MethodGet, &cryptutil.MockEncoder{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6IjA3YTA4MjgzOWYyZTcxYTliZjZjNTk2OTk2Yjk0NzM5Nzg1YWZkYzMiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI4NTE4NzcwODIwNTktYmZna3BqMDlub29nN2FzM2dwYzN0N3I2bjlzamJnczYuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTE0MzI2NTU5NzcyNzMxNTAzMDgiLCJoZCI6InBvbWVyaXVtLmlvIiwiZW1haWwiOiJiZGRAcG9tZXJpdW0uaW8iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6IlppQ1g0WndDYl9tcUVxM2xnbmFZRHciLCJuYW1lIjoiQm9iYnkgRGVTaW1vbmUiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tLy1PX1BzRTlILTgzRS9BQUFBQUFBQUFBSS9BQUFBQUFBQUFBQS9BQ0hpM3JjQ0U0SFRLVDBhQk1pUFVfOEZfVXFOQ3F6RTBRL3M5Ni1jL3Bob3RvLmpwZyIsImdpdmVuX25hbWUiOiJCb2JieSIsImZhbWlseV9uYW1lIjoiRGVTaW1vbmUiLCJsb2NhbGUiOiJlbiIsImlhdCI6MTU1ODY1NDEzNywiZXhwIjoxNTU4NjU3NzM3fQ.Flah31XfqmPhWYh2rJ-6rtowmSQFgp6HqDf1rpS38Wo0DXnIYmXxEQVLanDNV62Z0sLhUk1QO9NqoSgA3NscM-Ww-JsqU80oKnWcMYweUb_KU0kfHyTiUB0iEHMqu6tXn5dA_dIaPnL5oorXZ_gG4sooRxBZrDkaNAjRINLciKDQkUTVaNfnM6IBZ_pWDPd2lWGtj8h8sEIe2PIiH73Z2VLlXz8kw60VTPsi9U2zrF0ZJ9MfRGJhceQ58vW2ZlFfXJixgvbOZjKmcRv8NaJDIUss48l0Bsya6icZ0l1ZK-sAiFr0KVLTl2ywu8d5SQpTJ1X7vDW_u_04xaqDQUdYKA"}}, clients.MockAuthorize{}, http.StatusInternalServerError},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -307,7 +307,7 @@ func TestProxy_ForceRefresh(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
p.cipher = tt.cipher
|
p.encoder = tt.cipher
|
||||||
p.sessionStore = tt.session
|
p.sessionStore = tt.session
|
||||||
p.AuthorizeClient = tt.authorizer
|
p.AuthorizeClient = tt.authorizer
|
||||||
|
|
||||||
|
@ -340,18 +340,18 @@ func TestProxy_Impersonate(t *testing.T) {
|
||||||
email string
|
email string
|
||||||
groups string
|
groups string
|
||||||
csrf string
|
csrf string
|
||||||
cipher cryptutil.Cipher
|
cipher cryptutil.SecureEncoder
|
||||||
sessionStore sessions.SessionStore
|
sessionStore sessions.SessionStore
|
||||||
authorizer clients.Authorizer
|
authorizer clients.Authorizer
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||||
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||||
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||||
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
|
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
|
||||||
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
|
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusForbidden},
|
||||||
{"save session failure", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
{"save session failure", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||||
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockCipher{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{RefreshDeadline: time.Now().Add(10 * time.Second), Email: "user@test.example", IDToken: ""}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -359,7 +359,7 @@ func TestProxy_Impersonate(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
p.cipher = tt.cipher
|
p.encoder = tt.cipher
|
||||||
p.sessionStore = tt.sessionStore
|
p.sessionStore = tt.sessionStore
|
||||||
p.AuthorizeClient = tt.authorizer
|
p.AuthorizeClient = tt.authorizer
|
||||||
postForm := url.Values{}
|
postForm := url.Values{}
|
||||||
|
|
|
@ -43,11 +43,11 @@ const (
|
||||||
// ValidateOptions checks that proper configuration settings are set to create
|
// ValidateOptions checks that proper configuration settings are set to create
|
||||||
// a proper Proxy instance
|
// a proper Proxy instance
|
||||||
func ValidateOptions(o config.Options) error {
|
func ValidateOptions(o config.Options) error {
|
||||||
if _, err := cryptutil.NewCipherFromBase64(o.SharedKey); err != nil {
|
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
|
||||||
return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %v", err)
|
return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := cryptutil.NewCipherFromBase64(o.CookieSecret); err != nil {
|
if _, err := cryptutil.NewAEADCipherFromBase64(o.CookieSecret); err != nil {
|
||||||
return fmt.Errorf("proxy: invalid 'COOKIE_SECRET': %v", err)
|
return fmt.Errorf("proxy: invalid 'COOKIE_SECRET': %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,7 +78,8 @@ type Proxy struct {
|
||||||
|
|
||||||
AuthorizeClient clients.Authorizer
|
AuthorizeClient clients.Authorizer
|
||||||
|
|
||||||
cipher cryptutil.Cipher
|
// cipher cipher.AEAD
|
||||||
|
encoder cryptutil.SecureEncoder
|
||||||
cookieName string
|
cookieName string
|
||||||
cookieDomain string
|
cookieDomain string
|
||||||
cookieSecret []byte
|
cookieSecret []byte
|
||||||
|
@ -105,10 +106,12 @@ func New(opts config.Options) (*Proxy, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
cipher, err := cryptutil.NewCipherFromBase64(opts.CookieSecret)
|
cipher, err := cryptutil.NewAEADCipherFromBase64(opts.CookieSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
encoder := cryptutil.NewSecureJSONEncoder(cipher)
|
||||||
|
|
||||||
if opts.CookieDomain == "" {
|
if opts.CookieDomain == "" {
|
||||||
opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String())
|
opts.CookieDomain = sessions.ParentSubdomain(opts.AuthenticateURL.String())
|
||||||
}
|
}
|
||||||
|
@ -120,7 +123,7 @@ func New(opts config.Options) (*Proxy, error) {
|
||||||
CookieSecure: opts.CookieSecure,
|
CookieSecure: opts.CookieSecure,
|
||||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||||
CookieExpire: opts.CookieExpire,
|
CookieExpire: opts.CookieExpire,
|
||||||
CookieCipher: cipher,
|
Encoder: encoder,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -130,7 +133,7 @@ func New(opts config.Options) (*Proxy, error) {
|
||||||
SharedKey: opts.SharedKey,
|
SharedKey: opts.SharedKey,
|
||||||
|
|
||||||
routeConfigs: make(map[string]*routeConfig),
|
routeConfigs: make(map[string]*routeConfig),
|
||||||
cipher: cipher,
|
encoder: encoder,
|
||||||
cookieSecret: decodedCookieSecret,
|
cookieSecret: decodedCookieSecret,
|
||||||
cookieDomain: opts.CookieDomain,
|
cookieDomain: opts.CookieDomain,
|
||||||
cookieName: opts.CookieName,
|
cookieName: opts.CookieName,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue