mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 15:47:36 +02:00
Improve test coverage. (#8)
* Improve test coverage. * Remove unused http status code argument from SignInPageMethod. * Removed log package in internal packages. * Add test to check https scheme is used for authorization url. * Add unit tests for global logging package.
This commit is contained in:
parent
5a75ace403
commit
56c89e8653
14 changed files with 478 additions and 105 deletions
|
@ -92,14 +92,12 @@ func (o *Options) Validate() error {
|
||||||
if o.SharedKey == "" {
|
if o.SharedKey == "" {
|
||||||
return errors.New("missing setting: shared secret")
|
return errors.New("missing setting: shared secret")
|
||||||
}
|
}
|
||||||
|
|
||||||
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
|
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("authenticate options: cookie secret invalid"+
|
return fmt.Errorf("authenticate options: cookie secret invalid"+
|
||||||
"must be a base64-encoded, 256 bit key e.g. `head -c32 /dev/urandom | base64`"+
|
"must be a base64-encoded, 256 bit key e.g. `head -c32 /dev/urandom | base64`"+
|
||||||
"got %q", err)
|
"got %q", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
validCookieSecretLength := false
|
validCookieSecretLength := false
|
||||||
for _, i := range []int{32, 64} {
|
for _, i := range []int{32, 64} {
|
||||||
if len(decodedCookieSecret) == i {
|
if len(decodedCookieSecret) == i {
|
||||||
|
@ -108,11 +106,8 @@ func (o *Options) Validate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !validCookieSecretLength {
|
if !validCookieSecretLength {
|
||||||
return fmt.Errorf("authenticate options: invalid cookie secret strength want 32 to 64 bytes, got %d bytes", len(decodedCookieSecret))
|
return fmt.Errorf("authenticate options: invalid cookie secret strength want"+
|
||||||
}
|
" 32 to 64 bytes, got %d bytes", len(decodedCookieSecret))
|
||||||
|
|
||||||
if o.CookieRefresh >= o.CookieExpire {
|
|
||||||
return fmt.Errorf("cookie_refresh (%s) must be less than cookie_expire (%s)", o.CookieRefresh.String(), o.CookieExpire.String())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
129
authenticate/authenticate_test.go
Normal file
129
authenticate/authenticate_test.go
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
package authenticate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testOptions() *Options {
|
||||||
|
redirectURL, _ := url.Parse("https://example.com/oauth2/callback")
|
||||||
|
return &Options{
|
||||||
|
ProxyRootDomains: []string{"example.com"},
|
||||||
|
AllowedDomains: []string{"example.com"},
|
||||||
|
RedirectURL: redirectURL,
|
||||||
|
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
ClientSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||||
|
CookieSecret: "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=",
|
||||||
|
CookieRefresh: time.Duration(1) * time.Hour,
|
||||||
|
SessionLifetimeTTL: time.Duration(720) * time.Hour,
|
||||||
|
CookieExpire: time.Duration(168) * time.Hour,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptions_Validate(t *testing.T) {
|
||||||
|
good := testOptions()
|
||||||
|
badRedirectURL := testOptions()
|
||||||
|
badRedirectURL.RedirectURL = nil
|
||||||
|
malformedRedirectURL := testOptions()
|
||||||
|
redirectURL, _ := url.Parse("https://example.com/oauth3/callback")
|
||||||
|
malformedRedirectURL.RedirectURL = redirectURL
|
||||||
|
emptyClientID := testOptions()
|
||||||
|
emptyClientID.ClientID = ""
|
||||||
|
emptyClientSecret := testOptions()
|
||||||
|
emptyClientSecret.ClientSecret = ""
|
||||||
|
allowedDomains := testOptions()
|
||||||
|
allowedDomains.AllowedDomains = nil
|
||||||
|
proxyRootDomains := testOptions()
|
||||||
|
proxyRootDomains.ProxyRootDomains = nil
|
||||||
|
emptyCookieSecret := testOptions()
|
||||||
|
emptyCookieSecret.CookieSecret = ""
|
||||||
|
invalidCookieSecret := testOptions()
|
||||||
|
invalidCookieSecret.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw^"
|
||||||
|
shortCookieLength := testOptions()
|
||||||
|
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
|
||||||
|
badSharedKey := testOptions()
|
||||||
|
badSharedKey.SharedKey = ""
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
o *Options
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"minimum options", good, false},
|
||||||
|
{"nil options", &Options{}, true},
|
||||||
|
{"bad redirect url", badRedirectURL, true},
|
||||||
|
{"malformed redirect url", malformedRedirectURL, true},
|
||||||
|
{"no cookie secret", emptyCookieSecret, true},
|
||||||
|
{"invalid cookie secret", invalidCookieSecret, true},
|
||||||
|
{"short cookie secret", shortCookieLength, true},
|
||||||
|
{"no shared secret", badSharedKey, true},
|
||||||
|
{"no client id", emptyClientID, true},
|
||||||
|
{"no client secret", emptyClientSecret, true},
|
||||||
|
{"empty allowed domains", allowedDomains, true},
|
||||||
|
{"empty root domains", proxyRootDomains, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
o := tt.o
|
||||||
|
if err := o.Validate(); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionsFromEnvConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want *Options
|
||||||
|
envKey string
|
||||||
|
envValue string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"good default, no env settings", defaultOptions, "", "", false},
|
||||||
|
{"bad url", nil, "REDIRECT_URL", "%.rjlw", true},
|
||||||
|
{"good duration", defaultOptions, "COOKIE_EXPIRE", "1m", false},
|
||||||
|
{"bad duration", nil, "COOKIE_EXPIRE", "1sm", true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.envKey != "" {
|
||||||
|
os.Setenv(tt.envKey, tt.envValue)
|
||||||
|
defer os.Unsetenv(tt.envKey)
|
||||||
|
}
|
||||||
|
got, err := OptionsFromEnvConfig()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("OptionsFromEnvConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("OptionsFromEnvConfig() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_dotPrependDomains(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
d []string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{"single domain", []string{"google.com"}, []string{".google.com"}},
|
||||||
|
{"multiple domain", []string{"google.com", "bing.com"}, []string{".google.com", ".bing.com"}},
|
||||||
|
{"empty", []string{""}, []string{""}},
|
||||||
|
{"nested subdomain", []string{"some.really.long.domain.com"}, []string{".some.really.long.domain.com"}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := dotPrependDomains(tt.d); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("dotPrependDomains() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -81,10 +81,9 @@ func (p *Authenticator) PingPage(rw http.ResponseWriter, req *http.Request) {
|
||||||
fmt.Fprintf(rw, "OK")
|
fmt.Fprintf(rw, "OK")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignInPage directs the user to the sign in page
|
// SignInPage directs the user to the sign in page. Takes a `redirect_uri` param.
|
||||||
func (p *Authenticator) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
|
func (p *Authenticator) SignInPage(rw http.ResponseWriter, req *http.Request) {
|
||||||
requestLog := log.WithRequest(req, "authenticate.SignInPage")
|
requestLog := log.WithRequest(req, "authenticate.SignInPage")
|
||||||
rw.WriteHeader(code)
|
|
||||||
redirectURL := p.RedirectURL.ResolveReference(req.URL)
|
redirectURL := p.RedirectURL.ResolveReference(req.URL)
|
||||||
// validateRedirectURI middleware already ensures that this is a valid URL
|
// validateRedirectURI middleware already ensures that this is a valid URL
|
||||||
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))
|
destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))
|
||||||
|
@ -107,12 +106,12 @@ func (p *Authenticator) SignInPage(rw http.ResponseWriter, req *http.Request, co
|
||||||
Str("Destination", destinationURL.Host).
|
Str("Destination", destinationURL.Host).
|
||||||
Str("AllowedDomains", strings.Join(p.AllowedDomains, ", ")).
|
Str("AllowedDomains", strings.Join(p.AllowedDomains, ", ")).
|
||||||
Msg("authenticate.SignInPage")
|
Msg("authenticate.SignInPage")
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
p.templates.ExecuteTemplate(rw, "sign_in.html", t)
|
p.templates.ExecuteTemplate(rw, "sign_in.html", t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Authenticator) authenticate(rw http.ResponseWriter, req *http.Request) (*sessions.SessionState, error) {
|
func (p *Authenticator) authenticate(rw http.ResponseWriter, req *http.Request) (*sessions.SessionState, error) {
|
||||||
requestLog := log.WithRequest(req, "authenticate.authenticate")
|
requestLog := log.WithRequest(req, "authenticate.authenticate")
|
||||||
|
|
||||||
session, err := p.sessionStore.LoadSession(req)
|
session, err := p.sessionStore.LoadSession(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("authenticate.authenticate")
|
log.Error().Err(err).Msg("authenticate.authenticate")
|
||||||
|
@ -169,7 +168,6 @@ func (p *Authenticator) authenticate(rw http.ResponseWriter, req *http.Request)
|
||||||
requestLog.Error().Msg("invalid email user")
|
requestLog.Error().Msg("invalid email user")
|
||||||
return nil, httputil.ErrUserNotAuthorized
|
return nil, httputil.ErrUserNotAuthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,11 +191,11 @@ func (p *Authenticator) SignIn(rw http.ResponseWriter, req *http.Request) {
|
||||||
p.ProxyOAuthRedirect(rw, req, session)
|
p.ProxyOAuthRedirect(rw, req, session)
|
||||||
case http.ErrNoCookie:
|
case http.ErrNoCookie:
|
||||||
log.Error().Err(err).Msg("authenticate.SignIn : err no cookie")
|
log.Error().Err(err).Msg("authenticate.SignIn : err no cookie")
|
||||||
p.SignInPage(rw, req, http.StatusOK)
|
p.SignInPage(rw, req)
|
||||||
case sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
case sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
|
||||||
log.Error().Err(err).Msg("authenticate.SignIn : invalid cookie cookie")
|
log.Error().Err(err).Msg("authenticate.SignIn : invalid cookie cookie")
|
||||||
p.sessionStore.ClearSession(rw, req)
|
p.sessionStore.ClearSession(rw, req)
|
||||||
p.SignInPage(rw, req, http.StatusOK)
|
p.SignInPage(rw, req)
|
||||||
default:
|
default:
|
||||||
log.Error().Err(err).Msg("authenticate.SignIn : unknown error cookie")
|
log.Error().Err(err).Msg("authenticate.SignIn : unknown error cookie")
|
||||||
httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
|
httputil.ErrorResponse(rw, req, err.Error(), httputil.CodeForError(err))
|
||||||
|
|
98
authenticate/handlers_test.go
Normal file
98
authenticate/handlers_test.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package authenticate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/authenticate/providers"
|
||||||
|
"github.com/pomerium/pomerium/internal/templates"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testAuthenticator() *Authenticator {
|
||||||
|
var auth Authenticator
|
||||||
|
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
|
||||||
|
auth.SharedKey = "IzY7MOZwzfOkmELXgozHDKTxoT3nOYhwkcmUVINsRww="
|
||||||
|
auth.AllowedDomains = []string{"*"}
|
||||||
|
auth.ProxyRootDomains = []string{"example.com"}
|
||||||
|
auth.templates = templates.New()
|
||||||
|
auth.provider = providers.NewTestProvider(auth.RedirectURL)
|
||||||
|
return &auth
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticator_PingPage(t *testing.T) {
|
||||||
|
auth := testAuthenticator()
|
||||||
|
req, err := http.NewRequest("GET", "/ping", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler := http.HandlerFunc(auth.PingPage)
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if status := rr.Code; status != http.StatusOK {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
||||||
|
}
|
||||||
|
expected := "OK"
|
||||||
|
if rr.Body.String() != expected {
|
||||||
|
t.Errorf("handler returned wrong body: got %v want %v", rr.Body.String(), expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticator_RobotsTxt(t *testing.T) {
|
||||||
|
auth := testAuthenticator()
|
||||||
|
req, err := http.NewRequest("GET", "/robots.txt", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler := http.HandlerFunc(auth.RobotsTxt)
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if status := rr.Code; status != http.StatusOK {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
||||||
|
}
|
||||||
|
expected := fmt.Sprintf("User-agent: *\nDisallow: /")
|
||||||
|
if rr.Body.String() != expected {
|
||||||
|
t.Errorf("handler returned wrong body: got %v want %v", rr.Body.String(), expected)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticator_SignInPage(t *testing.T) {
|
||||||
|
auth := testAuthenticator()
|
||||||
|
v := url.Values{}
|
||||||
|
v.Set("request_uri", "this-is-a-test-uri")
|
||||||
|
url := fmt.Sprintf("/signin?%s", v.Encode())
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler := http.HandlerFunc(auth.SignInPage)
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if status := rr.Code; status != http.StatusOK {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
||||||
|
}
|
||||||
|
body := []byte(rr.Body.String())
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
value string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"provider name", auth.provider.Data().ProviderName, true},
|
||||||
|
{"destination url", v.Encode(), true},
|
||||||
|
{"shouldn't be found", "this string should not be in the body", false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := bytes.Contains(body, []byte(tt.value)); got != tt.want {
|
||||||
|
t.Errorf("handler body missing expected value %v", tt.value)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,6 +14,8 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var defaultSignatureValidityDuration = 5 * time.Minute
|
||||||
|
|
||||||
// validateRedirectURI checks the redirect uri in the query parameters and ensures that
|
// validateRedirectURI checks the redirect uri in the query parameters and ensures that
|
||||||
// the url's domain is one in the list of proxy root domains.
|
// the url's domain is one in the list of proxy root domains.
|
||||||
func validateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
|
func validateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.HandlerFunc {
|
||||||
|
@ -34,11 +36,17 @@ func validateRedirectURI(f http.HandlerFunc, proxyRootDomains []string) http.Han
|
||||||
}
|
}
|
||||||
|
|
||||||
func validRedirectURI(uri string, rootDomains []string) bool {
|
func validRedirectURI(uri string, rootDomains []string) bool {
|
||||||
|
if uri == "" || len(rootDomains) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
redirectURL, err := url.Parse(uri)
|
redirectURL, err := url.Parse(uri)
|
||||||
if uri == "" || err != nil || redirectURL.Host == "" {
|
if err != nil || redirectURL.Host == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
for _, domain := range rootDomains {
|
for _, domain := range rootDomains {
|
||||||
|
if domain == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
if strings.HasSuffix(redirectURL.Hostname(), domain) {
|
if strings.HasSuffix(redirectURL.Hostname(), domain) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -65,6 +73,8 @@ func validateSignature(f http.HandlerFunc, sharedKey string) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateSignature ensures the validity of the redirect url by comparing the hmac
|
||||||
|
// digest, and ensuring that the included timestamp is fresh
|
||||||
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||||
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
||||||
return false
|
return false
|
||||||
|
@ -82,14 +92,15 @@ func validSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
tm := time.Unix(i, 0)
|
tm := time.Unix(i, 0)
|
||||||
ttl := 5 * time.Minute
|
if time.Now().Sub(tm) > defaultSignatureValidityDuration {
|
||||||
if time.Now().Sub(tm) > ttl {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
localSig := redirectURLSignature(redirectURI, tm, secret)
|
localSig := redirectURLSignature(redirectURI, tm, secret)
|
||||||
return hmac.Equal(requestSig, localSig)
|
return hmac.Equal(requestSig, localSig)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// redirectURLSignature generates a hmac digest from a
|
||||||
|
// redirect url, a timestamp, and a secret.
|
||||||
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
|
func redirectURLSignature(rawRedirect string, timestamp time.Time, secret string) []byte {
|
||||||
h := hmac.New(sha256.New, []byte(secret))
|
h := hmac.New(sha256.New, []byte(secret))
|
||||||
h.Write([]byte(rawRedirect))
|
h.Write([]byte(rawRedirect))
|
||||||
|
|
85
authenticate/middleware_test.go
Normal file
85
authenticate/middleware_test.go
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
package authenticate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_validRedirectURI(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
uri string
|
||||||
|
rootDomains []string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"good url redirect", "https://example.com/redirect", []string{"example.com"}, true},
|
||||||
|
{"bad domain", "https://example.com/redirect", []string{"notexample.com"}, false},
|
||||||
|
{"malformed url", "^example.com/redirect", []string{"notexample.com"}, false},
|
||||||
|
{"empty domain list", "https://example.com/redirect", []string{}, false},
|
||||||
|
{"empty domain", "https://example.com/redirect", []string{""}, false},
|
||||||
|
{"empty url", "", []string{"example.com"}, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := validRedirectURI(tt.uri, tt.rootDomains); got != tt.want {
|
||||||
|
t.Errorf("validRedirectURI() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_validSignature(t *testing.T) {
|
||||||
|
goodUrl := "https://example.com/redirect"
|
||||||
|
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
||||||
|
now := fmt.Sprint(time.Now().Unix())
|
||||||
|
rawSig := redirectURLSignature(goodUrl, time.Now(), secretA)
|
||||||
|
sig := base64.URLEncoding.EncodeToString(rawSig)
|
||||||
|
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
redirectURI string
|
||||||
|
sigVal string
|
||||||
|
timestamp string
|
||||||
|
secret string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"good signature", goodUrl, string(sig), now, secretA, true},
|
||||||
|
{"empty redirect url", "", string(sig), now, secretA, false},
|
||||||
|
{"bad redirect url", "https://google.com^", string(sig), now, secretA, false},
|
||||||
|
{"malformed signature", goodUrl, string(sig + "^"), now, "&*&@**($&#(", false},
|
||||||
|
{"malformed timestamp", goodUrl, string(sig), now + "^", secretA, false},
|
||||||
|
{"stale timestamp", goodUrl, string(sig), staleTime, secretA, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := validSignature(tt.redirectURI, tt.sigVal, tt.timestamp, tt.secret); got != tt.want {
|
||||||
|
t.Errorf("validSignature() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_redirectURLSignature(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rawRedirect string
|
||||||
|
timestamp time.Time
|
||||||
|
secret string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"good signature", "https://example.com/redirect", time.Unix(1546797901, 0), "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=", "GIDyWKjrG_7MwXwIq1o51f2pDT_rH9aLHdsHxSBEwy8="},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := redirectURLSignature(tt.rawRedirect, tt.timestamp, tt.secret)
|
||||||
|
out := base64.URLEncoding.EncodeToString(got)
|
||||||
|
if out != tt.want {
|
||||||
|
t.Errorf("redirectURLSignature() = %v, want %v", tt.want, out)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,36 +0,0 @@
|
||||||
|
|
||||||
## Generating random seeds
|
|
||||||
In order of preference:
|
|
||||||
- `head -c32 /dev/urandom | base64`
|
|
||||||
- `openssl rand -base64 32 | head -c 32 | base64`
|
|
||||||
## Encrypting data
|
|
||||||
|
|
||||||
TL;DR -- Nonce reuse is a problem. AEAD isn't a clear choice right now.
|
|
||||||
|
|
||||||
[Miscreant](https://github.com/miscreant/miscreant.go)
|
|
||||||
+ AES-GCM-SIV seems to have ideal properties
|
|
||||||
+ random nonces
|
|
||||||
- ~30% slower encryption
|
|
||||||
- [not maintained by a BigCo](https://github.com/miscreant/miscreant.go/graphs/contributors)
|
|
||||||
|
|
||||||
[nacl/secretbot](https://godoc.org/golang.org/x/crypto/nacl/secretbox)
|
|
||||||
+ Fast
|
|
||||||
+ XSalsa20 wutg Poly1305 MAC provides encryption and authentication together
|
|
||||||
+ A newer standard and may not be considered acceptable in environments that require high levels of review.
|
|
||||||
-/+ maintained as an [/x/ package](https://godoc.org/golang.org/x/crypto/nacl/secretbox)
|
|
||||||
- doesn't use the underlying cipher.AEAD api.
|
|
||||||
|
|
||||||
|
|
||||||
GCM with random nonces
|
|
||||||
+ Fastest
|
|
||||||
+ Go standard library, supported by google $
|
|
||||||
- Easy to get wrong
|
|
||||||
- IV reuse is a known weakness so keys must be rotated before birthday attack. [NIST SP 800-38D](http://csrc.nist.gov/publications/nistpubs/800-38D/SP-800-38D.pdf) recommends using the same key with random 96-bit nonces (the default nonce length) no more than 2^32 times
|
|
||||||
|
|
||||||
Further reading on tradeoffs:
|
|
||||||
- [Introducing Miscreant](https://tonyarcieri.com/introducing-miscreant-a-multi-language-misuse-resistant-encryption-library)
|
|
||||||
- [agl's post AES-GCM-SIV](https://www.imperialviolet.org/2017/05/14/aesgcmsiv.html)
|
|
||||||
- [x/crypto: add chacha20, xchacha20](https://github.com/golang/go/issues/24485s)
|
|
||||||
- [GCM cannot be used with random nonces](https://github.com/gtank/cryptopasta/issues/14s)
|
|
||||||
- [proposal: x/crypto/chacha20poly1305: add support for XChaCha20](https://github.com/golang/go/issues/23885)
|
|
||||||
- [kubernetes](https://kubernetes.io/docs/tasks/administer-cluster/encrypt-data/#providers)
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
"github.com/pomerium/pomerium/internal/templates"
|
"github.com/pomerium/pomerium/internal/templates"
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
)
|
)
|
||||||
|
@ -47,13 +46,6 @@ func ErrorResponse(rw http.ResponseWriter, req *http.Request, message string, co
|
||||||
writeJSONResponse(rw, code, response)
|
writeJSONResponse(rw, code, response)
|
||||||
} else {
|
} else {
|
||||||
title := http.StatusText(code)
|
title := http.StatusText(code)
|
||||||
|
|
||||||
log.Error().
|
|
||||||
Int("http-status", code).
|
|
||||||
Str("page-title", title).
|
|
||||||
Str("page-message", message).
|
|
||||||
Msg("authenticate/errors.ErrorResponse")
|
|
||||||
|
|
||||||
rw.WriteHeader(code)
|
rw.WriteHeader(code)
|
||||||
t := struct {
|
t := struct {
|
||||||
Code int
|
Code int
|
||||||
|
|
|
@ -2,8 +2,6 @@
|
||||||
package log // import "github.com/pomerium/pomerium/internal/log"
|
package log // import "github.com/pomerium/pomerium/internal/log"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
@ -18,11 +16,6 @@ func SetDebugMode() {
|
||||||
Logger = Logger.Output(zerolog.ConsoleWriter{Out: os.Stdout})
|
Logger = Logger.Output(zerolog.ConsoleWriter{Out: os.Stdout})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output duplicates the global logger and sets w as its output.
|
|
||||||
func Output(w io.Writer) zerolog.Logger {
|
|
||||||
return Logger.Output(w)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With creates a child logger with the field added to its context.
|
// With creates a child logger with the field added to its context.
|
||||||
func With() zerolog.Context {
|
func With() zerolog.Context {
|
||||||
return Logger.With()
|
return Logger.With()
|
||||||
|
@ -46,16 +39,6 @@ func Level(level zerolog.Level) zerolog.Logger {
|
||||||
return Logger.Level(level)
|
return Logger.Level(level)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sample returns a logger with the s sampler.
|
|
||||||
func Sample(s zerolog.Sampler) zerolog.Logger {
|
|
||||||
return Logger.Sample(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hook returns a logger with the h Hook.
|
|
||||||
func Hook(h zerolog.Hook) zerolog.Logger {
|
|
||||||
return Logger.Hook(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug starts a new message with debug level.
|
// Debug starts a new message with debug level.
|
||||||
//
|
//
|
||||||
// You must call Msg on the returned event in order to send the event.
|
// You must call Msg on the returned event in order to send the event.
|
||||||
|
@ -126,9 +109,3 @@ func Print(v ...interface{}) {
|
||||||
func Printf(format string, v ...interface{}) {
|
func Printf(format string, v ...interface{}) {
|
||||||
Logger.Printf(format, v...)
|
Logger.Printf(format, v...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ctx returns the Logger associated with the ctx. If no logger
|
|
||||||
// is associated, a disabled logger is returned.
|
|
||||||
func Ctx(ctx context.Context) *zerolog.Logger {
|
|
||||||
return zerolog.Ctx(ctx)
|
|
||||||
}
|
|
||||||
|
|
133
internal/log/log_test.go
Normal file
133
internal/log/log_test.go
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
package log_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"flag"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setup would normally be an init() function, however, there seems
|
||||||
|
// to be something awry with the testing framework when we set the
|
||||||
|
// global Logger from an init()
|
||||||
|
func setup() {
|
||||||
|
// UNIX Time is faster and smaller than most timestamps
|
||||||
|
// If you set zerolog.TimeFieldFormat to an empty string,
|
||||||
|
// logs will write with UNIX time
|
||||||
|
zerolog.TimeFieldFormat = ""
|
||||||
|
// In order to always output a static time to stdout for these
|
||||||
|
// examples to pass, we need to override zerolog.TimestampFunc
|
||||||
|
// and log.Logger globals -- you would not normally need to do this
|
||||||
|
zerolog.TimestampFunc = func() time.Time {
|
||||||
|
return time.Date(2008, 1, 8, 17, 5, 05, 0, time.UTC)
|
||||||
|
}
|
||||||
|
log.Logger = zerolog.New(os.Stdout).With().Timestamp().Logger()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple logging example using the Print function in the log package
|
||||||
|
// Note that both Print and Printf are at the debug log level by default
|
||||||
|
func ExamplePrint() {
|
||||||
|
setup()
|
||||||
|
|
||||||
|
log.Print("hello world")
|
||||||
|
// Output: {"level":"debug","time":1199811905,"message":"hello world"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleWith() {
|
||||||
|
setup()
|
||||||
|
sublog := log.With().Str("foo", "bar").Logger()
|
||||||
|
sublog.Debug().Msg("hello world")
|
||||||
|
// Output: {"level":"debug","foo":"bar","time":1199811905,"message":"hello world"}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple logging example using the Printf function in the log package
|
||||||
|
func ExamplePrintf() {
|
||||||
|
setup()
|
||||||
|
|
||||||
|
log.Printf("hello %s", "world")
|
||||||
|
// Output: {"level":"debug","time":1199811905,"message":"hello world"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example of a log with no particular "level"
|
||||||
|
func ExampleLog() {
|
||||||
|
setup()
|
||||||
|
log.Log().Msg("hello world")
|
||||||
|
|
||||||
|
// Output: {"time":1199811905,"message":"hello world"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example of a log at a particular "level" (in this case, "debug")
|
||||||
|
func ExampleDebug() {
|
||||||
|
setup()
|
||||||
|
log.Debug().Msg("hello world")
|
||||||
|
|
||||||
|
// Output: {"level":"debug","time":1199811905,"message":"hello world"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example of a log at a particular "level" (in this case, "info")
|
||||||
|
func ExampleInfo() {
|
||||||
|
setup()
|
||||||
|
log.Info().Msg("hello world")
|
||||||
|
|
||||||
|
// Output: {"level":"info","time":1199811905,"message":"hello world"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example of a log at a particular "level" (in this case, "warn")
|
||||||
|
func ExampleWarn() {
|
||||||
|
setup()
|
||||||
|
log.Warn().Msg("hello world")
|
||||||
|
|
||||||
|
// Output: {"level":"warn","time":1199811905,"message":"hello world"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example of a log at a particular "level" (in this case, "error")
|
||||||
|
func ExampleError() {
|
||||||
|
setup()
|
||||||
|
log.Error().Msg("hello world")
|
||||||
|
|
||||||
|
// Output: {"level":"error","time":1199811905,"message":"hello world"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example of a log at a particular "level" (in this case, "fatal")
|
||||||
|
func ExampleFatal() {
|
||||||
|
setup()
|
||||||
|
err := errors.New("A repo man spends his life getting into tense situations")
|
||||||
|
service := "myservice"
|
||||||
|
|
||||||
|
log.Fatal().
|
||||||
|
Err(err).
|
||||||
|
Str("service", service).
|
||||||
|
Msg("Cannot start")
|
||||||
|
|
||||||
|
// Outputs: {"level":"fatal","time":1199811905,"error":"A repo man spends his life getting into tense situations","service":"myservice","message":"Cannot start myservice"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This example uses command-line flags to demonstrate various outputs
|
||||||
|
// depending on the chosen log level.
|
||||||
|
func Example() {
|
||||||
|
setup()
|
||||||
|
debug := flag.Bool("debug", false, "sets log level to debug")
|
||||||
|
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Default level for this example is info, unless debug flag is present
|
||||||
|
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||||
|
if *debug {
|
||||||
|
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msg("This message appears only when log level set to Debug")
|
||||||
|
log.Info().Msg("This message appears when log level set to Debug or Info")
|
||||||
|
|
||||||
|
if e := log.Debug(); e.Enabled() {
|
||||||
|
// Compute log output only if enabled.
|
||||||
|
value := "bar"
|
||||||
|
e.Str("foo", value).Msg("some debug message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output: {"level":"info","time":1199811905,"message":"This message appears when log level set to Debug or Info"}
|
||||||
|
}
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetHeaders ensures that every response includes some basic security headers
|
// SetHeaders ensures that every response includes some basic security headers
|
||||||
|
@ -56,10 +55,6 @@ func ValidateClientSecret(f http.HandlerFunc, sharedSecret string) http.HandlerF
|
||||||
}
|
}
|
||||||
|
|
||||||
if clientSecret != sharedSecret {
|
if clientSecret != sharedSecret {
|
||||||
log.Error().
|
|
||||||
Str("clientSecret", clientSecret).
|
|
||||||
Str("sharedSecret", sharedSecret).
|
|
||||||
Msg("middleware.ValidateClientSecret")
|
|
||||||
httputil.ErrorResponse(rw, req, "Invalid client secret", http.StatusUnauthorized)
|
httputil.ErrorResponse(rw, req, "Invalid client secret", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,11 +5,9 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/aead"
|
"github.com/pomerium/pomerium/internal/aead"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrInvalidSession is an error for invalid sessions.
|
// ErrInvalidSession is an error for invalid sessions.
|
||||||
|
@ -85,9 +83,6 @@ func (s *CookieStore) makeCookie(req *http.Request, name string, value string, e
|
||||||
domain = h
|
domain = h
|
||||||
}
|
}
|
||||||
if s.CookieDomain != "" {
|
if s.CookieDomain != "" {
|
||||||
if !strings.HasSuffix(domain, s.CookieDomain) {
|
|
||||||
log.Warn().Str("cookie-domain", s.CookieDomain).Msg("using configured cookie domain")
|
|
||||||
}
|
|
||||||
domain = s.CookieDomain
|
domain = s.CookieDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,7 +140,6 @@ func (s *CookieStore) LoadSession(req *http.Request) (*SessionState, error) {
|
||||||
}
|
}
|
||||||
session, err := UnmarshalSession(c.Value, s.CookieCipher)
|
session, err := UnmarshalSession(c.Value, s.CookieCipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Str("remote-host", req.Host).Msg("error unmarshaling session")
|
|
||||||
return nil, ErrInvalidSession
|
return nil, ErrInvalidSession
|
||||||
}
|
}
|
||||||
return session, nil
|
return session, nil
|
||||||
|
|
|
@ -91,11 +91,9 @@ func (o *Options) Validate() error {
|
||||||
if o.CookieSecret == "" {
|
if o.CookieSecret == "" {
|
||||||
return errors.New("missing setting: cookie-secret")
|
return errors.New("missing setting: cookie-secret")
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.SharedKey == "" {
|
if o.SharedKey == "" {
|
||||||
return errors.New("missing setting: client-secret")
|
return errors.New("missing setting: client-secret")
|
||||||
}
|
}
|
||||||
|
|
||||||
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
|
decodedCookieSecret, err := base64.StdEncoding.DecodeString(o.CookieSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("cookie secret is invalid (e.g. `head -c32 /dev/urandom | base64`) ")
|
return errors.New("cookie secret is invalid (e.g. `head -c32 /dev/urandom | base64`) ")
|
||||||
|
|
|
@ -140,6 +140,9 @@ func TestOptions_Validate(t *testing.T) {
|
||||||
badToRoute.Routes = map[string]string{"^": "example.com"}
|
badToRoute.Routes = map[string]string{"^": "example.com"}
|
||||||
badAuthURL := testOptions()
|
badAuthURL := testOptions()
|
||||||
badAuthURL.AuthenticateServiceURL = nil
|
badAuthURL.AuthenticateServiceURL = nil
|
||||||
|
authurl, _ := url.Parse("http://sso-auth.corp.beyondperimeter.com")
|
||||||
|
httpAuthURL := testOptions()
|
||||||
|
httpAuthURL.AuthenticateServiceURL = authurl
|
||||||
emptyCookieSecret := testOptions()
|
emptyCookieSecret := testOptions()
|
||||||
emptyCookieSecret.CookieSecret = ""
|
emptyCookieSecret.CookieSecret = ""
|
||||||
invalidCookieSecret := testOptions()
|
invalidCookieSecret := testOptions()
|
||||||
|
@ -157,14 +160,15 @@ func TestOptions_Validate(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{"good - minimum options", good, false},
|
{"good - minimum options", good, false},
|
||||||
|
|
||||||
{"bad - nil options", &Options{}, true},
|
{"nil options", &Options{}, true},
|
||||||
{"bad - from route", badFromRoute, true},
|
{"from route", badFromRoute, true},
|
||||||
{"bad - to route", badToRoute, true},
|
{"to route", badToRoute, true},
|
||||||
{"bad - auth service url", badAuthURL, true},
|
{"auth service url", badAuthURL, true},
|
||||||
{"bad - no cookie secret", emptyCookieSecret, true},
|
{"auth service url not https", httpAuthURL, true},
|
||||||
{"bad - invalid cookie secret", invalidCookieSecret, true},
|
{"no cookie secret", emptyCookieSecret, true},
|
||||||
{"bad - short cookie secret", shortCookieLength, true},
|
{"invalid cookie secret", invalidCookieSecret, true},
|
||||||
{"bad - no shared secret", badSharedKey, true},
|
{"short cookie secret", shortCookieLength, true},
|
||||||
|
{"no shared secret", badSharedKey, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue