pomerium/proxy/handlers_test.go
Bobby DeSimone 805f0198d2
authenticate: add tests, fix signout (#45)
- authenticate: a bug where sign out failed to revoke the remote session
- docs: add code coverage to readme
- authenticate: Rename shorthand receiver variable name
- authenticate: consolidate sign in
2019-02-14 00:01:50 -08:00

477 lines
16 KiB
Go

package proxy
import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"time"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version"
"github.com/pomerium/pomerium/proxy/authenticator"
)
type mockCipher struct{}
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Marshal(s interface{}) (string, error) { return "ok", nil }
func (a mockCipher) Unmarshal(s string, i interface{}) error {
if string(s) == "unmarshal error" || string(s) == "error" {
return errors.New("error")
}
return nil
}
func TestProxy_RobotsTxt(t *testing.T) {
proxy := Proxy{}
req := httptest.NewRequest("GET", "/robots.txt", nil)
rr := httptest.NewRecorder()
proxy.RobotsTxt(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
}
expected := fmt.Sprintf("User-agent: *\nDisallow: /")
if rr.Body.String() != expected {
t.Errorf("handler returned wrong body: got %v want %v", rr.Body.String(), expected)
}
}
func TestProxy_GetRedirectURL(t *testing.T) {
tests := []struct {
name string
host string
want *url.URL
}{
{"google", "google.com", &url.URL{Scheme: "https", Host: "google.com", Path: "/.pomerium/callback"}},
{"pomerium", "pomerium.io", &url.URL{Scheme: "https", Host: "pomerium.io", Path: "/.pomerium/callback"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Proxy{redirectURL: &url.URL{Path: "/.pomerium/callback"}}
if got := p.GetRedirectURL(tt.host); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Proxy.GetRedirectURL() = %v, want %v", got, tt.want)
}
})
}
}
func TestProxy_signRedirectURL(t *testing.T) {
tests := []struct {
name string
rawRedirect string
timestamp time.Time
want string
}{
{"pomerium", "https://pomerium.io/.pomerium/callback", fixedDate, "wq3rAjRGN96RXS8TAzH-uxQTD0XgY_8ZYEKMiOLD5P4="},
{"google", "https://google.com/.pomerium/callback", fixedDate, "7EYHZObq167CuyuPm5CqOtkU4zg5dFeUCs7W7QOrgNQ="},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Proxy{}
if got := p.signRedirectURL(tt.rawRedirect, tt.timestamp); got != tt.want {
t.Errorf("Proxy.signRedirectURL() = %v, want %v", got, tt.want)
}
})
}
}
func TestProxy_GetSignOutURL(t *testing.T) {
tests := []struct {
name string
authenticate string
redirect string
wantPrefix string
}{
{"without scheme", "auth.corp.pomerium.io", "hello.corp.pomerium.io", "https://auth.corp.pomerium.io/sign_out?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io"},
{"with scheme", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "https://auth.corp.pomerium.io/sign_out?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
authenticateURL, _ := urlParse(tt.authenticate)
redirectURL, _ := urlParse(tt.redirect)
p := &Proxy{}
// signature is ignored as it is tested above. Avoids testing time.Now
if got := p.GetSignOutURL(authenticateURL, redirectURL); !strings.HasPrefix(got.String(), tt.wantPrefix) {
t.Errorf("Proxy.GetSignOutURL() = %v, wantPrefix %v", got.String(), tt.wantPrefix)
}
})
}
}
func TestProxy_GetSignInURL(t *testing.T) {
tests := []struct {
name string
authenticate string
redirect string
state string
wantPrefix string
}{
{"without scheme", "auth.corp.pomerium.io", "hello.corp.pomerium.io", "example_state", "https://auth.corp.pomerium.io/sign_in?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io&response_type=code&shared_secret=shared-secret"},
{"with scheme", "https://auth.corp.pomerium.io", "https://hello.corp.pomerium.io", "example_state", "https://auth.corp.pomerium.io/sign_in?redirect_uri=https%3A%2F%2Fhello.corp.pomerium.io&response_type=code&shared_secret=shared-secret"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Proxy{SharedKey: "shared-secret"}
authenticateURL, _ := urlParse(tt.authenticate)
redirectURL, _ := urlParse(tt.redirect)
if got := p.GetSignInURL(authenticateURL, redirectURL, tt.state); !strings.HasPrefix(got.String(), tt.wantPrefix) {
t.Errorf("Proxy.GetSignOutURL() = %v, wantPrefix %v", got.String(), tt.wantPrefix)
}
})
}
}
func TestProxy_Favicon(t *testing.T) {
proxy, err := New(testOptions())
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("GET", "/favicon.ico", nil)
rr := httptest.NewRecorder()
proxy.Favicon(rr, req)
if status := rr.Code; status != http.StatusNotFound {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusNotFound)
}
// todo(bdd) : good way of mocking auth then serving a simple favicon?
}
func TestProxy_Signout(t *testing.T) {
proxy, err := New(testOptions())
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("GET", "/.pomerium/sign_out", nil)
rr := httptest.NewRecorder()
proxy.SignOut(rr, req)
if status := rr.Code; status != http.StatusFound {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
}
// todo(bdd) : good way of mocking auth then serving a simple favicon?
}
func TestProxy_OAuthStart(t *testing.T) {
proxy, err := New(testOptions())
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("GET", "/oauth-start", nil)
rr := httptest.NewRecorder()
proxy.OAuthStart(rr, req)
// expect oauth redirect
if status := rr.Code; status != http.StatusFound {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound)
}
// expected url
expected := `<a href="https://sso-auth.corp.beyondperimeter.com/sign_in`
body := rr.Body.String()
if !strings.HasPrefix(body, expected) {
t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
}
}
func TestProxy_Handler(t *testing.T) {
proxy, err := New(testOptions())
if err != nil {
t.Fatal(err)
}
h := proxy.Handler()
if h == nil {
t.Error("handler cannot be nil")
}
mux := http.NewServeMux()
mux.Handle("/", h)
req := httptest.NewRequest("GET", "/ping", nil)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
expected := version.UserAgent()
body := rr.Body.String()
if body != expected {
t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
}
}
func TestProxy_OAuthCallback(t *testing.T) {
//todo(bdd): test malformed requests
// https://github.com/golang/go/blob/master/src/net/http/request_test.go#L110
normalSession := sessions.MockSessionStore{
Session: &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(-10 * time.Second),
},
}
normalAuth := authenticator.MockAuthenticate{
RedeemResponse: &authenticator.RedeemResponse{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
Expiry: time.Now().Add(10 * time.Second),
},
}
normalCsrf := sessions.MockCSRFStore{
ResponseCSRF: "ok",
GetError: nil,
Cookie: &http.Cookie{
Name: "something_csrf",
Value: "csrf_state",
}}
tests := []struct {
name string
csrf sessions.MockCSRFStore
session sessions.MockSessionStore
authenticator authenticator.MockAuthenticate
params map[string]string
wantCode int
}{
{"good", normalCsrf, normalSession, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusFound},
{"error", normalCsrf, normalSession, normalAuth, map[string]string{"error": "some error"}, http.StatusForbidden},
{"code err", normalCsrf, normalSession, authenticator.MockAuthenticate{RedeemError: errors.New("error")}, map[string]string{"code": "error"}, http.StatusInternalServerError},
{"state err", normalCsrf, normalSession, normalAuth, map[string]string{"code": "code", "state": "error"}, http.StatusInternalServerError},
{"csrf err", sessions.MockCSRFStore{GetError: errors.New("error")}, normalSession, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusBadRequest},
{"unmarshal err", sessions.MockCSRFStore{
Cookie: &http.Cookie{
Name: "something_csrf",
Value: "unmarshal error",
},
}, normalSession, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
{"encrypted state != CSRF", normalCsrf, normalSession, normalAuth, map[string]string{"code": "code", "state": "csrf_state"}, http.StatusBadRequest},
{"session save err", normalCsrf, sessions.MockSessionStore{SaveError: errors.New("error")}, normalAuth, map[string]string{"code": "code", "state": "state"}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy, err := New(testOptions())
if err != nil {
t.Fatal(err)
}
proxy.sessionStore = tt.session
proxy.csrfStore = tt.csrf
proxy.AuthenticateClient = tt.authenticator
proxy.cipher = mockCipher{}
// proxy.Csrf
req := httptest.NewRequest(http.MethodPost, "/.pomerium/callback", nil)
q := req.URL.Query()
for k, v := range tt.params {
q.Add(k, v)
}
req.URL.RawQuery = q.Encode()
w := httptest.NewRecorder()
proxy.OAuthCallback(w, req)
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
}
})
}
}
func Test_extendDeadline(t *testing.T) {
tests := []struct {
name string
ttl time.Duration
want time.Time
}{
{"good", time.Second, time.Now().Add(time.Second).Truncate(time.Second)},
{"test nanoseconds truncated", 500 * time.Nanosecond, time.Now().Truncate(time.Second)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := extendDeadline(tt.ttl); !reflect.DeepEqual(got, tt.want) {
t.Errorf("extendDeadline() = %v, want %v", got, tt.want)
}
})
}
}
func TestProxy_router(t *testing.T) {
tests := []struct {
name string
host string
mux map[string]string
route http.Handler
wantOk bool
}{
{"good corp", "https://corp.example.com", map[string]string{"corp.example.com": "example.com"}, nil, true},
{"good with slash", "https://corp.example.com/", map[string]string{"corp.example.com": "example.com"}, nil, true},
{"good with path", "https://corp.example.com/123", map[string]string{"corp.example.com": "example.com"}, nil, true},
{"multiple", "https://corp.example.com/", map[string]string{"corp.unrelated.com": "unrelated.com", "corp.example.com": "example.com"}, nil, true},
{"bad corp", "https://notcorp.example.com/123", map[string]string{"corp.example.com": "example.com"}, nil, false},
{"bad sub-sub", "https://notcorp.corp.example.com/123", map[string]string{"corp.example.com": "example.com"}, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := testOptions()
opts.Routes = tt.mux
p, err := New(opts)
if err != nil {
t.Fatal(err)
}
p.AuthenticateClient = authenticator.MockAuthenticate{}
p.cipher = mockCipher{}
req := httptest.NewRequest("GET", tt.host, nil)
_, ok := p.router(req)
if ok != tt.wantOk {
t.Errorf("Proxy.router() ok = %v, want %v", ok, tt.wantOk)
}
})
}
}
func TestProxy_Proxy(t *testing.T) {
goodSession := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}
expiredLifetime := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(-10 * time.Second),
}
tests := []struct {
name string
host string
session sessions.SessionStore
authenticator authenticator.Authenticator
wantStatus int
}{
// weirdly, we want 503 here because that means proxy is trying to route a domain (example.com) that we dont control. Weird. I know.
{"good", "https://corp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusServiceUnavailable},
{"unexpected error", "https://corp.example.com/test", sessions.MockSessionStore{LoadError: errors.New("ok")}, authenticator.MockAuthenticate{}, http.StatusInternalServerError},
// redirect to start auth process
{"expired lifetime", "https://corp.example.com/test", sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, http.StatusFound},
{"unknown host", "https://notcorp.example.com/test", sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, http.StatusNotFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := testOptions()
p, err := New(opts)
if err != nil {
t.Fatal(err)
}
p.cipher = mockCipher{}
p.sessionStore = tt.session
p.AuthenticateClient = tt.authenticator
r := httptest.NewRequest("GET", tt.host, nil)
w := httptest.NewRecorder()
p.Proxy(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String())
}
})
}
}
func TestProxy_Authenticate(t *testing.T) {
goodSession := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(10 * time.Second),
ValidDeadline: time.Now().Add(10 * time.Second),
}
expiredLifetime := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(-10 * time.Second),
}
expiredDeadline := &sessions.SessionState{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
LifetimeDeadline: time.Now().Add(10 * time.Second),
RefreshDeadline: time.Now().Add(-10 * time.Second),
}
tests := []struct {
name string
host string
mux map[string]string
session sessions.SessionStore
authenticator authenticator.Authenticator
wantErr bool
}{
{"cannot save session",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: goodSession, SaveError: errors.New("error")},
authenticator.MockAuthenticate{}, true},
{"cannot load session",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{LoadError: errors.New("error")}, authenticator.MockAuthenticate{}, true},
{"expired lifetime",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: expiredLifetime}, authenticator.MockAuthenticate{}, true},
{"expired session",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: expiredDeadline}, authenticator.MockAuthenticate{}, false},
{"bad refresh authenticator",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{
Session: expiredDeadline,
},
authenticator.MockAuthenticate{RefreshError: errors.New("error")},
true},
{"good",
"https://corp.example.com/",
map[string]string{"corp.example.com": "example.com"},
sessions.MockSessionStore{Session: goodSession}, authenticator.MockAuthenticate{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := testOptions()
opts.Routes = tt.mux
p, err := New(opts)
if err != nil {
t.Fatal(err)
}
p.sessionStore = tt.session
p.AuthenticateClient = tt.authenticator
p.cipher = mockCipher{}
r := httptest.NewRequest("GET", tt.host, nil)
w := httptest.NewRecorder()
if err := p.Authenticate(w, r); (err != nil) != tt.wantErr {
t.Errorf("Proxy.Authenticate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}