mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 08:50:42 +02:00
Improve test coverage. (#8)
* Improve test coverage. * Remove unused http status code argument from SignInPageMethod. * Removed log package in internal packages. * Add test to check https scheme is used for authorization url. * Add unit tests for global logging package.
This commit is contained in:
parent
5a75ace403
commit
56c89e8653
14 changed files with 478 additions and 105 deletions
|
@ -92,14 +92,12 @@ func (o *Options) Validate() error {
|
|||
if o.SharedKey == "" {
|
||||
return errors.New("missing setting: shared secret")
|
||||
}
|
||||
|
||||
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate options: cookie secret invalid"+
|
||||
"must be a base64-encoded, 256 bit key e.g. `head -c32 /dev/urandom | base64`"+
|
||||
"got %q", err)
|
||||
}
|
||||
|
||||
validCookieSecretLength := false
|
||||
for _, i := range []int{32, 64} {
|
||||
if len(decodedCookieSecret) == i {
|
||||
|
@ -108,11 +106,8 @@ func (o *Options) Validate() error {
|
|||
}
|
||||
|
||||
if !validCookieSecretLength {
|
||||
return fmt.Errorf("authenticate options: invalid cookie secret strength want 32 to 64 bytes, got %d bytes", len(decodedCookieSecret))
|
||||
}
|
||||
|
||||
if o.CookieRefresh >= o.CookieExpire {
|
||||
return fmt.Errorf("cookie_refresh (%s) must be less than cookie_expire (%s)", o.CookieRefresh.String(), o.CookieExpire.String())
|
||||
return fmt.Errorf("authenticate options: invalid cookie secret strength want"+
|
||||
" 32 to 64 bytes, got %d bytes", len(decodedCookieSecret))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
129
authenticate/authenticate_test.go
Normal file
129
authenticate/authenticate_test.go
Normal file
|
@ -0,0 +1,129 @@
|
|||
package authenticate
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func testOptions() *Options {
|
||||
redirectURL, _ := url.Parse("https://example.com/oauth2/callback")
|
||||
return &Options{
|
||||
ProxyRootDomains: []string{"example.com"},
|
||||
AllowedDomains: []string{"example.com"},
|
||||
RedirectURL: redirectURL,
|
||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||
CookieRefresh: time.Duration(1) * time.Hour,
|
||||
SessionLifetimeTTL: time.Duration(720) * time.Hour,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptions_Validate(t *testing.T) {
|
||||
good := testOptions()
|
||||
badRedirectURL := testOptions()
|
||||
badRedirectURL.RedirectURL = nil
|
||||
malformedRedirectURL := testOptions()
|
||||
redirectURL, _ := url.Parse("https://example.com/oauth3/callback")
|
||||
malformedRedirectURL.RedirectURL = redirectURL
|
||||
emptyClientID := testOptions()
|
||||
emptyClientID.ClientID = ""
|
||||
emptyClientSecret := testOptions()
|
||||
emptyClientSecret.ClientSecret = ""
|
||||
allowedDomains := testOptions()
|
||||
allowedDomains.AllowedDomains = nil
|
||||
proxyRootDomains := testOptions()
|
||||
proxyRootDomains.ProxyRootDomains = nil
|
||||
emptyCookieSecret := testOptions()
|
||||
emptyCookieSecret.CookieSecret = ""
|
||||
invalidCookieSecret := testOptions()
|
||||
invalidCookieSecret.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^"
|
||||
shortCookieLength := testOptions()
|
||||
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
|
||||
badSharedKey := testOptions()
|
||||
badSharedKey.SharedKey = ""
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
o *Options
|
||||
wantErr bool
|
||||
}{
|
||||
{"minimum options", good, false},
|
||||
{"nil options", &Options{}, true},
|
||||
{"bad redirect url", badRedirectURL, true},
|
||||
{"malformed redirect url", malformedRedirectURL, true},
|
||||
{"no cookie secret", emptyCookieSecret, true},
|
||||
{"invalid cookie secret", invalidCookieSecret, true},
|
||||
{"short cookie secret", shortCookieLength, true},
|
||||
{"no shared secret", badSharedKey, true},
|
||||
{"no client id", emptyClientID, true},
|
||||
{"no client secret", emptyClientSecret, true},
|
||||
{"empty allowed domains", allowedDomains, true},
|
||||
{"empty root domains", proxyRootDomains, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := tt.o
|
||||
if err := o.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionsFromEnvConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want *Options
|
||||
envKey string
|
||||
envValue string
|
||||
wantErr bool
|
||||
}{
|
||||
{"good default, no env settings", defaultOptions, "", "", false},
|
||||
{"bad url", nil, "REDIRECT_URL", "%.rjlw", true},
|
||||
{"good duration", defaultOptions, "COOKIE_EXPIRE", "1m", false},
|
||||
{"bad duration", nil, "COOKIE_EXPIRE", "1sm", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envKey != "" {
|
||||
os.Setenv(tt.envKey, tt.envValue)
|
||||
defer os.Unsetenv(tt.envKey)
|
||||
}
|
||||
got, err := OptionsFromEnvConfig()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("OptionsFromEnvConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("OptionsFromEnvConfig() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_dotPrependDomains(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
d []string
|
||||
want []string
|
||||
}{
|
||||
{"single domain", []string{"google.com"}, []string{".google.com"}},
|
||||
{"multiple domain", []string{"google.com", "bing.com"}, []string{".google.com", ".bing.com"}},
|
||||
{"empty", []string{""}, []string{""}},
|
||||
{"nested subdomain", []string{"some.really.long.domain.com"}, []string{".some.really.long.domain.com"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := dotPrependDomains(tt.d); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("dotPrependDomains() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -81,10 +81,9 @@ func (p *Authenticator) PingPage(rw http.ResponseWriter, req *http.Request) {
|
|||
fmt.Fprintf(rw, "OK")
|
||||
}
|
||||
|
||||
// SignInPage directs the user to the sign in page
|
||||
func (p *Authenticator) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
|
||||
// SignInPage directs the user to the sign in page. Takes a `redirect_uri` param.
|
||||
func (p *Authenticator) SignInPage(rw http.ResponseWriter, req *http.Request) {
|
||||
requestLog := log.WithRequest(req, "authenticate.SignInPage")
|
||||
rw.WriteHeader(code)
|
||||
redirectURL := p.RedirectURL.ResolveReference(req.URL)
|
||||
// validateRedirectURI middleware already ensures that this is a valid URL
|
||||
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))
|
||||
|
@ -107,12 +106,12 @@ func (p *Authenticator) SignInPage(rw http.ResponseWriter, req *http.Request, co
|
|||
Str("Destination", destinationURL.Host).
|
||||
Str("AllowedDomains", strings.Join(p.AllowedDomains, ", ")).
|
||||
Msg("authenticate.SignInPage")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
p.templates.ExecuteTemplate(rw, "sign_in.html", t)
|
||||
}
|
||||
|
||||
func (p *Authenticator) authenticate(rw http.ResponseWriter, req *http.Request) (*sessions.SessionState, error) {
|
||||
requestLog := log.WithRequest(req, "authenticate.authenticate")
|
||||
|
||||
session, err := p.sessionStore.LoadSession(req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("authenticate.authenticate")
|
||||
|
@ -169,7 +168,6 @@ func (p *Authenticator) authenticate(rw http.ResponseWriter, req *http.Request)
|
|||
requestLog.Error().Msg("invalid email user")
|
||||
return nil, httputil.ErrUserNotAuthorized
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
|
@ -193,11 +191,11 @@ func (p *Authenticator) SignIn(rw http.ResponseWriter, req *http.Request) {
|
|||
p.ProxyOAuthRedirect(rw, req, session)
|
||||
case http.ErrNoCookie:
|
||||
log.Error().Err(err).Msg("authenticate.SignIn : err no cookie")
|
||||
p.SignInPage(rw, req, http.StatusOK)
|
||||
p.SignInPage(rw, req)
|
||||
case sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
||||
log.Error().Err(err).Msg("authenticate.SignIn : invalid cookie cookie")
|
||||
p.sessionStore.ClearSession(rw, req)
|
||||
p.SignInPage(rw, req, http.StatusOK)
|
||||
p.SignInPage(rw, req)
|
||||
default:
|
||||
log.Error().Err(err).Msg("authenticate.SignIn : unknown error cookie")
|
||||
httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
|
||||
|
|
98
authenticate/handlers_test.go
Normal file
98
authenticate/handlers_test.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package authenticate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate/providers"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
)
|
||||
|
||||
func testAuthenticator() *Authenticator {
|
||||
var auth Authenticator
|
||||
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
|
||||
auth.SharedKey = "IzY7MOZwzfOkmELXgozHDKTxoT3nOYhwkcmUVINsRww="
|
||||
auth.AllowedDomains = []string{"*"}
|
||||
auth.ProxyRootDomains = []string{"example.com"}
|
||||
auth.templates = templates.New()
|
||||
auth.provider = providers.NewTestProvider(auth.RedirectURL)
|
||||
return &auth
|
||||
}
|
||||
|
||||
func TestAuthenticator_PingPage(t *testing.T) {
|
||||
auth := testAuthenticator()
|
||||
req, err := http.NewRequest("GET", "/ping", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
handler := http.HandlerFunc(auth.PingPage)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
||||
}
|
||||
expected := "OK"
|
||||
if rr.Body.String() != expected {
|
||||
t.Errorf("handler returned wrong body: got %v want %v", rr.Body.String(), expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticator_RobotsTxt(t *testing.T) {
|
||||
auth := testAuthenticator()
|
||||
req, err := http.NewRequest("GET", "/robots.txt", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
handler := http.HandlerFunc(auth.RobotsTxt)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
||||
}
|
||||
expected := fmt.Sprintf("User-agent: *\nDisallow: /")
|
||||
if rr.Body.String() != expected {
|
||||
t.Errorf("handler returned wrong body: got %v want %v", rr.Body.String(), expected)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticator_SignInPage(t *testing.T) {
|
||||
auth := testAuthenticator()
|
||||
v := url.Values{}
|
||||
v.Set("request_uri", "this-is-a-test-uri")
|
||||
url := fmt.Sprintf("/signin?%s", v.Encode())
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
handler := http.HandlerFunc(auth.SignInPage)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
||||
}
|
||||
body := []byte(rr.Body.String())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
want bool
|
||||
}{
|
||||
{"provider name", auth.provider.Data().ProviderName, true},
|
||||
{"destination url", v.Encode(), true},
|
||||
{"shouldn't be found", "this string should not be in the body", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := bytes.Contains(body, []byte(tt.value)); got != tt.want {
|
||||
t.Errorf("handler body missing expected value %v", tt.value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -14,6 +14,8 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
)
|
||||
|
||||
var defaultSignatureValidityDuration = 5 * time.Minute
|
||||
|
||||
// validateRedirectURI checks the redirect uri in the query parameters and ensures that
|
||||
// the url's domain is one in the list of proxy root domains.
|
||||
func validateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
|
||||
|
@ -34,11 +36,17 @@ func validateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.Han
|
|||
}
|
||||
|
||||
func validRedirectURI(uri string, rootDomains []string) bool {
|
||||
if uri == "" || len(rootDomains) == 0 {
|
||||
return false
|
||||
}
|
||||
redirectURL, err := url.Parse(uri)
|
||||
if uri == "" || err != nil || redirectURL.Host == "" {
|
||||
if err != nil || redirectURL.Host == "" {
|
||||
return false
|
||||
}
|
||||
for _, domain := range rootDomains {
|
||||
if domain == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasSuffix(redirectURL.Hostname(), domain) {
|
||||
return true
|
||||
}
|
||||
|
@ -65,6 +73,8 @@ func validateSignature(f http.HandlerFunc, sharedKey string) http.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
// validateSignature ensures the validity of the redirect url by comparing the hmac
|
||||
// digest, and ensuring that the included timestamp is fresh
|
||||
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
||||
return false
|
||||
|
@ -82,14 +92,15 @@ func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
|||
return false
|
||||
}
|
||||
tm := time.Unix(i, 0)
|
||||
ttl := 5 * time.Minute
|
||||
if time.Now().Sub(tm) > ttl {
|
||||
if time.Now().Sub(tm) > defaultSignatureValidityDuration {
|
||||
return false
|
||||
}
|
||||
localSig := redirectURLSignature(redirectURI, tm, secret)
|
||||
return hmac.Equal(requestSig, localSig)
|
||||
}
|
||||
|
||||
// redirectURLSignature generates a hmac digest from a
|
||||
// redirect url, a timestamp, and a secret.
|
||||
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(rawRedirect))
|
||||
|
|
85
authenticate/middleware_test.go
Normal file
85
authenticate/middleware_test.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package authenticate
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_validRedirectURI(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
rootDomains []string
|
||||
want bool
|
||||
}{
|
||||
{"good url redirect", "https://example.com/redirect", []string{"example.com"}, true},
|
||||
{"bad domain", "https://example.com/redirect", []string{"notexample.com"}, false},
|
||||
{"malformed url", "^example.com/redirect", []string{"notexample.com"}, false},
|
||||
{"empty domain list", "https://example.com/redirect", []string{}, false},
|
||||
{"empty domain", "https://example.com/redirect", []string{""}, false},
|
||||
{"empty url", "", []string{"example.com"}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := validRedirectURI(tt.uri, tt.rootDomains); got != tt.want {
|
||||
t.Errorf("validRedirectURI() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validSignature(t *testing.T) {
|
||||
goodUrl := "https://example.com/redirect"
|
||||
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
||||
now := fmt.Sprint(time.Now().Unix())
|
||||
rawSig := redirectURLSignature(goodUrl, time.Now(), secretA)
|
||||
sig := base64.URLEncoding.EncodeToString(rawSig)
|
||||
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
redirectURI string
|
||||
sigVal string
|
||||
timestamp string
|
||||
secret string
|
||||
want bool
|
||||
}{
|
||||
{"good signature", goodUrl, string(sig), now, secretA, true},
|
||||
{"empty redirect url", "", string(sig), now, secretA, false},
|
||||
{"bad redirect url", "https://google.com^", string(sig), now, secretA, false},
|
||||
{"malformed signature", goodUrl, string(sig + "^"), now, "&*&@**($&#(", false},
|
||||
{"malformed timestamp", goodUrl, string(sig), now + "^", secretA, false},
|
||||
{"stale timestamp", goodUrl, string(sig), staleTime, secretA, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := validSignature(tt.redirectURI, tt.sigVal, tt.timestamp, tt.secret); got != tt.want {
|
||||
t.Errorf("validSignature() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_redirectURLSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawRedirect string
|
||||
timestamp time.Time
|
||||
secret string
|
||||
want string
|
||||
}{
|
||||
{"good signature", "https://example.com/redirect", time.Unix(1546797901, 0), "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=", "GIDyWKjrG_7MwXwIq1o51f2pDT_rH9aLHdsHxSBEwy8="},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := redirectURLSignature(tt.rawRedirect, tt.timestamp, tt.secret)
|
||||
out := base64.URLEncoding.EncodeToString(got)
|
||||
if out != tt.want {
|
||||
t.Errorf("redirectURLSignature() = %v, want %v", tt.want, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue