pomerium/proxy/middleware_test.go
Caleb Doxsey d9a224a5e8
proxy: move properties to atomically updated state (#1280)
* authenticate: remove cookie options

* authenticate: remove shared key field

* authenticate: remove shared cipher property

* authenticate: move properties to separate state struct

* proxy: allow local state to be updated on configuration changes

* fix test

* return new connection

* use warn, collapse to single line

* address concerns, fix tests
2020-08-14 11:44:58 -06:00

163 lines
5.3 KiB
Go

package proxy
import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
)
func TestProxy_AuthenticateSession(t *testing.T) {
t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
refreshRespStatus int
errOnFailure bool
session sessions.SessionStore
ctxError error
provider identity.Authenticator
encoder encoding.MarshalUnmarshaler
refreshURL string
wantStatus int
}{
{"good", 200, false, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
{"invalid session", 200, false, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound},
{"expired", 200, false, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.refreshRespStatus)
fmt.Fprintln(w, "REFRESH GOOD")
}))
defer ts.Close()
rURL := ts.URL
if tt.refreshURL != "" {
rURL = tt.refreshURL
}
a := Proxy{
state: newAtomicProxyState(&proxyState{
sharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
authenticateRefreshURL: uriParseHelper(rURL),
sessionStore: tt.session,
encoder: tt.encoder,
}),
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := a.jwtClaimMiddleware(false)(a.AuthenticateSession(fn))
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("AuthenticateSession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
})
}
}
func Test_jwtClaimMiddleware(t *testing.T) {
claimHeaders := []string{"email", "groups", "missing"}
sharedKey := "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
session := &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}
encoder, _ := jws.NewHS256Signer([]byte(sharedKey), "https://authenticate.pomerium.example")
state, err := encoder.Marshal(session)
if err != nil {
t.Errorf("failed to marshal state: %s", err)
}
a := Proxy{
state: newAtomicProxyState(&proxyState{
sharedKey: sharedKey,
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
encoder: encoder,
jwtClaimHeaders: claimHeaders,
}),
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
r := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := r.Context()
ctx = sessions.NewContext(ctx, string(state), nil)
r = r.WithContext(ctx)
w := httptest.NewRecorder()
proxyHandler := a.jwtClaimMiddleware(true)(handler)
proxyHandler.ServeHTTP(w, r)
t.Run("missing claim", func(t *testing.T) {
absentHeader := r.Header.Get("x-pomerium-claim-missing")
if absentHeader != "" {
t.Errorf("found claim that should not exist, got=%q", absentHeader)
}
})
}
func TestProxy_SetResponseHeaders(t *testing.T) {
t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
var sb strings.Builder
for k, v := range r.Header {
k = strings.ToLower(k)
for _, h := range v {
sb.WriteString(fmt.Sprintf("%v: %v\n", k, h))
}
}
fmt.Fprint(w, sb.String())
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
setHeaders map[string]string
wantHeaders string
}{
{"good", map[string]string{"x-gonna": "give-it-to-ya"}, "x-gonna: give-it-to-ya\n"},
{"nil", nil, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
got := SetResponseHeaders(tt.setHeaders)(fn)
got.ServeHTTP(w, r)
if diff := cmp.Diff(w.Body.String(), tt.wantHeaders); diff != "" {
t.Errorf("SetResponseHeaders() :\n %s", diff)
}
})
}
}