diff --git a/cmd/pomerium/main.go b/cmd/pomerium/main.go index 218b4857f..912ac7e10 100644 --- a/cmd/pomerium/main.go +++ b/cmd/pomerium/main.go @@ -112,7 +112,7 @@ func newProxyService(opt config.Options, r *mux.Router) (*proxy.Proxy, error) { if err != nil { return nil, err } - r.PathPrefix("/").Handler(service.Handler()) + r.PathPrefix("/").Handler(service.Handler) return service, nil } diff --git a/go.sum b/go.sum index 5583db503..92d7c8f9f 100644 --- a/go.sum +++ b/go.sum @@ -70,6 +70,7 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= +github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk= github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= @@ -124,8 +125,6 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30 h1:jggCv6hZvcxjGa3gqkYY2EUuOkITI9Znugz/f3QJfRQ= -github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0= github.com/pomerium/csrf v1.6.2-0.20190918035251-f3318380bad3 h1:FmzFXnCAepHZwl6QPhTFqBHcbcGevdiEQjutK+M5bj4= github.com/pomerium/csrf v1.6.2-0.20190918035251-f3318380bad3/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0= github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o= @@ -198,7 +197,6 @@ golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmV golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8 h1:1wopBVtVdWnn03fZelqdXTqk7U7zPQCb+T4rbU9ZEoU= golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 h1:0hQKqeLdqlt5iIwVOBErRisrHJAN57yOiPRQItI20fU= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= diff --git a/internal/httputil/proxy.go b/internal/httputil/proxy.go new file mode 100644 index 000000000..c0a623803 --- /dev/null +++ b/internal/httputil/proxy.go @@ -0,0 +1,31 @@ +package httputil // import "github.com/pomerium/pomerium/internal/httputil" + +import ( + stdlog "log" + "net/http" + "net/http/httputil" + "net/url" + + "github.com/pomerium/pomerium/internal/log" +) + +// HeaderForwardHost is the header key the identifies the originating +// IP addresses of a client connecting to a web server through an HTTP proxy +// or a load balancer. +const HeaderForwardHost = "X-Forwarded-Host" + +// NewReverseProxy returns a new ReverseProxy that routes +// URLs to the scheme, host, and base path provided in target, +// rewrites the Host header, and sets `X-Forwarded-Host`. +func NewReverseProxy(target *url.URL) *httputil.ReverseProxy { + reverseProxy := httputil.NewSingleHostReverseProxy(target) + sublogger := log.With().Str("reverse-proxy", target.Host).Logger() + reverseProxy.ErrorLog = stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0) + director := reverseProxy.Director + reverseProxy.Director = func(req *http.Request) { + req.Header.Add(HeaderForwardHost, req.Host) + director(req) + req.Host = target.Host + } + return reverseProxy +} diff --git a/internal/httputil/proxy_test.go b/internal/httputil/proxy_test.go new file mode 100644 index 000000000..a98859d06 --- /dev/null +++ b/internal/httputil/proxy_test.go @@ -0,0 +1,35 @@ +package httputil // import "github.com/pomerium/pomerium/internal/httputil" + +import ( + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestNewReverseProxy(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + hostname, _, _ := net.SplitHostPort(r.Host) + w.Write([]byte(hostname)) + })) + defer backend.Close() + + backendURL, _ := url.Parse(backend.URL) + backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) + backendHost := net.JoinHostPort(backendHostname, backendPort) + proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") + + proxyHandler := NewReverseProxy(proxyURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + res, _ := http.DefaultClient.Do(getReq) + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendHostname; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} diff --git a/internal/httputil/tls_test.go b/internal/httputil/tls_test.go index d9a2c22e2..b28b4426c 100644 --- a/internal/httputil/tls_test.go +++ b/internal/httputil/tls_test.go @@ -1,4 +1,4 @@ -package httputil +package httputil // import "github.com/pomerium/pomerium/internal/httputil" import ( "encoding/base64" diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go new file mode 100644 index 000000000..c8404a91c --- /dev/null +++ b/internal/middleware/cors.go @@ -0,0 +1,26 @@ +package middleware // import "github.com/pomerium/pomerium/internal/middleware" + +import ( + "net/http" + + "github.com/pomerium/pomerium/internal/telemetry/trace" +) + +// CorsBypass is middleware that takes a target handler as a paramater, +// if the request is determined to be a CORS preflight request, that handler +// is called instead of the normal handler chain. +func CorsBypass(target http.Handler) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, span := trace.StartSpan(r.Context(), "middleware.CorsBypass") + defer span.End() + if r.Method == http.MethodOptions && + r.Header.Get("Access-Control-Request-Method") != "" && + r.Header.Get("Origin") != "" { + target.ServeHTTP(w, r.WithContext(ctx)) + return + } + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/internal/middleware/cors_test.go b/internal/middleware/cors_test.go new file mode 100644 index 000000000..a09a01866 --- /dev/null +++ b/internal/middleware/cors_test.go @@ -0,0 +1,57 @@ +package middleware // import "github.com/pomerium/pomerium/internal/middleware" + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func someOtherMiddleware(s string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Some-Other-Middleware", s) + next.ServeHTTP(w, r) + }) + } +} +func TestCorsBypass(t *testing.T) { + t.Parallel() + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + fmt.Fprint(w, http.StatusText(http.StatusOK)) + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + method string + header http.Header + wantStatus int + wantHeader string + }{ + {"good", http.MethodOptions, http.Header{"Access-Control-Request-Method": []string{"GET"}, "Origin": []string{"localhost"}}, 200, ""}, + {"invalid cors - non options request", http.MethodGet, http.Header{"Access-Control-Request-Method": []string{"GET"}, "Origin": []string{"localhost"}}, 200, "BAD"}, + {"invalid cors - Origin not set", http.MethodOptions, http.Header{"Access-Control-Request-Method": []string{"GET"}, "Origin": []string{""}}, 200, "BAD"}, + {"invalid cors - Access-Control-Request-Method not set", http.MethodOptions, http.Header{"Access-Control-Request-Method": []string{""}, "Origin": []string{"*"}}, 200, "BAD"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &http.Request{ + Method: tt.method, + Header: tt.header, + } + w := httptest.NewRecorder() + target := fn + got := CorsBypass(target)(someOtherMiddleware("BAD")(target)) + got.ServeHTTP(w, r) + if status := w.Code; status != tt.wantStatus { + t.Errorf("TestCorsBypass() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) + } + if header := w.Header().Get("Some-Other-Middleware"); tt.wantHeader != header { + t.Errorf("TestCorsBypass() header = %v, want %v", header, tt.wantHeader) + } + }) + } +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 80db0183f..7c222e59a 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "strings" + "time" "github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/httputil" @@ -121,22 +122,6 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler } } -// ValidateHost ensures that each request's host is valid -func ValidateHost(validHost func(host string) bool) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateHost") - defer span.End() - - if !validHost(r.Host) { - httputil.ErrorResponse(w, r, httputil.Error("", http.StatusNotFound, nil)) - return - } - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} - // Healthcheck endpoint middleware useful to setting up a path like // `/ping` that load balancers or uptime testing external services // can make a request before hitting any routes. It's also convenient @@ -185,3 +170,33 @@ func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool { } return cryptutil.CheckHMAC([]byte(fmt.Sprint(redirectURI, timestamp)), requestSig, secret) } + +// StripCookie strips the cookie from the downstram request. +func StripCookie(cookieName string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, span := trace.StartSpan(r.Context(), "middleware.StripCookie") + defer span.End() + + headers := make([]string, 0, len(r.Cookies())) + for _, cookie := range r.Cookies() { + if !strings.HasPrefix(cookie.Name, cookieName) { + headers = append(headers, cookie.String()) + } + } + r.Header.Set("Cookie", strings.Join(headers, ";")) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// TimeoutHandlerFunc wraps http.TimeoutHandler +func TimeoutHandlerFunc(timeout time.Duration, timeoutError string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, span := trace.StartSpan(r.Context(), "middleware.TimeoutHandlerFunc") + defer span.End() + http.TimeoutHandler(next, timeout, timeoutError).ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index 3a5126489..1cef1a542 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -1,4 +1,4 @@ -package middleware +package middleware // import "github.com/pomerium/pomerium/internal/middleware" import ( "encoding/base64" @@ -276,60 +276,80 @@ func TestHealthCheck(t *testing.T) { } } -// Redirect to a fixed URL -type handlerHelper struct { - msg string -} - -func (rh *handlerHelper) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(rh.msg)) -} - -func handlerHelp(msg string) http.Handler { - return &handlerHelper{msg} -} -func TestValidateHost(t *testing.T) { - validHostFunc := func(host string) bool { - return host == "google.com" - } - - validHostHandler := handlerHelp("google") - +func TestStripCookie(t *testing.T) { tests := []struct { - name string - isValidHost func(string) bool - validHostHandler http.Handler - clientPath string - expected []byte - status int + name string + pomeriumCookie string + otherCookies []string }{ - {"good", validHostFunc, validHostHandler, "google.com", []byte("google"), 200}, - {"no route", validHostFunc, validHostHandler, "googles.com", []byte("google"), 404}, + {"good", "pomerium", []string{"x", "y", "z"}}, } - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req, err := http.NewRequest(http.MethodGet, tt.clientPath, nil) - if err != nil { - t.Fatal(err) - } + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, cookie := range r.Cookies() { + if cookie.Name == tt.pomeriumCookie { + t.Errorf("cookie not stripped %s", r.Cookies()) + } + } + }) rr := httptest.NewRecorder() - - var testHandler http.Handler - if tt.isValidHost(tt.clientPath) { - tt.validHostHandler.ServeHTTP(rr, req) - testHandler = tt.validHostHandler - } else { - testHandler = handlerHelp("ok") + for _, cn := range tt.otherCookies { + http.SetCookie(rr, &http.Cookie{ + Name: cn, + Value: "some other cookie", + }) } - handler := ValidateHost(tt.isValidHost)(testHandler) + + http.SetCookie(rr, &http.Cookie{ + Name: tt.pomeriumCookie, + Value: "pomerium cookie!", + }) + + http.SetCookie(rr, &http.Cookie{ + Name: tt.pomeriumCookie + "_csrf", + Value: "pomerium csrf cookie!", + }) + req := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}} + + handler := StripCookie(tt.pomeriumCookie)(testHandler) handler.ServeHTTP(rr, req) - if rr.Code != tt.status { - t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status) - t.Errorf("%s", rr.Body) - } - + }) + } +} + +func TestTimeoutHandlerFunc(t *testing.T) { + t.Parallel() + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + fmt.Fprint(w, http.StatusText(http.StatusOK)) + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + timeout time.Duration + timeoutError string + wantStatus int + wantBody string + }{ + {"good", 1 * time.Second, "good timed out!?", http.StatusOK, http.StatusText(http.StatusOK)}, + {"timeout!", 1 * time.Nanosecond, "ruh roh", http.StatusServiceUnavailable, "ruh roh"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + got := TimeoutHandlerFunc(tt.timeout, tt.timeoutError)(fn) + got.ServeHTTP(w, r) + if status := w.Code; status != tt.wantStatus { + t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) + } + if body := w.Body.String(); tt.wantBody != body { + t.Errorf("SignRequest() body = %v, want %v", body, tt.wantBody) + } }) } } diff --git a/internal/middleware/reverse_proxy.go b/internal/middleware/reverse_proxy.go deleted file mode 100644 index d00d7c64b..000000000 --- a/internal/middleware/reverse_proxy.go +++ /dev/null @@ -1,48 +0,0 @@ -package middleware // import "github.com/pomerium/pomerium/internal/middleware" - -import ( - "net/http" - "strings" - - "github.com/pomerium/pomerium/internal/cryptutil" - "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/trace" -) - -func SignRequest(signer cryptutil.JWTSigner, id, email, groups, header string) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, span := trace.StartSpan(r.Context(), "middleware.SignRequest") - defer span.End() - jwt, err := signer.SignJWT( - r.Header.Get(id), - r.Header.Get(email), - r.Header.Get(groups)) - if err != nil { - log.Warn().Err(err).Msg("internal/middleware: failed signing request") - } else { - r.Header.Set(header, jwt) - } - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} - -// StripPomeriumCookie ensures that every response includes some basic security headers -func StripPomeriumCookie(cookieName string) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, span := trace.StartSpan(r.Context(), "middleware.StripPomeriumCookie") - defer span.End() - - headers := make([]string, 0, len(r.Cookies())) - for _, cookie := range r.Cookies() { - if !strings.HasPrefix(cookie.Name, cookieName) { - headers = append(headers, cookie.String()) - } - } - r.Header.Set("Cookie", strings.Join(headers, ";")) - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} diff --git a/internal/middleware/reverse_proxy_test.go b/internal/middleware/reverse_proxy_test.go deleted file mode 100644 index ea83d906a..000000000 --- a/internal/middleware/reverse_proxy_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package middleware // import "github.com/pomerium/pomerium/internal/middleware" - -import ( - "encoding/base64" - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/pomerium/pomerium/internal/cryptutil" -) - -const exampleKey = `-----BEGIN EC PRIVATE KEY----- -MHcCAQEEIM3mpZIWXCX9yEgxU6s57CbtbUNDBSCEAtQF5fUWHpcQoAoGCCqGSM49 -AwEHoUQDQgAEhPQv+LACPVNmBTK0xSTzbpEPkRrk1eUt1BOa32SEfUPzNi4IWeZ/ -KKITt2q1IqpV2KMSbVDyr9ijv/Xh98iyEw== ------END EC PRIVATE KEY----- -` - -func TestSignRequest(t *testing.T) { - tests := []struct { - name string - - id string - email string - groups string - header string - }{ - {"good", "id", "email", "group", "Jwt"}, - } - req, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatal(err) - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.Header.Set(fmt.Sprintf("%s-header", tt.id), tt.id) - r.Header.Set(fmt.Sprintf("%s-header", tt.email), tt.email) - r.Header.Set(fmt.Sprintf("%s-header", tt.groups), tt.groups) - - }) - rr := httptest.NewRecorder() - signer, err := cryptutil.NewES256Signer(base64.StdEncoding.EncodeToString([]byte(exampleKey)), "audience") - if err != nil { - t.Fatal(err) - } - - handler := SignRequest(signer, tt.id, tt.email, tt.groups, tt.header)(testHandler) - handler.ServeHTTP(rr, req) - jwt := req.Header["Jwt"] - if len(jwt) != 1 { - t.Errorf("no jwt found %v", req.Header) - } - }) - } -} - -func TestStripPomeriumCookie(t *testing.T) { - tests := []struct { - name string - pomeriumCookie string - otherCookies []string - }{ - {"good", "pomerium", []string{"x", "y", "z"}}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for _, cookie := range r.Cookies() { - if cookie.Name == tt.pomeriumCookie { - t.Errorf("cookie not stripped %s", r.Cookies()) - } - } - }) - rr := httptest.NewRecorder() - for _, cn := range tt.otherCookies { - http.SetCookie(rr, &http.Cookie{ - Name: cn, - Value: "some other cookie", - }) - } - - http.SetCookie(rr, &http.Cookie{ - Name: tt.pomeriumCookie, - Value: "pomerium cookie!", - }) - - http.SetCookie(rr, &http.Cookie{ - Name: tt.pomeriumCookie + "_csrf", - Value: "pomerium csrf cookie!", - }) - req := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}} - - handler := StripPomeriumCookie(tt.pomeriumCookie)(testHandler) - handler.ServeHTTP(rr, req) - - }) - } -} diff --git a/proxy/handlers.go b/proxy/handlers.go index 0b51ba776..67ba61c74 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -7,67 +7,34 @@ import ( "strings" "time" + "github.com/gorilla/mux" "github.com/pomerium/csrf" - "github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/httputil" - "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/templates" "github.com/pomerium/pomerium/internal/urlutil" ) -// Handler returns the proxy service's ServeMux -func (p *Proxy) Handler() http.Handler { - r := httputil.NewRouter() - r.SkipClean(true) - r.StrictSlash(true) - r.Use(middleware.ValidateHost(func(host string) bool { - _, ok := p.routeConfigs[host] - return ok - })) - r.HandleFunc("/robots.txt", p.RobotsTxt) - // requires authN not authZ - r.Use(sessions.RetrieveSession(p.sessionStore)) - r.Use(p.VerifySession) - // Proxy service endpoints - v := r.PathPrefix("/.pomerium").Subrouter() - v.Use(csrf.Protect( +// registerHelperHandlers returns the proxy service's ServeMux +func (p *Proxy) registerHelperHandlers(r *mux.Router) *mux.Router { + h := r.PathPrefix(dashboardURL).Subrouter() + h.Use(sessions.RetrieveSession(p.sessionStore)) + h.Use(p.AuthenticateSession) + h.Use(csrf.Protect( p.cookieSecret, csrf.Path("/"), csrf.Domain(p.cookieDomain), csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieName)), csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)), )) - v.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet) - v.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost) - v.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost) - v.HandleFunc("/refresh", p.ForceRefresh).Methods(http.MethodPost) - - r.PathPrefix("/").HandlerFunc(p.Proxy) + h.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet) + h.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost) + h.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost) + h.HandleFunc("/refresh", p.ForceRefresh).Methods(http.MethodPost) return r } -// VerifySession is the middleware used to enforce a valid authentication -// session state is attached to the users's request context. -func (p *Proxy) VerifySession(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - state, err := sessions.FromContext(r.Context()) - if err != nil { - log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error") - p.authenticate(w, r) - return - } - if err := state.Valid(); err != nil { - log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to invalid session") - p.authenticate(w, r) - return - } - next.ServeHTTP(w, r) - }) -} - // RobotsTxt sets the User-Agent header in the response to be "Disallow" func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") @@ -88,76 +55,6 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, uri.String(), http.StatusFound) } -// Authenticate begins the authenticate flow, encrypting the redirect url -// in a request to the provider's sign in endpoint. -func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request) { - uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r)) - http.Redirect(w, r, uri.String(), http.StatusFound) -} - -// shouldSkipAuthentication contains conditions for skipping authentication. -// Conditions should be few in number and have strong justifications. -func (p *Proxy) shouldSkipAuthentication(r *http.Request) bool { - policy, policyExists := p.policy(r) - - if isCORSPreflight(r) && policyExists && policy.CORSAllowPreflight { - log.FromRequest(r).Debug().Msg("proxy: skipping authentication for valid CORS preflight request") - return true - } - - if policyExists && policy.AllowPublicUnauthenticatedAccess { - log.FromRequest(r).Debug().Msg("proxy: skipping authentication for public route") - return true - } - - return false -} - -// isCORSPreflight inspects the request to see if this is a valid CORS preflight request. -// These checks are not exhaustive, because the proxied server should be verifying it as well. -// -// See https://www.html5rocks.com/static/images/cors_server_flowchart.png for more info. -func isCORSPreflight(r *http.Request) bool { - return r.Method == http.MethodOptions && - r.Header.Get("Access-Control-Request-Method") != "" && - r.Header.Get("Origin") != "" -} - -// Proxy authenticates a request, either proxying the request if it is authenticated, -// or starting the authenticate service for validation if not. -func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { - route, ok := p.router(r) - if !ok { - httputil.ErrorResponse(w, r, httputil.Error("", http.StatusNotFound, nil)) - return - } - - if p.shouldSkipAuthentication(r) { - log.FromRequest(r).Debug().Msg("proxy: access control skipped") - route.ServeHTTP(w, r) - return - } - s, err := sessions.FromContext(r.Context()) - if err != nil || s == nil { - log.Debug().Err(err).Msg("proxy: couldn't get session from context") - p.authenticate(w, r) - return - } - authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s) - if err != nil { - httputil.ErrorResponse(w, r, err) - return - } else if !authorized { - httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.RequestEmail()), http.StatusForbidden, nil)) - return - } - r.Header.Set(HeaderUserID, s.User) - r.Header.Set(HeaderEmail, s.RequestEmail()) - r.Header.Set(HeaderGroups, s.RequestGroups()) - - route.ServeHTTP(w, r) -} - // UserDashboard lets users investigate, and refresh their current session. // It also contains certain administrative actions like user impersonation. // Nota bene: This endpoint does authentication, not authorization. @@ -207,10 +104,9 @@ func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) { // reject a refresh if it's been less than the refresh cooldown to prevent abuse if time.Since(iss) < p.refreshCooldown { - httputil.ErrorResponse(w, r, - httputil.Error( - fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown), - http.StatusBadRequest, nil)) + errStr := fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown) + httpErr := httputil.Error(errStr, http.StatusBadRequest, nil) + httputil.ErrorResponse(w, r, httpErr) return } session.ForceRefresh() @@ -218,7 +114,7 @@ func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) { httputil.ErrorResponse(w, r, err) return } - http.Redirect(w, r, "/.pomerium", http.StatusFound) + http.Redirect(w, r, dashboardURL, http.StatusFound) } // Impersonate takes the result of a form and adds user impersonation details @@ -232,7 +128,9 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) { } isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session) if err != nil || !isAdmin { - httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.RequestEmail()), http.StatusForbidden, err)) + errStr := fmt.Sprintf("%s is not an administrator", session.RequestEmail()) + httpErr := httputil.Error(errStr, http.StatusForbidden, err) + httputil.ErrorResponse(w, r, httpErr) return } // OK to impersonation @@ -247,27 +145,5 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) { return } - http.Redirect(w, r, "/.pomerium", http.StatusFound) -} - -// router attempts to find a route for a request. If a route is successfully matched, -// it returns the route information and a bool value of `true`. If a route can -// not be matched, a nil value for the route and false bool value is returned. -func (p *Proxy) router(r *http.Request) (http.Handler, bool) { - config, ok := p.routeConfigs[r.Host] - if ok { - return config.mux, true - } - return nil, false -} - -// policy attempts to find a policy for a request. If a policy is successfully matched, -// it returns the policy information and a bool value of `true`. If a policy can not be matched, -// a nil value for the policy and false bool value is returned. -func (p *Proxy) policy(r *http.Request) (*config.Policy, bool) { - config, ok := p.routeConfigs[r.Host] - if ok { - return &config.policy, true - } - return nil, false + http.Redirect(w, r, dashboardURL, http.StatusFound) } diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index d74971574..c02a0f563 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -1,4 +1,4 @@ -package proxy +package proxy // import "github.com/pomerium/pomerium/proxy" import ( "bytes" @@ -13,7 +13,6 @@ import ( "github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/cryptutil" - "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/proxy/clients" ) @@ -76,153 +75,114 @@ func TestProxy_authenticate(t *testing.T) { } } -func TestProxy_Handler(t *testing.T) { - proxy, err := New(testOptions(t)) - if err != nil { - t.Fatal(err) - } - h := proxy.Handler() - if h == nil { - t.Error("handler cannot be nil") - } - mux := http.NewServeMux() - mux.Handle("/", h) - req := httptest.NewRequest(http.MethodGet, "/", nil) - rr := httptest.NewRecorder() - mux.ServeHTTP(rr, req) - if rr.Code != http.StatusNotFound { - t.Errorf("expected 404 route not found for empty route") - } -} +// func TestProxy_PomeriumHandler(t *testing.T) { +// proxy, err := New(testOptions(t)) +// if err != nil { +// t.Fatal(err) +// } +// h := proxy.registerHelperHandlers() +// if h == nil { +// t.Error("handler cannot be nil") +// } +// mux := http.NewServeMux() +// mux.Handle("/", h) +// req := httptest.NewRequest(http.MethodGet, "/", nil) +// rr := httptest.NewRecorder() +// mux.ServeHTTP(rr, req) +// if rr.Code != http.StatusNotFound { +// t.Errorf("expected 404 route not found for empty route") +// } +// } -func TestProxy_router(t *testing.T) { - testPolicy := config.Policy{From: "https://corp.example.com", To: "https://example.com"} - if err := testPolicy.Validate(); err != nil { - t.Fatal(err) - } - policies := []config.Policy{testPolicy} - tests := []struct { - name string - host string - mux []config.Policy - route http.Handler - wantOk bool - }{ - {"good corp", "https://corp.example.com", policies, nil, true}, - {"good with slash", "https://corp.example.com/", policies, nil, true}, - {"good with path", "https://corp.example.com/123", policies, nil, true}, - {"no policies", "https://notcorp.example.com/123", []config.Policy{}, nil, false}, - {"bad corp", "https://notcorp.example.com/123", policies, nil, false}, - {"bad sub-sub", "https://notcorp.corp.example.com/123", policies, nil, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - opts := testOptions(t) - opts.Policies = tt.mux - p, err := New(opts) - if err != nil { - t.Fatal(err) - } - p.encoder = &cryptutil.MockEncoder{MarshalResponse: "foo"} +// func TestProxy_Proxy(t *testing.T) { +// goodSession := &sessions.State{ +// AccessToken: "AccessToken", +// RefreshToken: "RefreshToken", +// RefreshDeadline: time.Now().Add(20 * time.Second), +// } - req := httptest.NewRequest(http.MethodGet, tt.host, nil) - _, ok := p.router(req) - if ok != tt.wantOk { - t.Errorf("Proxy.router() ok = %v, want %v", ok, tt.wantOk) - } - }) - } -} +// ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// w.Header().Set("Content-Type", "text/plain; charset=utf-8") +// w.Header().Set("X-Content-Type-Options", "nosniff") +// fmt.Fprintln(w, "RVSI FILIVS CAISAR") +// w.WriteHeader(http.StatusOK) -func TestProxy_Proxy(t *testing.T) { - goodSession := &sessions.State{ - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - RefreshDeadline: time.Now().Add(20 * time.Second), - } +// })) +// defer ts.Close() - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") - fmt.Fprintln(w, "RVSI FILIVS CAISAR") - w.WriteHeader(http.StatusOK) +// opts := testOptionsTestServer(t, ts.URL) +// optsCORS := testOptionsWithCORS(t, ts.URL) +// optsPublic := testOptionsWithPublicAccess(t, ts.URL) +// optsNoPolicies := testOptionsWithEmptyPolicies(t, ts.URL) - })) - defer ts.Close() +// defaultHeaders, goodCORSHeaders, badCORSHeaders, headersWs := http.Header{}, http.Header{}, http.Header{}, http.Header{} +// goodCORSHeaders.Set("origin", "anything") +// goodCORSHeaders.Set("access-control-request-method", "anything") +// // missing "Origin" +// badCORSHeaders.Set("access-control-request-method", "anything") +// headersWs.Set("Connection", "Upgrade") +// headersWs.Set("Upgrade", "websocket") - opts := testOptionsTestServer(t, ts.URL) - optsCORS := testOptionsWithCORS(t, ts.URL) - optsPublic := testOptionsWithPublicAccess(t, ts.URL) - optsNoPolicies := testOptionsWithEmptyPolicies(t, ts.URL) +// tests := []struct { +// name string +// options config.Options +// method string +// header http.Header +// host string +// session sessions.SessionStore +// authorizer clients.Authorizer +// wantStatus int +// }{ +// {"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(20 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, +// {"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK}, +// {"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, +// {"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, +// // same request as above, but with cors_allow_preflight=false in the policy +// {"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, +// // cors allowed, but the request is missing proper headers +// {"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, +// // redirect to start auth process +// {"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, +// {"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, +// {"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError}, +// // authenticate errors +// {"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound}, +// {"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK}, +// {"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound}, +// {"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound}, +// {"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, +// } - defaultHeaders, goodCORSHeaders, badCORSHeaders, headersWs := http.Header{}, http.Header{}, http.Header{}, http.Header{} - goodCORSHeaders.Set("origin", "anything") - goodCORSHeaders.Set("access-control-request-method", "anything") - // missing "Origin" - badCORSHeaders.Set("access-control-request-method", "anything") - headersWs.Set("Connection", "Upgrade") - headersWs.Set("Upgrade", "websocket") +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// err := ValidateOptions(tt.options) +// if err != nil { +// t.Fatal(err) +// } +// p, err := New(tt.options) +// if err != nil { +// t.Fatal(err) +// } +// p.encoder = &cryptutil.MockEncoder{MarshalResponse: "foo"} +// p.sessionStore = tt.session +// p.AuthorizeClient = tt.authorizer +// r := httptest.NewRequest(tt.method, tt.host, nil) +// r.Header = tt.header +// r.Header.Set("Accept", "application/json") +// state, _ := tt.session.LoadSession(r) +// ctx := r.Context() +// ctx = sessions.NewContext(ctx, state, nil) +// r = r.WithContext(ctx) - tests := []struct { - name string - options config.Options - method string - header http.Header - host string - session sessions.SessionStore - authorizer clients.Authorizer - wantStatus int - }{ - {"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(20 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, - {"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK}, - {"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, - {"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, - // same request as above, but with cors_allow_preflight=false in the policy - {"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, - // cors allowed, but the request is missing proper headers - {"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, - // redirect to start auth process - {"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, - {"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, - {"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError}, - // authenticate errors - {"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound}, - {"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK}, - {"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound}, - {"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound}, - {"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, - } +// w := httptest.NewRecorder() +// p.Proxy(w, r) +// if status := w.Code; status != tt.wantStatus { +// t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String()) +// } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := ValidateOptions(tt.options) - if err != nil { - t.Fatal(err) - } - p, err := New(tt.options) - if err != nil { - t.Fatal(err) - } - p.encoder = &cryptutil.MockEncoder{MarshalResponse: "foo"} - p.sessionStore = tt.session - p.AuthorizeClient = tt.authorizer - r := httptest.NewRequest(tt.method, tt.host, nil) - r.Header = tt.header - r.Header.Set("Accept", "application/json") - state, _ := tt.session.LoadSession(r) - ctx := r.Context() - ctx = sessions.NewContext(ctx, state, nil) - r = r.WithContext(ctx) - - w := httptest.NewRecorder() - p.Proxy(w, r) - if status := w.Code; status != tt.wantStatus { - t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String()) - } - - }) - } -} +// }) +// } +// } func TestProxy_UserDashboard(t *testing.T) { opts := testOptions(t) @@ -427,55 +387,9 @@ func TestProxy_SignOut(t *testing.T) { } } func uriParseHelper(s string) *url.URL { - uri, _ := url.Parse(s) + uri, err := url.Parse(s) + if err != nil { + panic(err) + } return uri } -func TestProxy_VerifySession(t *testing.T) { - fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") - fmt.Fprintln(w, "RVSI FILIVS CAISAR") - w.WriteHeader(http.StatusOK) - }) - - tests := []struct { - name string - session sessions.SessionStore - ctxError error - provider identity.Authenticator - - wantStatus int - }{ - {"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK}, - {"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound}, - {"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusFound}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - a := Proxy{ - SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", - cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), - authenticateURL: uriParseHelper("https://authenticate.corp.example"), - authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"), - sessionStore: tt.session, - } - r := httptest.NewRequest("GET", "/", nil) - state, _ := tt.session.LoadSession(r) - ctx := r.Context() - ctx = sessions.NewContext(ctx, state, tt.ctxError) - r = r.WithContext(ctx) - - r.Header.Set("Accept", "application/json") - - w := httptest.NewRecorder() - - got := a.VerifySession(fn) - got.ServeHTTP(w, r) - if status := w.Code; status != tt.wantStatus { - t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) - - } - }) - } -} diff --git a/proxy/middleware.go b/proxy/middleware.go new file mode 100644 index 000000000..ca23934d3 --- /dev/null +++ b/proxy/middleware.go @@ -0,0 +1,103 @@ +package proxy // import "github.com/pomerium/pomerium/proxy" + +import ( + "fmt" + "net/http" + "strings" + + "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/internal/urlutil" +) + +const ( + // HeaderJWT is the header key containing JWT signed user details. + HeaderJWT = "x-pomerium-jwt-assertion" + // HeaderUserID is the header key containing the user's id. + HeaderUserID = "x-pomerium-authenticated-user-id" + // HeaderEmail is the header key containing the user's email. + HeaderEmail = "x-pomerium-authenticated-user-email" + // HeaderGroups is the header key containing the user's groups. + HeaderGroups = "x-pomerium-authenticated-user-groups" +) + +// AuthenticateSession is middleware to enforce a valid authentication +// session state is retrieved from the users's request context. +func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, span := trace.StartSpan(r.Context(), "middleware.AuthenticateSession") + defer span.End() + s, err := sessions.FromContext(r.Context()) + if err != nil { + log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error") + p.authenticate(w, r) + return + } + if err := s.Valid(); err != nil { + log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to invalid session") + p.authenticate(w, r) + return + } + r.Header.Set(HeaderUserID, s.User) + r.Header.Set(HeaderEmail, s.RequestEmail()) + r.Header.Set(HeaderGroups, s.RequestGroups()) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// AuthorizeSession is middleware to enforce a user is authorized for a request +// session state is retrieved from the users's request context. +func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, span := trace.StartSpan(r.Context(), "middleware.AuthorizeSession") + defer span.End() + s, err := sessions.FromContext(r.Context()) + if err != nil { + httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err)) + return + } + authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s) + if err != nil { + httputil.ErrorResponse(w, r.WithContext(ctx), err) + return + } else if !authorized { + errMsg := fmt.Sprintf("%s is not authorized for this route", s.RequestEmail()) + httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error(errMsg, http.StatusForbidden, nil)) + return + } + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// SignRequest is middleware that signs a JWT that contains a user's id, +// email, and group. Session state is retrieved from the users's request context +func (p *Proxy) SignRequest(signer cryptutil.JWTSigner) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, span := trace.StartSpan(r.Context(), "middleware.SignRequest") + defer span.End() + s, err := sessions.FromContext(r.Context()) + if err != nil { + httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err)) + return + } + jwt, err := signer.SignJWT(s.User, s.Email, strings.Join(s.Groups, ",")) + if err != nil { + log.Warn().Err(err).Msg("proxy: failed signing jwt") + } else { + r.Header.Set(HeaderJWT, jwt) + } + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// Authenticate begins the authenticate flow, encrypting the redirect url +// in a request to the provider's sign in endpoint. +func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request) { + uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r)) + http.Redirect(w, r, uri.String(), http.StatusFound) +} diff --git a/proxy/middleware_test.go b/proxy/middleware_test.go new file mode 100644 index 000000000..96a930a2e --- /dev/null +++ b/proxy/middleware_test.go @@ -0,0 +1,175 @@ +package proxy + +import ( + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/pomerium/pomerium/internal/identity" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/proxy/clients" +) + +func TestProxy_AuthenticateSession(t *testing.T) { + t.Parallel() + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + fmt.Fprint(w, http.StatusText(http.StatusOK)) + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + session sessions.SessionStore + ctxError error + provider identity.Authenticator + + wantStatus int + }{ + {"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK}, + {"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound}, + {"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusFound}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + a := Proxy{ + SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", + cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), + authenticateURL: uriParseHelper("https://authenticate.corp.example"), + authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"), + sessionStore: tt.session, + } + r := httptest.NewRequest(http.MethodGet, "/", nil) + state, _ := tt.session.LoadSession(r) + ctx := r.Context() + ctx = sessions.NewContext(ctx, state, tt.ctxError) + r = r.WithContext(ctx) + r.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + got := a.AuthenticateSession(fn) + got.ServeHTTP(w, r) + if status := w.Code; status != tt.wantStatus { + t.Errorf("AuthenticateSession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) + } + }) + } +} + +func TestProxy_AuthorizeSession(t *testing.T) { + t.Parallel() + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + fmt.Fprint(w, http.StatusText(http.StatusOK)) + w.WriteHeader(http.StatusOK) + }) + tests := []struct { + name string + session sessions.SessionStore + authzClient clients.Authorizer + + ctxError error + provider identity.Authenticator + + wantStatus int + }{ + {"user is authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK}, + {"user is not authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusForbidden}, + {"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusForbidden}, + {"authz client error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + a := Proxy{ + SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", + cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), + authenticateURL: uriParseHelper("https://authenticate.corp.example"), + authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"), + sessionStore: tt.session, + AuthorizeClient: tt.authzClient, + } + r := httptest.NewRequest(http.MethodGet, "/", nil) + state, _ := tt.session.LoadSession(r) + ctx := r.Context() + ctx = sessions.NewContext(ctx, state, tt.ctxError) + r = r.WithContext(ctx) + r.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + got := a.AuthorizeSession(fn) + got.ServeHTTP(w, r) + if status := w.Code; status != tt.wantStatus { + t.Errorf("AuthorizeSession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) + } + }) + } +} + +type mockJWTSigner struct { + SignError error +} + +// Sign implements the JWTSigner interface from the cryptutil package, but just +// base64's the inputs instead for stesting. +func (s *mockJWTSigner) SignJWT(user, email, groups string) (string, error) { + return base64.StdEncoding.EncodeToString([]byte(fmt.Sprint(user, email, groups))), s.SignError +} + +func TestProxy_SignRequest(t *testing.T) { + t.Parallel() + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + fmt.Fprint(w, http.StatusText(http.StatusOK)) + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + session sessions.SessionStore + + signerError error + ctxError error + + wantStatus int + wantHeaders string + }{ + {"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "dGVzdA=="}, + {"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""}, + {"signature failure, warn but ok", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + a := Proxy{ + SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", + cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), + authenticateURL: uriParseHelper("https://authenticate.corp.example"), + authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"), + sessionStore: tt.session, + } + r := httptest.NewRequest(http.MethodGet, "/", nil) + state, _ := tt.session.LoadSession(r) + ctx := r.Context() + ctx = sessions.NewContext(ctx, state, tt.ctxError) + r = r.WithContext(ctx) + r.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + signer := &mockJWTSigner{SignError: tt.signerError} + got := a.SignRequest(signer)(fn) + got.ServeHTTP(w, r) + if status := w.Code; status != tt.wantStatus { + t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) + } + if headers := r.Header.Get(HeaderJWT); tt.wantHeaders != headers { + t.Errorf("SignRequest() headers = %v, want %v", headers, tt.wantHeaders) + } + }) + } +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 00f805950..0f6df968e 100755 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -5,15 +5,14 @@ import ( "encoding/base64" "fmt" "html/template" - stdlog "log" "net/http" - "net/http/httputil" "net/url" "time" + "github.com/gorilla/mux" "github.com/pomerium/pomerium/internal/config" "github.com/pomerium/pomerium/internal/cryptutil" - pom_httputil "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/sessions" @@ -25,15 +24,8 @@ import ( ) const ( - // HeaderJWT is the header key containing JWT signed user details. - HeaderJWT = "x-pomerium-jwt-assertion" - // HeaderUserID is the header key containing the user's id. - HeaderUserID = "x-pomerium-authenticated-user-id" - // HeaderEmail is the header key containing the user's email. - HeaderEmail = "x-pomerium-authenticated-user-email" - // HeaderGroups is the header key containing the user's groups. - HeaderGroups = "x-pomerium-authenticated-user-groups" - + // dashboardURL is the path to authenticate's sign in endpoint + dashboardURL = "/.pomerium/" // signinURL is the path to authenticate's sign in endpoint signinURL = "/.pomerium/sign_in" // signoutURL is the path to authenticate's sign out endpoint @@ -78,38 +70,29 @@ type Proxy struct { AuthorizeClient clients.Authorizer - // cipher cipher.AEAD encoder cryptutil.SecureEncoder cookieName string cookieDomain string cookieSecret []byte defaultUpstreamTimeout time.Duration refreshCooldown time.Duration - routeConfigs map[string]*routeConfig + Handler http.Handler sessionStore sessions.SessionStore signingKey string templates *template.Template } -type routeConfig struct { - mux http.Handler - policy config.Policy -} - // New takes a Proxy service from options and a validation function. // Function returns an error if options fail to validate. func New(opts config.Options) (*Proxy, error) { if err := ValidateOptions(opts); err != nil { return nil, err } - decodedCookieSecret, err := base64.StdEncoding.DecodeString(opts.CookieSecret) - if err != nil { - return nil, err - } - cipher, err := cryptutil.NewAEADCipherFromBase64(opts.CookieSecret) - if err != nil { - return nil, err - } + + // errors checked in ValidateOptions + decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret) + cipher, _ := cryptutil.NewAEADCipherFromBase64(opts.CookieSecret) + encoder := cryptutil.NewSecureJSONEncoder(cipher) if opts.CookieDomain == "" { @@ -132,7 +115,6 @@ func New(opts config.Options) (*Proxy, error) { p := &Proxy{ SharedKey: opts.SharedKey, - routeConfigs: make(map[string]*routeConfig), encoder: encoder, cookieSecret: decodedCookieSecret, cookieDomain: opts.CookieDomain, @@ -143,18 +125,18 @@ func New(opts config.Options) (*Proxy, error) { signingKey: opts.SigningKey, templates: templates.New(), } - // DeepCopy urls to avoid accidental mutation, err checked in validate func + // errors checked in ValidateOptions + p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL) p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL) + p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL}) p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL}) - p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL) - if err := p.UpdatePolicies(&opts); err != nil { return nil, err } metrics.AddPolicyCountCallback("proxy", func() int64 { - return int64(len(p.routeConfigs)) + return int64(len(opts.Policies)) }) p.AuthorizeClient, err = clients.NewAuthorizeClient("grpc", &clients.Options{ @@ -169,116 +151,6 @@ func New(opts config.Options) (*Proxy, error) { return p, err } -// UpdatePolicies updates the handlers based on the configured policies -func (p *Proxy) UpdatePolicies(opts *config.Options) error { - routeConfigs := make(map[string]*routeConfig, len(opts.Policies)) - if len(opts.Policies) == 0 { - log.Warn().Msg("proxy: configuration has no policies") - } - for _, policy := range opts.Policies { - if err := policy.Validate(); err != nil { - return fmt.Errorf("proxy: couldn't update policies %s", err) - } - proxy := NewReverseProxy(policy.Destination) - // build http transport (roundtripper) middleware chain - transport := http.DefaultTransport.(*http.Transport).Clone() - c := tripper.NewChain() - c = c.Append(metrics.HTTPMetricsRoundTripper("proxy", policy.Destination.Host)) - - var tlsClientConfig tls.Config - var isCustomClientConfig bool - if policy.TLSSkipVerify { - tlsClientConfig.InsecureSkipVerify = true - isCustomClientConfig = true - log.Warn().Str("policy", policy.String()).Msg("proxy: tls skip verify") - } - if policy.RootCAs != nil { - tlsClientConfig.RootCAs = policy.RootCAs - isCustomClientConfig = true - log.Debug().Str("policy", policy.String()).Msg("proxy: custom root ca") - } - - if policy.ClientCertificate != nil { - tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate} - isCustomClientConfig = true - log.Debug().Str("policy", policy.String()).Msg("proxy: client certs enabled") - } - - if policy.TLSServerName != "" { - tlsClientConfig.ServerName = policy.TLSServerName - isCustomClientConfig = true - log.Debug().Str("policy", policy.String()).Msgf("proxy: tls hostname override to: %s", policy.TLSServerName) - } - - // We avoid setting a custom client config unless we have to as - // if TLSClientConfig is nil, the default configuration is used. - if isCustomClientConfig { - transport.TLSClientConfig = &tlsClientConfig - } - proxy.Transport = c.Then(transport) - - handler, err := p.newReverseProxyHandler(proxy, &policy) - if err != nil { - return err - } - routeConfigs[policy.Source.Host] = &routeConfig{ - mux: handler, - policy: policy, - } - } - p.routeConfigs = routeConfigs - return nil -} - -// NewReverseProxy returns a new ReverseProxy that routes URLs to the scheme, host, and -// base path provided in target. NewReverseProxy rewrites the Host header. -func NewReverseProxy(to *url.URL) *httputil.ReverseProxy { - proxy := httputil.NewSingleHostReverseProxy(to) - sublogger := log.With().Str("proxy", to.Host).Logger() - proxy.ErrorLog = stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0) - director := proxy.Director - proxy.Director = func(req *http.Request) { - // Identifies the originating IP addresses of a client connecting to - // a web server through an HTTP proxy or a load balancer. - req.Header.Add("X-Forwarded-Host", req.Host) - director(req) - req.Host = to.Host - } - return proxy -} - -// each route has a custom set of middleware applied to the reverse proxy -func (p *Proxy) newReverseProxyHandler(rp http.Handler, route *config.Policy) (http.Handler, error) { - r := pom_httputil.NewRouter() - r.SkipClean(true) - r.StrictSlash(true) - r.Use(middleware.StripPomeriumCookie(p.cookieName)) - // if signing key is set, add signer to middleware - if len(p.signingKey) != 0 { - signer, err := cryptutil.NewES256Signer(p.signingKey, route.Source.Host) - if err != nil { - return nil, err - } - r.Use(middleware.SignRequest(signer, HeaderUserID, HeaderEmail, HeaderGroups, HeaderJWT)) - } - // websockets cannot use the non-hijackable timeout-handler - if !route.AllowWebsockets { - timeout := p.defaultUpstreamTimeout - if route.UpstreamTimeout != 0 { - timeout = route.UpstreamTimeout - } - timeoutMsg := fmt.Sprintf("%s timed out in %s", route.Destination.Host, timeout) - rp = http.TimeoutHandler(rp, timeout, timeoutMsg) - } - // todo(bdd) : fix cors - // if route.CORSAllowPreflight { - // r.Use(cors.Default().Handler) - // } - r.Host(route.Destination.Host) - r.PathPrefix("/").Handler(rp) - return r, nil -} - // UpdateOptions updates internal structures based on config.Options func (p *Proxy) UpdateOptions(o config.Options) error { if p == nil { @@ -287,3 +159,123 @@ func (p *Proxy) UpdateOptions(o config.Options) error { log.Info().Msg("proxy: updating options") return p.UpdatePolicies(&o) } + +// UpdatePolicies updates the H basedon the configured policies +func (p *Proxy) UpdatePolicies(opts *config.Options) error { + var err error + if len(opts.Policies) == 0 { + log.Warn().Msg("proxy: configuration has no policies") + } + r := httputil.NewRouter() + r.SkipClean(true) + r.StrictSlash(true) + r.HandleFunc("/robots.txt", p.RobotsTxt).Methods(http.MethodGet) + r = p.registerHelperHandlers(r) + + for _, policy := range opts.Policies { + if err := policy.Validate(); err != nil { + return fmt.Errorf("proxy: invalid policy %s", err) + } + r, err = p.reverseProxyHandler(r, &policy) + if err != nil { + return err + } + } + p.Handler = r + return nil +} + +func (p *Proxy) reverseProxyHandler(r *mux.Router, policy *config.Policy) (*mux.Router, error) { + // 1. Create the reverse proxy connection + proxy := httputil.NewReverseProxy(policy.Destination) + // 2. Override any custom transport settings (e.g. TLS settings, etc) + proxy.Transport = p.roundTripperFromPolicy(policy) + // 3. Create a sub-router for a given route's hostname (`httpbin.corp.example.com`) + rp := r.Host(policy.Source.Host).Subrouter() + rp.PathPrefix("/").Handler(proxy) + + // Optional: If websockets are enabled, do not set a handler request timeout + // websockets cannot use the non-hijackable timeout-handler + if !policy.AllowWebsockets { + timeout := p.defaultUpstreamTimeout + if policy.UpstreamTimeout != 0 { + timeout = policy.UpstreamTimeout + } + timeoutMsg := fmt.Sprintf("%s timed out in %s", policy.Destination.Host, timeout) + rp.Use(middleware.TimeoutHandlerFunc(timeout, timeoutMsg)) + } + + // Optional: a cors preflight check, skip access control middleware + if policy.CORSAllowPreflight { + log.Warn().Str("route", policy.String()).Msg("proxy: cors preflight enabled") + rp.Use(middleware.CorsBypass(proxy)) + } + + // Optional: if a public route, skip access control middleware + if policy.AllowPublicUnauthenticatedAccess { + log.Warn().Str("route", policy.String()).Msg("proxy: all access control disabled") + return r, nil + } + + // 4. Retrieve the user session and add it to the request context + rp.Use(sessions.RetrieveSession(p.sessionStore)) + // 5. Strip the user session cookie from the downstream request + rp.Use(middleware.StripCookie(p.cookieName)) + // 6. AuthN - Verify the user is authenticated. Set email, group, & id headers + rp.Use(p.AuthenticateSession) + // 7. AuthZ - Verify the user is authorized for route + rp.Use(p.AuthorizeSession) + // Optional: Add a signed JWT attesting to the user's id, email, and group + if len(p.signingKey) != 0 { + signer, err := cryptutil.NewES256Signer(p.signingKey, policy.Source.Host) + if err != nil { + return nil, err + } + rp.Use(p.SignRequest(signer)) + } + + return r, nil +} + +// roundTripperFromPolicy adjusts the std library's `DefaultTransport RoundTripper` +// for a given route. A route's `RoundTripper` establishes network connections +// as needed and caches them for reuse by subsequent calls. +func (p *Proxy) roundTripperFromPolicy(policy *config.Policy) http.RoundTripper { + transport := http.DefaultTransport.(*http.Transport).Clone() + c := tripper.NewChain() + c = c.Append(metrics.HTTPMetricsRoundTripper("proxy", policy.Destination.Host)) + + var tlsClientConfig tls.Config + var isCustomClientConfig bool + + if policy.TLSSkipVerify { + tlsClientConfig.InsecureSkipVerify = true + isCustomClientConfig = true + log.Warn().Str("policy", policy.String()).Msg("proxy: tls skip verify") + } + + if policy.RootCAs != nil { + tlsClientConfig.RootCAs = policy.RootCAs + isCustomClientConfig = true + log.Debug().Str("policy", policy.String()).Msg("proxy: custom root ca") + } + + if policy.ClientCertificate != nil { + tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate} + isCustomClientConfig = true + log.Debug().Str("policy", policy.String()).Msg("proxy: client certs enabled") + } + + if policy.TLSServerName != "" { + tlsClientConfig.ServerName = policy.TLSServerName + isCustomClientConfig = true + log.Debug().Str("policy", policy.String()).Msgf("proxy: tls override hostname: %s", policy.TLSServerName) + } + + // We avoid setting a custom client config unless we have to as + // if TLSClientConfig is nil, the default configuration is used. + if isCustomClientConfig { + transport.TLSClientConfig = &tlsClientConfig + } + return c.Then(transport) +} diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 3dffb2179..100a4e118 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -1,8 +1,6 @@ package proxy // import "github.com/pomerium/pomerium/proxy" import ( - "io/ioutil" - "net" "net/http" "net/http/httptest" "net/url" @@ -21,72 +19,6 @@ func newTestOptions(t *testing.T) *config.Options { return opts } -func TestNewReverseProxy(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - hostname, _, _ := net.SplitHostPort(r.Host) - w.Write([]byte(hostname)) - })) - defer backend.Close() - - backendURL, _ := url.Parse(backend.URL) - backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) - backendHost := net.JoinHostPort(backendHostname, backendPort) - proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") - - proxyHandler := NewReverseProxy(proxyURL) - frontend := httptest.NewServer(proxyHandler) - defer frontend.Close() - - getReq, _ := http.NewRequest("GET", frontend.URL, nil) - res, _ := http.DefaultClient.Do(getReq) - bodyBytes, _ := ioutil.ReadAll(res.Body) - if g, e := string(bodyBytes), backendHostname; g != e { - t.Errorf("got body %q; expected %q", g, e) - } -} - -func TestNewReverseProxyHandler(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - hostname, _, _ := net.SplitHostPort(r.Host) - w.Write([]byte(hostname)) - })) - defer backend.Close() - - backendURL, _ := url.Parse(backend.URL) - backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) - backendHost := net.JoinHostPort(backendHostname, backendPort) - proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") - proxyHandler := NewReverseProxy(proxyURL) - opts := newTestOptions(t) - opts.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSU0zbXBaSVdYQ1g5eUVneFU2czU3Q2J0YlVOREJTQ0VBdFFGNWZVV0hwY1FvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFaFBRditMQUNQVk5tQlRLMHhTVHpicEVQa1JyazFlVXQxQk9hMzJTRWZVUHpOaTRJV2VaLwpLS0lUdDJxMUlxcFYyS01TYlZEeXI5aWp2L1hoOThpeUV3PT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=" - testPolicy := config.Policy{From: "https://corp.example.com", To: "https://example.com", UpstreamTimeout: 1 * time.Second} - if err := testPolicy.Validate(); err != nil { - t.Fatal(err) - } - p, err := New(*opts) - if err != nil { - t.Fatal(err) - } - handle, err := p.newReverseProxyHandler(proxyHandler, &testPolicy) - if err != nil { - t.Fatal(err) - } - - frontend := httptest.NewServer(handle) - - defer frontend.Close() - - getReq, _ := http.NewRequest("GET", frontend.URL, nil) - - res, _ := http.DefaultClient.Do(getReq) - bodyBytes, _ := ioutil.ReadAll(res.Body) - if g, e := string(bodyBytes), backendHostname; g != e { - t.Errorf("got body %q; expected %q", g, e) - } -} - func testOptions(t *testing.T) config.Options { authenticateService, _ := url.Parse("https://authenticate.corp.beyondperimeter.com") authorizeService, _ := url.Parse("https://authorize.corp.beyondperimeter.com") @@ -106,61 +38,9 @@ func testOptions(t *testing.T) config.Options { return *opts } -func testOptionsTestServer(t *testing.T, uri string) config.Options { - authenticateService, _ := url.Parse("https://authenticate.corp.beyondperimeter.com") - authorizeService, _ := url.Parse("https://authorize.corp.beyondperimeter.com") - testPolicy := config.Policy{ - From: "https://httpbin.corp.example", - To: uri, - } - if err := testPolicy.Validate(); err != nil { - t.Fatal(err) - } - opts := newTestOptions(t) - opts.Policies = []config.Policy{testPolicy} - opts.AuthenticateURL = authenticateService - opts.AuthorizeURL = authorizeService - opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=" - opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=" - opts.CookieName = "pomerium" - return *opts -} - -func testOptionsWithCORS(t *testing.T, uri string) config.Options { - testPolicy := config.Policy{ - From: "https://httpbin.corp.example", - To: uri, - CORSAllowPreflight: true, - } - if err := testPolicy.Validate(); err != nil { - t.Fatal(err) - } - opts := testOptionsTestServer(t, uri) - opts.Policies = []config.Policy{testPolicy} - return opts -} - -func testOptionsWithPublicAccess(t *testing.T, uri string) config.Options { - testPolicy := config.Policy{ - From: "https://httpbin.corp.example", - To: uri, - AllowPublicUnauthenticatedAccess: true, - } - if err := testPolicy.Validate(); err != nil { - t.Fatal(err) - } - opts := testOptions(t) - opts.Policies = []config.Policy{testPolicy} - return opts -} - -func testOptionsWithEmptyPolicies(t *testing.T, uri string) config.Options { - opts := testOptionsTestServer(t, uri) - opts.Policies = []config.Policy{} - return opts -} - func TestOptions_Validate(t *testing.T) { + t.Parallel() + good := testOptions(t) badAuthURL := testOptions(t) badAuthURL.AuthenticateURL = nil @@ -215,22 +95,33 @@ func TestOptions_Validate(t *testing.T) { } func TestNew(t *testing.T) { + t.Parallel() + good := testOptions(t) shortCookieLength := testOptions(t) shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" badRoutedProxy := testOptions(t) badRoutedProxy.SigningKey = "YmFkIGtleQo=" + badCookie := testOptions(t) + badCookie.CookieName = "" + badPolicyURL := config.Policy{To: "http://", From: "http://bar.example"} + badNewPolicy := testOptions(t) + badNewPolicy.Policies = []config.Policy{ + badPolicyURL, + } + tests := []struct { name string opts config.Options wantProxy bool - numRoutes int wantErr bool }{ - {"good", good, true, 1, false}, - {"empty options", config.Options{}, false, 0, true}, - {"short secret/validate sanity check", shortCookieLength, false, 0, true}, - {"invalid ec key, valid base64 though", badRoutedProxy, false, 0, true}, + {"good", good, true, false}, + {"empty options", config.Options{}, false, true}, + {"short secret/validate sanity check", shortCookieLength, false, true}, + {"invalid ec key, valid base64 though", badRoutedProxy, false, true}, + {"invalid cookie name, empty", badCookie, false, true}, + {"bad policy, bad policy url", badNewPolicy, false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -242,21 +133,17 @@ func TestNew(t *testing.T) { if got == nil && tt.wantProxy == true { t.Errorf("New() expected valid proxy struct") } - if got != nil && len(got.routeConfigs) != tt.numRoutes { - t.Errorf("New() = num routeConfigs \n%+v, want \n%+v \nfrom %+v", got, tt.numRoutes, tt.opts) - } }) } } func Test_UpdateOptions(t *testing.T) { + t.Parallel() good := testOptions(t) newPolicy := config.Policy{To: "http://foo.example", From: "http://bar.example"} newPolicies := testOptions(t) - newPolicies.Policies = []config.Policy{ - newPolicy, - } + newPolicies.Policies = []config.Policy{newPolicy} err := newPolicy.Validate() if err != nil { t.Fatal(err) @@ -268,26 +155,33 @@ func Test_UpdateOptions(t *testing.T) { } disableTLSPolicy := config.Policy{To: "http://foo.example", From: "http://bar.example", TLSSkipVerify: true} disableTLSPolicies := testOptions(t) - disableTLSPolicies.Policies = []config.Policy{ - disableTLSPolicy, - } + disableTLSPolicies.Policies = []config.Policy{disableTLSPolicy} + customCAPolicies := testOptions(t) - customCAPolicies.Policies = []config.Policy{ - {To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURlVENDQW1HZ0F3SUJBZ0lKQUszMmhoR0JIcmFtTUEwR0NTcUdTSWIzRFFFQkN3VUFNR0l4Q3pBSkJnTlYKQkFZVEFsVlRNUk13RVFZRFZRUUlEQXBEWVd4cFptOXlibWxoTVJZd0ZBWURWUVFIREExVFlXNGdSbkpoYm1OcApjMk52TVE4d0RRWURWUVFLREFaQ1lXUlRVMHd4RlRBVEJnTlZCQU1NRENvdVltRmtjM05zTG1OdmJUQWVGdzB4Ck9UQTJNVEl4TlRNeE5UbGFGdzB5TVRBMk1URXhOVE14TlRsYU1HSXhDekFKQmdOVkJBWVRBbFZUTVJNd0VRWUQKVlFRSURBcERZV3hwWm05eWJtbGhNUll3RkFZRFZRUUhEQTFUWVc0Z1JuSmhibU5wYzJOdk1ROHdEUVlEVlFRSwpEQVpDWVdSVFUwd3hGVEFUQmdOVkJBTU1EQ291WW1Ga2MzTnNMbU52YlRDQ0FTSXdEUVlKS29aSWh2Y05BUUVCCkJRQURnZ0VQQURDQ0FRb0NnZ0VCQU1JRTdQaU03Z1RDczloUTFYQll6Sk1ZNjF5b2FFbXdJclg1bFo2eEt5eDIKUG16QVMyQk1UT3F5dE1BUGdMYXcrWExKaGdMNVhFRmRFeXQvY2NSTHZPbVVMbEEzcG1jY1lZejJRVUxGUnRNVwpoeWVmZE9zS25SRlNKaUZ6YklSTWVWWGswV3ZvQmoxSUZWS3RzeWpicXY5dS8yQ1ZTbmRyT2ZFazBURzIzVTNBCnhQeFR1VzFDcmJWOC9xNzFGZEl6U09jaWNjZkNGSHBzS09vM1N0L3FiTFZ5dEg1YW9oYmNhYkZYUk5zS0VxdmUKd3c5SGRGeEJJdUdhK1J1VDVxMGlCaWt1c2JwSkhBd25ucVA3aS9kQWNnQ3NrZ2paakZlRVU0RUZ5K2IrYTFTWQpRQ2VGeHhDN2MzRHZhUmhCQjBWVmZQbGtQejBzdzZsODY1TWFUSWJSeW9VQ0F3RUFBYU15TURBd0NRWURWUjBUCkJBSXdBREFqQmdOVkhSRUVIREFhZ2d3cUxtSmhaSE56YkM1amIyMkNDbUpoWkhOemJDNWpiMjB3RFFZSktvWkkKaHZjTkFRRUxCUUFEZ2dFQkFJaTV1OXc4bWdUNnBwQ2M3eHNHK0E5ZkkzVzR6K3FTS2FwaHI1bHM3MEdCS2JpWQpZTEpVWVpoUGZXcGgxcXRra1UwTEhGUG04M1ZhNTJlSUhyalhUMFZlNEt0TzFuMElBZkl0RmFXNjJDSmdoR1luCmp6dzByeXpnQzRQeUZwTk1uTnRCcm9QdS9iUGdXaU1nTE9OcEVaaGlneDRROHdmMVkvVTlzK3pDQ3hvSmxhS1IKTVhidVE4N1g3bS85VlJueHhvNk56NVpmN09USFRwTk9JNlZqYTBCeGJtSUFVNnlyaXc5VXJnaWJYZk9qM2o2bgpNVExCdWdVVklCMGJCYWFzSnNBTUsrdzRMQU52YXBlWjBET1NuT1I0S0syNEowT3lvRjVmSG1wNTllTTE3SW9GClFxQmh6cG1RVWd1bmVjRVc4QlRxck5wRzc5UjF1K1YrNHd3Y2tQYz0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="}, - } + customCAPolicies.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURlVENDQW1HZ0F3SUJBZ0lKQUszMmhoR0JIcmFtTUEwR0NTcUdTSWIzRFFFQkN3VUFNR0l4Q3pBSkJnTlYKQkFZVEFsVlRNUk13RVFZRFZRUUlEQXBEWVd4cFptOXlibWxoTVJZd0ZBWURWUVFIREExVFlXNGdSbkpoYm1OcApjMk52TVE4d0RRWURWUVFLREFaQ1lXUlRVMHd4RlRBVEJnTlZCQU1NRENvdVltRmtjM05zTG1OdmJUQWVGdzB4Ck9UQTJNVEl4TlRNeE5UbGFGdzB5TVRBMk1URXhOVE14TlRsYU1HSXhDekFKQmdOVkJBWVRBbFZUTVJNd0VRWUQKVlFRSURBcERZV3hwWm05eWJtbGhNUll3RkFZRFZRUUhEQTFUWVc0Z1JuSmhibU5wYzJOdk1ROHdEUVlEVlFRSwpEQVpDWVdSVFUwd3hGVEFUQmdOVkJBTU1EQ291WW1Ga2MzTnNMbU52YlRDQ0FTSXdEUVlKS29aSWh2Y05BUUVCCkJRQURnZ0VQQURDQ0FRb0NnZ0VCQU1JRTdQaU03Z1RDczloUTFYQll6Sk1ZNjF5b2FFbXdJclg1bFo2eEt5eDIKUG16QVMyQk1UT3F5dE1BUGdMYXcrWExKaGdMNVhFRmRFeXQvY2NSTHZPbVVMbEEzcG1jY1lZejJRVUxGUnRNVwpoeWVmZE9zS25SRlNKaUZ6YklSTWVWWGswV3ZvQmoxSUZWS3RzeWpicXY5dS8yQ1ZTbmRyT2ZFazBURzIzVTNBCnhQeFR1VzFDcmJWOC9xNzFGZEl6U09jaWNjZkNGSHBzS09vM1N0L3FiTFZ5dEg1YW9oYmNhYkZYUk5zS0VxdmUKd3c5SGRGeEJJdUdhK1J1VDVxMGlCaWt1c2JwSkhBd25ucVA3aS9kQWNnQ3NrZ2paakZlRVU0RUZ5K2IrYTFTWQpRQ2VGeHhDN2MzRHZhUmhCQjBWVmZQbGtQejBzdzZsODY1TWFUSWJSeW9VQ0F3RUFBYU15TURBd0NRWURWUjBUCkJBSXdBREFqQmdOVkhSRUVIREFhZ2d3cUxtSmhaSE56YkM1amIyMkNDbUpoWkhOemJDNWpiMjB3RFFZSktvWkkKaHZjTkFRRUxCUUFEZ2dFQkFJaTV1OXc4bWdUNnBwQ2M3eHNHK0E5ZkkzVzR6K3FTS2FwaHI1bHM3MEdCS2JpWQpZTEpVWVpoUGZXcGgxcXRra1UwTEhGUG04M1ZhNTJlSUhyalhUMFZlNEt0TzFuMElBZkl0RmFXNjJDSmdoR1luCmp6dzByeXpnQzRQeUZwTk1uTnRCcm9QdS9iUGdXaU1nTE9OcEVaaGlneDRROHdmMVkvVTlzK3pDQ3hvSmxhS1IKTVhidVE4N1g3bS85VlJueHhvNk56NVpmN09USFRwTk9JNlZqYTBCeGJtSUFVNnlyaXc5VXJnaWJYZk9qM2o2bgpNVExCdWdVVklCMGJCYWFzSnNBTUsrdzRMQU52YXBlWjBET1NuT1I0S0syNEowT3lvRjVmSG1wNTllTTE3SW9GClFxQmh6cG1RVWd1bmVjRVc4QlRxck5wRzc5UjF1K1YrNHd3Y2tQYz0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="}} badCustomCAPolicies := testOptions(t) - badCustomCAPolicies.Policies = []config.Policy{ - {To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "=@@"}, - } + badCustomCAPolicies.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "=@@"}} + goodClientCertPolicies := testOptions(t) - goodClientCertPolicies.Policies = []config.Policy{ - {To: "http://foo.example", From: "http://bar.example", - TLSClientKey: "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", TLSClientCert: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="}, - } + goodClientCertPolicies.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSClientKey: "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", TLSClientCert: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="}} goodClientCertPolicies.Validate() + customServerName := testOptions(t) customServerName.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSServerName: "test"}} + + emptyPolicies := testOptions(t) + emptyPolicies.Policies = nil + + allowWebSockets := testOptions(t) + allowWebSockets.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", AllowWebsockets: true}} + customTimeout := testOptions(t) + customTimeout.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", UpstreamTimeout: 10 * time.Second}} + corsPreflight := testOptions(t) + corsPreflight.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", CORSAllowPreflight: true}} + disableAuth := testOptions(t) + disableAuth.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", AllowPublicUnauthenticatedAccess: true}} + tests := []struct { name string originalOptions config.Options @@ -301,12 +195,18 @@ func Test_UpdateOptions(t *testing.T) { {"changed", good, newPolicies, "", "https://bar.example", false, true}, {"changed and missing", good, newPolicies, "", "https://corp.example.example", false, false}, {"bad signing key", good, newPolicies, "^bad base 64", "https://corp.example.example", true, false}, + {"good signing key", good, newPolicies, "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSU0zbXBaSVdYQ1g5eUVneFU2czU3Q2J0YlVOREJTQ0VBdFFGNWZVV0hwY1FvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFaFBRditMQUNQVk5tQlRLMHhTVHpicEVQa1JyazFlVXQxQk9hMzJTRWZVUHpOaTRJV2VaLwpLS0lUdDJxMUlxcFYyS01TYlZEeXI5aWp2L1hoOThpeUV3PT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQ==", "https://corp.example.example", false, true}, {"bad change bad policy url", good, badNewPolicy, "", "https://bar.example", true, false}, {"disable tls verification", good, disableTLSPolicies, "", "https://bar.example", false, true}, {"custom root ca", good, customCAPolicies, "", "https://bar.example", false, true}, {"bad custom root ca base64", good, badCustomCAPolicies, "", "https://bar.example", true, false}, {"good client certs", good, goodClientCertPolicies, "", "https://bar.example", false, true}, {"custom server name", customServerName, customServerName, "", "https://bar.example", false, true}, + {"good no policies to start", emptyPolicies, good, "", "https://corp.example.example", false, true}, + {"allow websockets", good, allowWebSockets, "", "https://corp.example.example", false, true}, + {"no websockets, custom timeout", good, customTimeout, "", "https://corp.example.example", false, true}, + {"enable cors preflight", good, corsPreflight, "", "https://corp.example.example", false, true}, + {"disable auth", good, disableAuth, "", "https://corp.example.example", false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -324,9 +224,10 @@ func Test_UpdateOptions(t *testing.T) { // This is only safe if we actually can load policies if err == nil { - req := httptest.NewRequest("GET", tt.host, nil) - _, ok := p.router(req) - if ok != tt.wantRoute { + r := httptest.NewRequest("GET", tt.host, nil) + w := httptest.NewRecorder() + p.Handler.ServeHTTP(w, r) + if tt.wantRoute && w.Code != http.StatusNotFound { t.Errorf("Failed to find route handler") return }