mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-11 08:07:38 +02:00
proxy: use middleware to manage request flow
proxy: remove duplicate error handling in New proxy: remove routeConfigs in favor of using gorilla/mux proxy: add proxy specific middleware proxy: no longer need to use middleware / handler to check if valid route. Can use build in 404 mux. internal/middleware: add cors bypass middleware Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
70c5553d3c
commit
782ffbeb3e
17 changed files with 834 additions and 839 deletions
|
@ -1,4 +1,4 @@
|
|||
package middleware
|
||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
@ -276,60 +276,80 @@ func TestHealthCheck(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Redirect to a fixed URL
|
||||
type handlerHelper struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (rh *handlerHelper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(rh.msg))
|
||||
}
|
||||
|
||||
func handlerHelp(msg string) http.Handler {
|
||||
return &handlerHelper{msg}
|
||||
}
|
||||
func TestValidateHost(t *testing.T) {
|
||||
validHostFunc := func(host string) bool {
|
||||
return host == "google.com"
|
||||
}
|
||||
|
||||
validHostHandler := handlerHelp("google")
|
||||
|
||||
func TestStripCookie(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isValidHost func(string) bool
|
||||
validHostHandler http.Handler
|
||||
clientPath string
|
||||
expected []byte
|
||||
status int
|
||||
name string
|
||||
pomeriumCookie string
|
||||
otherCookies []string
|
||||
}{
|
||||
{"good", validHostFunc, validHostHandler, "google.com", []byte("google"), 200},
|
||||
{"no route", validHostFunc, validHostHandler, "googles.com", []byte("google"), 404},
|
||||
{"good", "pomerium", []string{"x", "y", "z"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, tt.clientPath, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, cookie := range r.Cookies() {
|
||||
if cookie.Name == tt.pomeriumCookie {
|
||||
t.Errorf("cookie not stripped %s", r.Cookies())
|
||||
}
|
||||
}
|
||||
})
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
var testHandler http.Handler
|
||||
if tt.isValidHost(tt.clientPath) {
|
||||
tt.validHostHandler.ServeHTTP(rr, req)
|
||||
testHandler = tt.validHostHandler
|
||||
} else {
|
||||
testHandler = handlerHelp("ok")
|
||||
for _, cn := range tt.otherCookies {
|
||||
http.SetCookie(rr, &http.Cookie{
|
||||
Name: cn,
|
||||
Value: "some other cookie",
|
||||
})
|
||||
}
|
||||
handler := ValidateHost(tt.isValidHost)(testHandler)
|
||||
|
||||
http.SetCookie(rr, &http.Cookie{
|
||||
Name: tt.pomeriumCookie,
|
||||
Value: "pomerium cookie!",
|
||||
})
|
||||
|
||||
http.SetCookie(rr, &http.Cookie{
|
||||
Name: tt.pomeriumCookie + "_csrf",
|
||||
Value: "pomerium csrf cookie!",
|
||||
})
|
||||
req := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}}
|
||||
|
||||
handler := StripCookie(tt.pomeriumCookie)(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.status {
|
||||
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
|
||||
t.Errorf("%s", rr.Body)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutHandlerFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout time.Duration
|
||||
timeoutError string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{"good", 1 * time.Second, "good timed out!?", http.StatusOK, http.StatusText(http.StatusOK)},
|
||||
{"timeout!", 1 * time.Nanosecond, "ruh roh", http.StatusServiceUnavailable, "ruh roh"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
got := TimeoutHandlerFunc(tt.timeout, tt.timeoutError)(fn)
|
||||
got.ServeHTTP(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
if body := w.Body.String(); tt.wantBody != body {
|
||||
t.Errorf("SignRequest() body = %v, want %v", body, tt.wantBody)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue