mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 23:57:34 +02:00
proxy: use middleware to manage request flow
proxy: remove duplicate error handling in New proxy: remove routeConfigs in favor of using gorilla/mux proxy: add proxy specific middleware proxy: no longer need to use middleware / handler to check if valid route. Can use build in 404 mux. internal/middleware: add cors bypass middleware Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
70c5553d3c
commit
782ffbeb3e
17 changed files with 834 additions and 839 deletions
|
@ -112,7 +112,7 @@ func newProxyService(opt config.Options, r *mux.Router) (*proxy.Proxy, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.PathPrefix("/").Handler(service.Handler())
|
||||
r.PathPrefix("/").Handler(service.Handler)
|
||||
return service, nil
|
||||
}
|
||||
|
||||
|
|
4
go.sum
4
go.sum
|
@ -70,6 +70,7 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
|
|||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||
github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk=
|
||||
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
|
||||
|
@ -124,8 +125,6 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
|||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30 h1:jggCv6hZvcxjGa3gqkYY2EUuOkITI9Znugz/f3QJfRQ=
|
||||
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0=
|
||||
github.com/pomerium/csrf v1.6.2-0.20190918035251-f3318380bad3 h1:FmzFXnCAepHZwl6QPhTFqBHcbcGevdiEQjutK+M5bj4=
|
||||
github.com/pomerium/csrf v1.6.2-0.20190918035251-f3318380bad3/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0=
|
||||
github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o=
|
||||
|
@ -198,7 +197,6 @@ golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmV
|
|||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8 h1:1wopBVtVdWnn03fZelqdXTqk7U7zPQCb+T4rbU9ZEoU=
|
||||
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 h1:0hQKqeLdqlt5iIwVOBErRisrHJAN57yOiPRQItI20fU=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||
|
|
31
internal/httputil/proxy.go
Normal file
31
internal/httputil/proxy.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||
|
||||
import (
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// HeaderForwardHost is the header key the identifies the originating
|
||||
// IP addresses of a client connecting to a web server through an HTTP proxy
|
||||
// or a load balancer.
|
||||
const HeaderForwardHost = "X-Forwarded-Host"
|
||||
|
||||
// NewReverseProxy returns a new ReverseProxy that routes
|
||||
// URLs to the scheme, host, and base path provided in target,
|
||||
// rewrites the Host header, and sets `X-Forwarded-Host`.
|
||||
func NewReverseProxy(target *url.URL) *httputil.ReverseProxy {
|
||||
reverseProxy := httputil.NewSingleHostReverseProxy(target)
|
||||
sublogger := log.With().Str("reverse-proxy", target.Host).Logger()
|
||||
reverseProxy.ErrorLog = stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0)
|
||||
director := reverseProxy.Director
|
||||
reverseProxy.Director = func(req *http.Request) {
|
||||
req.Header.Add(HeaderForwardHost, req.Host)
|
||||
director(req)
|
||||
req.Host = target.Host
|
||||
}
|
||||
return reverseProxy
|
||||
}
|
35
internal/httputil/proxy_test.go
Normal file
35
internal/httputil/proxy_test.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewReverseProxy(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
hostname, _, _ := net.SplitHostPort(r.Host)
|
||||
w.Write([]byte(hostname))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
backendURL, _ := url.Parse(backend.URL)
|
||||
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
|
||||
backendHost := net.JoinHostPort(backendHostname, backendPort)
|
||||
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
|
||||
|
||||
proxyHandler := NewReverseProxy(proxyURL)
|
||||
frontend := httptest.NewServer(proxyHandler)
|
||||
defer frontend.Close()
|
||||
|
||||
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
||||
res, _ := http.DefaultClient.Do(getReq)
|
||||
bodyBytes, _ := ioutil.ReadAll(res.Body)
|
||||
if g, e := string(bodyBytes), backendHostname; g != e {
|
||||
t.Errorf("got body %q; expected %q", g, e)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package httputil
|
||||
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
|
26
internal/middleware/cors.go
Normal file
26
internal/middleware/cors.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
)
|
||||
|
||||
// CorsBypass is middleware that takes a target handler as a paramater,
|
||||
// if the request is determined to be a CORS preflight request, that handler
|
||||
// is called instead of the normal handler chain.
|
||||
func CorsBypass(target http.Handler) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.CorsBypass")
|
||||
defer span.End()
|
||||
if r.Method == http.MethodOptions &&
|
||||
r.Header.Get("Access-Control-Request-Method") != "" &&
|
||||
r.Header.Get("Origin") != "" {
|
||||
target.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
57
internal/middleware/cors_test.go
Normal file
57
internal/middleware/cors_test.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func someOtherMiddleware(s string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Some-Other-Middleware", s)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestCorsBypass(t *testing.T) {
|
||||
t.Parallel()
|
||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
header http.Header
|
||||
wantStatus int
|
||||
wantHeader string
|
||||
}{
|
||||
{"good", http.MethodOptions, http.Header{"Access-Control-Request-Method": []string{"GET"}, "Origin": []string{"localhost"}}, 200, ""},
|
||||
{"invalid cors - non options request", http.MethodGet, http.Header{"Access-Control-Request-Method": []string{"GET"}, "Origin": []string{"localhost"}}, 200, "BAD"},
|
||||
{"invalid cors - Origin not set", http.MethodOptions, http.Header{"Access-Control-Request-Method": []string{"GET"}, "Origin": []string{""}}, 200, "BAD"},
|
||||
{"invalid cors - Access-Control-Request-Method not set", http.MethodOptions, http.Header{"Access-Control-Request-Method": []string{""}, "Origin": []string{"*"}}, 200, "BAD"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &http.Request{
|
||||
Method: tt.method,
|
||||
Header: tt.header,
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
target := fn
|
||||
got := CorsBypass(target)(someOtherMiddleware("BAD")(target))
|
||||
got.ServeHTTP(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("TestCorsBypass() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
if header := w.Header().Get("Some-Other-Middleware"); tt.wantHeader != header {
|
||||
t.Errorf("TestCorsBypass() header = %v, want %v", header, tt.wantHeader)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -6,6 +6,7 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
|
@ -121,22 +122,6 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
|
|||
}
|
||||
}
|
||||
|
||||
// ValidateHost ensures that each request's host is valid
|
||||
func ValidateHost(validHost func(host string) bool) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateHost")
|
||||
defer span.End()
|
||||
|
||||
if !validHost(r.Host) {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusNotFound, nil))
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Healthcheck endpoint middleware useful to setting up a path like
|
||||
// `/ping` that load balancers or uptime testing external services
|
||||
// can make a request before hitting any routes. It's also convenient
|
||||
|
@ -185,3 +170,33 @@ func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
|||
}
|
||||
return cryptutil.CheckHMAC([]byte(fmt.Sprint(redirectURI, timestamp)), requestSig, secret)
|
||||
}
|
||||
|
||||
// StripCookie strips the cookie from the downstram request.
|
||||
func StripCookie(cookieName string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.StripCookie")
|
||||
defer span.End()
|
||||
|
||||
headers := make([]string, 0, len(r.Cookies()))
|
||||
for _, cookie := range r.Cookies() {
|
||||
if !strings.HasPrefix(cookie.Name, cookieName) {
|
||||
headers = append(headers, cookie.String())
|
||||
}
|
||||
}
|
||||
r.Header.Set("Cookie", strings.Join(headers, ";"))
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TimeoutHandlerFunc wraps http.TimeoutHandler
|
||||
func TimeoutHandlerFunc(timeout time.Duration, timeoutError string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.TimeoutHandlerFunc")
|
||||
defer span.End()
|
||||
http.TimeoutHandler(next, timeout, timeoutError).ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package middleware
|
||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
@ -276,60 +276,80 @@ func TestHealthCheck(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Redirect to a fixed URL
|
||||
type handlerHelper struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (rh *handlerHelper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(rh.msg))
|
||||
}
|
||||
|
||||
func handlerHelp(msg string) http.Handler {
|
||||
return &handlerHelper{msg}
|
||||
}
|
||||
func TestValidateHost(t *testing.T) {
|
||||
validHostFunc := func(host string) bool {
|
||||
return host == "google.com"
|
||||
}
|
||||
|
||||
validHostHandler := handlerHelp("google")
|
||||
|
||||
func TestStripCookie(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isValidHost func(string) bool
|
||||
validHostHandler http.Handler
|
||||
clientPath string
|
||||
expected []byte
|
||||
status int
|
||||
name string
|
||||
pomeriumCookie string
|
||||
otherCookies []string
|
||||
}{
|
||||
{"good", validHostFunc, validHostHandler, "google.com", []byte("google"), 200},
|
||||
{"no route", validHostFunc, validHostHandler, "googles.com", []byte("google"), 404},
|
||||
{"good", "pomerium", []string{"x", "y", "z"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, tt.clientPath, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, cookie := range r.Cookies() {
|
||||
if cookie.Name == tt.pomeriumCookie {
|
||||
t.Errorf("cookie not stripped %s", r.Cookies())
|
||||
}
|
||||
}
|
||||
})
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
var testHandler http.Handler
|
||||
if tt.isValidHost(tt.clientPath) {
|
||||
tt.validHostHandler.ServeHTTP(rr, req)
|
||||
testHandler = tt.validHostHandler
|
||||
} else {
|
||||
testHandler = handlerHelp("ok")
|
||||
for _, cn := range tt.otherCookies {
|
||||
http.SetCookie(rr, &http.Cookie{
|
||||
Name: cn,
|
||||
Value: "some other cookie",
|
||||
})
|
||||
}
|
||||
handler := ValidateHost(tt.isValidHost)(testHandler)
|
||||
|
||||
http.SetCookie(rr, &http.Cookie{
|
||||
Name: tt.pomeriumCookie,
|
||||
Value: "pomerium cookie!",
|
||||
})
|
||||
|
||||
http.SetCookie(rr, &http.Cookie{
|
||||
Name: tt.pomeriumCookie + "_csrf",
|
||||
Value: "pomerium csrf cookie!",
|
||||
})
|
||||
req := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}}
|
||||
|
||||
handler := StripCookie(tt.pomeriumCookie)(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.status {
|
||||
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
|
||||
t.Errorf("%s", rr.Body)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutHandlerFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout time.Duration
|
||||
timeoutError string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{"good", 1 * time.Second, "good timed out!?", http.StatusOK, http.StatusText(http.StatusOK)},
|
||||
{"timeout!", 1 * time.Nanosecond, "ruh roh", http.StatusServiceUnavailable, "ruh roh"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
got := TimeoutHandlerFunc(tt.timeout, tt.timeoutError)(fn)
|
||||
got.ServeHTTP(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
if body := w.Body.String(); tt.wantBody != body {
|
||||
t.Errorf("SignRequest() body = %v, want %v", body, tt.wantBody)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
)
|
||||
|
||||
func SignRequest(signer cryptutil.JWTSigner, id, email, groups, header string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.SignRequest")
|
||||
defer span.End()
|
||||
jwt, err := signer.SignJWT(
|
||||
r.Header.Get(id),
|
||||
r.Header.Get(email),
|
||||
r.Header.Get(groups))
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("internal/middleware: failed signing request")
|
||||
} else {
|
||||
r.Header.Set(header, jwt)
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// StripPomeriumCookie ensures that every response includes some basic security headers
|
||||
func StripPomeriumCookie(cookieName string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.StripPomeriumCookie")
|
||||
defer span.End()
|
||||
|
||||
headers := make([]string, 0, len(r.Cookies()))
|
||||
for _, cookie := range r.Cookies() {
|
||||
if !strings.HasPrefix(cookie.Name, cookieName) {
|
||||
headers = append(headers, cookie.String())
|
||||
}
|
||||
}
|
||||
r.Header.Set("Cookie", strings.Join(headers, ";"))
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,100 +0,0 @@
|
|||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
const exampleKey = `-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIM3mpZIWXCX9yEgxU6s57CbtbUNDBSCEAtQF5fUWHpcQoAoGCCqGSM49
|
||||
AwEHoUQDQgAEhPQv+LACPVNmBTK0xSTzbpEPkRrk1eUt1BOa32SEfUPzNi4IWeZ/
|
||||
KKITt2q1IqpV2KMSbVDyr9ijv/Xh98iyEw==
|
||||
-----END EC PRIVATE KEY-----
|
||||
`
|
||||
|
||||
func TestSignRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
id string
|
||||
email string
|
||||
groups string
|
||||
header string
|
||||
}{
|
||||
{"good", "id", "email", "group", "Jwt"},
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Header.Set(fmt.Sprintf("%s-header", tt.id), tt.id)
|
||||
r.Header.Set(fmt.Sprintf("%s-header", tt.email), tt.email)
|
||||
r.Header.Set(fmt.Sprintf("%s-header", tt.groups), tt.groups)
|
||||
|
||||
})
|
||||
rr := httptest.NewRecorder()
|
||||
signer, err := cryptutil.NewES256Signer(base64.StdEncoding.EncodeToString([]byte(exampleKey)), "audience")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
handler := SignRequest(signer, tt.id, tt.email, tt.groups, tt.header)(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
jwt := req.Header["Jwt"]
|
||||
if len(jwt) != 1 {
|
||||
t.Errorf("no jwt found %v", req.Header)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripPomeriumCookie(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pomeriumCookie string
|
||||
otherCookies []string
|
||||
}{
|
||||
{"good", "pomerium", []string{"x", "y", "z"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, cookie := range r.Cookies() {
|
||||
if cookie.Name == tt.pomeriumCookie {
|
||||
t.Errorf("cookie not stripped %s", r.Cookies())
|
||||
}
|
||||
}
|
||||
})
|
||||
rr := httptest.NewRecorder()
|
||||
for _, cn := range tt.otherCookies {
|
||||
http.SetCookie(rr, &http.Cookie{
|
||||
Name: cn,
|
||||
Value: "some other cookie",
|
||||
})
|
||||
}
|
||||
|
||||
http.SetCookie(rr, &http.Cookie{
|
||||
Name: tt.pomeriumCookie,
|
||||
Value: "pomerium cookie!",
|
||||
})
|
||||
|
||||
http.SetCookie(rr, &http.Cookie{
|
||||
Name: tt.pomeriumCookie + "_csrf",
|
||||
Value: "pomerium csrf cookie!",
|
||||
})
|
||||
req := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}}
|
||||
|
||||
handler := StripPomeriumCookie(tt.pomeriumCookie)(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
|
@ -7,67 +7,34 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/pomerium/csrf"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/config"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/templates"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
// Handler returns the proxy service's ServeMux
|
||||
func (p *Proxy) Handler() http.Handler {
|
||||
r := httputil.NewRouter()
|
||||
r.SkipClean(true)
|
||||
r.StrictSlash(true)
|
||||
r.Use(middleware.ValidateHost(func(host string) bool {
|
||||
_, ok := p.routeConfigs[host]
|
||||
return ok
|
||||
}))
|
||||
r.HandleFunc("/robots.txt", p.RobotsTxt)
|
||||
// requires authN not authZ
|
||||
r.Use(sessions.RetrieveSession(p.sessionStore))
|
||||
r.Use(p.VerifySession)
|
||||
// Proxy service endpoints
|
||||
v := r.PathPrefix("/.pomerium").Subrouter()
|
||||
v.Use(csrf.Protect(
|
||||
// registerHelperHandlers returns the proxy service's ServeMux
|
||||
func (p *Proxy) registerHelperHandlers(r *mux.Router) *mux.Router {
|
||||
h := r.PathPrefix(dashboardURL).Subrouter()
|
||||
h.Use(sessions.RetrieveSession(p.sessionStore))
|
||||
h.Use(p.AuthenticateSession)
|
||||
h.Use(csrf.Protect(
|
||||
p.cookieSecret,
|
||||
csrf.Path("/"),
|
||||
csrf.Domain(p.cookieDomain),
|
||||
csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieName)),
|
||||
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
|
||||
))
|
||||
v.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet)
|
||||
v.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost)
|
||||
v.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost)
|
||||
v.HandleFunc("/refresh", p.ForceRefresh).Methods(http.MethodPost)
|
||||
|
||||
r.PathPrefix("/").HandlerFunc(p.Proxy)
|
||||
h.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet)
|
||||
h.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost)
|
||||
h.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost)
|
||||
h.HandleFunc("/refresh", p.ForceRefresh).Methods(http.MethodPost)
|
||||
return r
|
||||
}
|
||||
|
||||
// VerifySession is the middleware used to enforce a valid authentication
|
||||
// session state is attached to the users's request context.
|
||||
func (p *Proxy) VerifySession(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error")
|
||||
p.authenticate(w, r)
|
||||
return
|
||||
}
|
||||
if err := state.Valid(); err != nil {
|
||||
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to invalid session")
|
||||
p.authenticate(w, r)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// RobotsTxt sets the User-Agent header in the response to be "Disallow"
|
||||
func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
|
@ -88,76 +55,6 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
|||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// Authenticate begins the authenticate flow, encrypting the redirect url
|
||||
// in a request to the provider's sign in endpoint.
|
||||
func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request) {
|
||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// shouldSkipAuthentication contains conditions for skipping authentication.
|
||||
// Conditions should be few in number and have strong justifications.
|
||||
func (p *Proxy) shouldSkipAuthentication(r *http.Request) bool {
|
||||
policy, policyExists := p.policy(r)
|
||||
|
||||
if isCORSPreflight(r) && policyExists && policy.CORSAllowPreflight {
|
||||
log.FromRequest(r).Debug().Msg("proxy: skipping authentication for valid CORS preflight request")
|
||||
return true
|
||||
}
|
||||
|
||||
if policyExists && policy.AllowPublicUnauthenticatedAccess {
|
||||
log.FromRequest(r).Debug().Msg("proxy: skipping authentication for public route")
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isCORSPreflight inspects the request to see if this is a valid CORS preflight request.
|
||||
// These checks are not exhaustive, because the proxied server should be verifying it as well.
|
||||
//
|
||||
// See https://www.html5rocks.com/static/images/cors_server_flowchart.png for more info.
|
||||
func isCORSPreflight(r *http.Request) bool {
|
||||
return r.Method == http.MethodOptions &&
|
||||
r.Header.Get("Access-Control-Request-Method") != "" &&
|
||||
r.Header.Get("Origin") != ""
|
||||
}
|
||||
|
||||
// Proxy authenticates a request, either proxying the request if it is authenticated,
|
||||
// or starting the authenticate service for validation if not.
|
||||
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||
route, ok := p.router(r)
|
||||
if !ok {
|
||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusNotFound, nil))
|
||||
return
|
||||
}
|
||||
|
||||
if p.shouldSkipAuthentication(r) {
|
||||
log.FromRequest(r).Debug().Msg("proxy: access control skipped")
|
||||
route.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if err != nil || s == nil {
|
||||
log.Debug().Err(err).Msg("proxy: couldn't get session from context")
|
||||
p.authenticate(w, r)
|
||||
return
|
||||
}
|
||||
authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
} else if !authorized {
|
||||
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.RequestEmail()), http.StatusForbidden, nil))
|
||||
return
|
||||
}
|
||||
r.Header.Set(HeaderUserID, s.User)
|
||||
r.Header.Set(HeaderEmail, s.RequestEmail())
|
||||
r.Header.Set(HeaderGroups, s.RequestGroups())
|
||||
|
||||
route.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// UserDashboard lets users investigate, and refresh their current session.
|
||||
// It also contains certain administrative actions like user impersonation.
|
||||
// Nota bene: This endpoint does authentication, not authorization.
|
||||
|
@ -207,10 +104,9 @@ func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// reject a refresh if it's been less than the refresh cooldown to prevent abuse
|
||||
if time.Since(iss) < p.refreshCooldown {
|
||||
httputil.ErrorResponse(w, r,
|
||||
httputil.Error(
|
||||
fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown),
|
||||
http.StatusBadRequest, nil))
|
||||
errStr := fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown)
|
||||
httpErr := httputil.Error(errStr, http.StatusBadRequest, nil)
|
||||
httputil.ErrorResponse(w, r, httpErr)
|
||||
return
|
||||
}
|
||||
session.ForceRefresh()
|
||||
|
@ -218,7 +114,7 @@ func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
|
|||
httputil.ErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, "/.pomerium", http.StatusFound)
|
||||
http.Redirect(w, r, dashboardURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// Impersonate takes the result of a form and adds user impersonation details
|
||||
|
@ -232,7 +128,9 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
|
||||
if err != nil || !isAdmin {
|
||||
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.RequestEmail()), http.StatusForbidden, err))
|
||||
errStr := fmt.Sprintf("%s is not an administrator", session.RequestEmail())
|
||||
httpErr := httputil.Error(errStr, http.StatusForbidden, err)
|
||||
httputil.ErrorResponse(w, r, httpErr)
|
||||
return
|
||||
}
|
||||
// OK to impersonation
|
||||
|
@ -247,27 +145,5 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/.pomerium", http.StatusFound)
|
||||
}
|
||||
|
||||
// router attempts to find a route for a request. If a route is successfully matched,
|
||||
// it returns the route information and a bool value of `true`. If a route can
|
||||
// not be matched, a nil value for the route and false bool value is returned.
|
||||
func (p *Proxy) router(r *http.Request) (http.Handler, bool) {
|
||||
config, ok := p.routeConfigs[r.Host]
|
||||
if ok {
|
||||
return config.mux, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// policy attempts to find a policy for a request. If a policy is successfully matched,
|
||||
// it returns the policy information and a bool value of `true`. If a policy can not be matched,
|
||||
// a nil value for the policy and false bool value is returned.
|
||||
func (p *Proxy) policy(r *http.Request) (*config.Policy, bool) {
|
||||
config, ok := p.routeConfigs[r.Host]
|
||||
if ok {
|
||||
return &config.policy, true
|
||||
}
|
||||
return nil, false
|
||||
http.Redirect(w, r, dashboardURL, http.StatusFound)
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package proxy
|
||||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -13,7 +13,6 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/internal/config"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/proxy/clients"
|
||||
)
|
||||
|
@ -76,153 +75,114 @@ func TestProxy_authenticate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestProxy_Handler(t *testing.T) {
|
||||
proxy, err := New(testOptions(t))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
h := proxy.Handler()
|
||||
if h == nil {
|
||||
t.Error("handler cannot be nil")
|
||||
}
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/", h)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 route not found for empty route")
|
||||
}
|
||||
}
|
||||
// func TestProxy_PomeriumHandler(t *testing.T) {
|
||||
// proxy, err := New(testOptions(t))
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// h := proxy.registerHelperHandlers()
|
||||
// if h == nil {
|
||||
// t.Error("handler cannot be nil")
|
||||
// }
|
||||
// mux := http.NewServeMux()
|
||||
// mux.Handle("/", h)
|
||||
// req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
// rr := httptest.NewRecorder()
|
||||
// mux.ServeHTTP(rr, req)
|
||||
// if rr.Code != http.StatusNotFound {
|
||||
// t.Errorf("expected 404 route not found for empty route")
|
||||
// }
|
||||
// }
|
||||
|
||||
func TestProxy_router(t *testing.T) {
|
||||
testPolicy := config.Policy{From: "https://corp.example.com", To: "https://example.com"}
|
||||
if err := testPolicy.Validate(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
policies := []config.Policy{testPolicy}
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
mux []config.Policy
|
||||
route http.Handler
|
||||
wantOk bool
|
||||
}{
|
||||
{"good corp", "https://corp.example.com", policies, nil, true},
|
||||
{"good with slash", "https://corp.example.com/", policies, nil, true},
|
||||
{"good with path", "https://corp.example.com/123", policies, nil, true},
|
||||
{"no policies", "https://notcorp.example.com/123", []config.Policy{}, nil, false},
|
||||
{"bad corp", "https://notcorp.example.com/123", policies, nil, false},
|
||||
{"bad sub-sub", "https://notcorp.corp.example.com/123", policies, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
opts := testOptions(t)
|
||||
opts.Policies = tt.mux
|
||||
p, err := New(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.encoder = &cryptutil.MockEncoder{MarshalResponse: "foo"}
|
||||
// func TestProxy_Proxy(t *testing.T) {
|
||||
// goodSession := &sessions.State{
|
||||
// AccessToken: "AccessToken",
|
||||
// RefreshToken: "RefreshToken",
|
||||
// RefreshDeadline: time.Now().Add(20 * time.Second),
|
||||
// }
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tt.host, nil)
|
||||
_, ok := p.router(req)
|
||||
if ok != tt.wantOk {
|
||||
t.Errorf("Proxy.router() ok = %v, want %v", ok, tt.wantOk)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
// ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
// w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
// fmt.Fprintln(w, "RVSI FILIVS CAISAR")
|
||||
// w.WriteHeader(http.StatusOK)
|
||||
|
||||
func TestProxy_Proxy(t *testing.T) {
|
||||
goodSession := &sessions.State{
|
||||
AccessToken: "AccessToken",
|
||||
RefreshToken: "RefreshToken",
|
||||
RefreshDeadline: time.Now().Add(20 * time.Second),
|
||||
}
|
||||
// }))
|
||||
// defer ts.Close()
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
// opts := testOptionsTestServer(t, ts.URL)
|
||||
// optsCORS := testOptionsWithCORS(t, ts.URL)
|
||||
// optsPublic := testOptionsWithPublicAccess(t, ts.URL)
|
||||
// optsNoPolicies := testOptionsWithEmptyPolicies(t, ts.URL)
|
||||
|
||||
}))
|
||||
defer ts.Close()
|
||||
// defaultHeaders, goodCORSHeaders, badCORSHeaders, headersWs := http.Header{}, http.Header{}, http.Header{}, http.Header{}
|
||||
// goodCORSHeaders.Set("origin", "anything")
|
||||
// goodCORSHeaders.Set("access-control-request-method", "anything")
|
||||
// // missing "Origin"
|
||||
// badCORSHeaders.Set("access-control-request-method", "anything")
|
||||
// headersWs.Set("Connection", "Upgrade")
|
||||
// headersWs.Set("Upgrade", "websocket")
|
||||
|
||||
opts := testOptionsTestServer(t, ts.URL)
|
||||
optsCORS := testOptionsWithCORS(t, ts.URL)
|
||||
optsPublic := testOptionsWithPublicAccess(t, ts.URL)
|
||||
optsNoPolicies := testOptionsWithEmptyPolicies(t, ts.URL)
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// options config.Options
|
||||
// method string
|
||||
// header http.Header
|
||||
// host string
|
||||
// session sessions.SessionStore
|
||||
// authorizer clients.Authorizer
|
||||
// wantStatus int
|
||||
// }{
|
||||
// {"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(20 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
// {"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
|
||||
// {"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
// {"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
// // same request as above, but with cors_allow_preflight=false in the policy
|
||||
// {"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
// // cors allowed, but the request is missing proper headers
|
||||
// {"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
// // redirect to start auth process
|
||||
// {"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
||||
// {"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
// {"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError},
|
||||
// // authenticate errors
|
||||
// {"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
||||
// {"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
|
||||
// {"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound},
|
||||
// {"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
||||
// {"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
||||
// }
|
||||
|
||||
defaultHeaders, goodCORSHeaders, badCORSHeaders, headersWs := http.Header{}, http.Header{}, http.Header{}, http.Header{}
|
||||
goodCORSHeaders.Set("origin", "anything")
|
||||
goodCORSHeaders.Set("access-control-request-method", "anything")
|
||||
// missing "Origin"
|
||||
badCORSHeaders.Set("access-control-request-method", "anything")
|
||||
headersWs.Set("Connection", "Upgrade")
|
||||
headersWs.Set("Upgrade", "websocket")
|
||||
// for _, tt := range tests {
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// err := ValidateOptions(tt.options)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// p, err := New(tt.options)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// p.encoder = &cryptutil.MockEncoder{MarshalResponse: "foo"}
|
||||
// p.sessionStore = tt.session
|
||||
// p.AuthorizeClient = tt.authorizer
|
||||
// r := httptest.NewRequest(tt.method, tt.host, nil)
|
||||
// r.Header = tt.header
|
||||
// r.Header.Set("Accept", "application/json")
|
||||
// state, _ := tt.session.LoadSession(r)
|
||||
// ctx := r.Context()
|
||||
// ctx = sessions.NewContext(ctx, state, nil)
|
||||
// r = r.WithContext(ctx)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
options config.Options
|
||||
method string
|
||||
header http.Header
|
||||
host string
|
||||
session sessions.SessionStore
|
||||
authorizer clients.Authorizer
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(20 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
{"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
|
||||
{"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
{"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
// same request as above, but with cors_allow_preflight=false in the policy
|
||||
{"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
// cors allowed, but the request is missing proper headers
|
||||
{"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
// redirect to start auth process
|
||||
{"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
||||
{"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
{"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError},
|
||||
// authenticate errors
|
||||
{"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
||||
{"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
|
||||
{"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound},
|
||||
{"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
||||
{"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
||||
}
|
||||
// w := httptest.NewRecorder()
|
||||
// p.Proxy(w, r)
|
||||
// if status := w.Code; status != tt.wantStatus {
|
||||
// t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String())
|
||||
// }
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateOptions(tt.options)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p, err := New(tt.options)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.encoder = &cryptutil.MockEncoder{MarshalResponse: "foo"}
|
||||
p.sessionStore = tt.session
|
||||
p.AuthorizeClient = tt.authorizer
|
||||
r := httptest.NewRequest(tt.method, tt.host, nil)
|
||||
r.Header = tt.header
|
||||
r.Header.Set("Accept", "application/json")
|
||||
state, _ := tt.session.LoadSession(r)
|
||||
ctx := r.Context()
|
||||
ctx = sessions.NewContext(ctx, state, nil)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
p.Proxy(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
func TestProxy_UserDashboard(t *testing.T) {
|
||||
opts := testOptions(t)
|
||||
|
@ -427,55 +387,9 @@ func TestProxy_SignOut(t *testing.T) {
|
|||
}
|
||||
}
|
||||
func uriParseHelper(s string) *url.URL {
|
||||
uri, _ := url.Parse(s)
|
||||
uri, err := url.Parse(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return uri
|
||||
}
|
||||
func TestProxy_VerifySession(t *testing.T) {
|
||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
session sessions.SessionStore
|
||||
ctxError error
|
||||
provider identity.Authenticator
|
||||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
|
||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusFound},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
a := Proxy{
|
||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
||||
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
|
||||
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
|
||||
sessionStore: tt.session,
|
||||
}
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
state, _ := tt.session.LoadSession(r)
|
||||
ctx := r.Context()
|
||||
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
got := a.VerifySession(fn)
|
||||
got.ServeHTTP(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
103
proxy/middleware.go
Normal file
103
proxy/middleware.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
const (
|
||||
// HeaderJWT is the header key containing JWT signed user details.
|
||||
HeaderJWT = "x-pomerium-jwt-assertion"
|
||||
// HeaderUserID is the header key containing the user's id.
|
||||
HeaderUserID = "x-pomerium-authenticated-user-id"
|
||||
// HeaderEmail is the header key containing the user's email.
|
||||
HeaderEmail = "x-pomerium-authenticated-user-email"
|
||||
// HeaderGroups is the header key containing the user's groups.
|
||||
HeaderGroups = "x-pomerium-authenticated-user-groups"
|
||||
)
|
||||
|
||||
// AuthenticateSession is middleware to enforce a valid authentication
|
||||
// session state is retrieved from the users's request context.
|
||||
func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthenticateSession")
|
||||
defer span.End()
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error")
|
||||
p.authenticate(w, r)
|
||||
return
|
||||
}
|
||||
if err := s.Valid(); err != nil {
|
||||
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to invalid session")
|
||||
p.authenticate(w, r)
|
||||
return
|
||||
}
|
||||
r.Header.Set(HeaderUserID, s.User)
|
||||
r.Header.Set(HeaderEmail, s.RequestEmail())
|
||||
r.Header.Set(HeaderGroups, s.RequestGroups())
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// AuthorizeSession is middleware to enforce a user is authorized for a request
|
||||
// session state is retrieved from the users's request context.
|
||||
func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthorizeSession")
|
||||
defer span.End()
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err))
|
||||
return
|
||||
}
|
||||
authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s)
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r.WithContext(ctx), err)
|
||||
return
|
||||
} else if !authorized {
|
||||
errMsg := fmt.Sprintf("%s is not authorized for this route", s.RequestEmail())
|
||||
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error(errMsg, http.StatusForbidden, nil))
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// SignRequest is middleware that signs a JWT that contains a user's id,
|
||||
// email, and group. Session state is retrieved from the users's request context
|
||||
func (p *Proxy) SignRequest(signer cryptutil.JWTSigner) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "middleware.SignRequest")
|
||||
defer span.End()
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err))
|
||||
return
|
||||
}
|
||||
jwt, err := signer.SignJWT(s.User, s.Email, strings.Join(s.Groups, ","))
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("proxy: failed signing jwt")
|
||||
} else {
|
||||
r.Header.Set(HeaderJWT, jwt)
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticate begins the authenticate flow, encrypting the redirect url
|
||||
// in a request to the provider's sign in endpoint.
|
||||
func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request) {
|
||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
}
|
175
proxy/middleware_test.go
Normal file
175
proxy/middleware_test.go
Normal file
|
@ -0,0 +1,175 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/proxy/clients"
|
||||
)
|
||||
|
||||
func TestProxy_AuthenticateSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
session sessions.SessionStore
|
||||
ctxError error
|
||||
provider identity.Authenticator
|
||||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
|
||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusFound},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
a := Proxy{
|
||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
||||
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
|
||||
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
|
||||
sessionStore: tt.session,
|
||||
}
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
state, _ := tt.session.LoadSession(r)
|
||||
ctx := r.Context()
|
||||
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||
r = r.WithContext(ctx)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
got := a.AuthenticateSession(fn)
|
||||
got.ServeHTTP(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("AuthenticateSession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_AuthorizeSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
tests := []struct {
|
||||
name string
|
||||
session sessions.SessionStore
|
||||
authzClient clients.Authorizer
|
||||
|
||||
ctxError error
|
||||
provider identity.Authenticator
|
||||
|
||||
wantStatus int
|
||||
}{
|
||||
{"user is authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK},
|
||||
{"user is not authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusForbidden},
|
||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusForbidden},
|
||||
{"authz client error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
a := Proxy{
|
||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
||||
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
|
||||
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
|
||||
sessionStore: tt.session,
|
||||
AuthorizeClient: tt.authzClient,
|
||||
}
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
state, _ := tt.session.LoadSession(r)
|
||||
ctx := r.Context()
|
||||
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||
r = r.WithContext(ctx)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
got := a.AuthorizeSession(fn)
|
||||
got.ServeHTTP(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("AuthorizeSession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockJWTSigner struct {
|
||||
SignError error
|
||||
}
|
||||
|
||||
// Sign implements the JWTSigner interface from the cryptutil package, but just
|
||||
// base64's the inputs instead for stesting.
|
||||
func (s *mockJWTSigner) SignJWT(user, email, groups string) (string, error) {
|
||||
return base64.StdEncoding.EncodeToString([]byte(fmt.Sprint(user, email, groups))), s.SignError
|
||||
}
|
||||
|
||||
func TestProxy_SignRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
session sessions.SessionStore
|
||||
|
||||
signerError error
|
||||
ctxError error
|
||||
|
||||
wantStatus int
|
||||
wantHeaders string
|
||||
}{
|
||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "dGVzdA=="},
|
||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""},
|
||||
{"signature failure, warn but ok", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
a := Proxy{
|
||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
||||
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
|
||||
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
|
||||
sessionStore: tt.session,
|
||||
}
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
state, _ := tt.session.LoadSession(r)
|
||||
ctx := r.Context()
|
||||
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||
r = r.WithContext(ctx)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
signer := &mockJWTSigner{SignError: tt.signerError}
|
||||
got := a.SignRequest(signer)(fn)
|
||||
got.ServeHTTP(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||
}
|
||||
if headers := r.Header.Get(HeaderJWT); tt.wantHeaders != headers {
|
||||
t.Errorf("SignRequest() headers = %v, want %v", headers, tt.wantHeaders)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
276
proxy/proxy.go
276
proxy/proxy.go
|
@ -5,15 +5,14 @@ import (
|
|||
"encoding/base64"
|
||||
"fmt"
|
||||
"html/template"
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/pomerium/pomerium/internal/config"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
pom_httputil "github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
|
@ -25,15 +24,8 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
// HeaderJWT is the header key containing JWT signed user details.
|
||||
HeaderJWT = "x-pomerium-jwt-assertion"
|
||||
// HeaderUserID is the header key containing the user's id.
|
||||
HeaderUserID = "x-pomerium-authenticated-user-id"
|
||||
// HeaderEmail is the header key containing the user's email.
|
||||
HeaderEmail = "x-pomerium-authenticated-user-email"
|
||||
// HeaderGroups is the header key containing the user's groups.
|
||||
HeaderGroups = "x-pomerium-authenticated-user-groups"
|
||||
|
||||
// dashboardURL is the path to authenticate's sign in endpoint
|
||||
dashboardURL = "/.pomerium/"
|
||||
// signinURL is the path to authenticate's sign in endpoint
|
||||
signinURL = "/.pomerium/sign_in"
|
||||
// signoutURL is the path to authenticate's sign out endpoint
|
||||
|
@ -78,38 +70,29 @@ type Proxy struct {
|
|||
|
||||
AuthorizeClient clients.Authorizer
|
||||
|
||||
// cipher cipher.AEAD
|
||||
encoder cryptutil.SecureEncoder
|
||||
cookieName string
|
||||
cookieDomain string
|
||||
cookieSecret []byte
|
||||
defaultUpstreamTimeout time.Duration
|
||||
refreshCooldown time.Duration
|
||||
routeConfigs map[string]*routeConfig
|
||||
Handler http.Handler
|
||||
sessionStore sessions.SessionStore
|
||||
signingKey string
|
||||
templates *template.Template
|
||||
}
|
||||
|
||||
type routeConfig struct {
|
||||
mux http.Handler
|
||||
policy config.Policy
|
||||
}
|
||||
|
||||
// New takes a Proxy service from options and a validation function.
|
||||
// Function returns an error if options fail to validate.
|
||||
func New(opts config.Options) (*Proxy, error) {
|
||||
if err := ValidateOptions(opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedCookieSecret, err := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cipher, err := cryptutil.NewAEADCipherFromBase64(opts.CookieSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// errors checked in ValidateOptions
|
||||
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
|
||||
cipher, _ := cryptutil.NewAEADCipherFromBase64(opts.CookieSecret)
|
||||
|
||||
encoder := cryptutil.NewSecureJSONEncoder(cipher)
|
||||
|
||||
if opts.CookieDomain == "" {
|
||||
|
@ -132,7 +115,6 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
p := &Proxy{
|
||||
SharedKey: opts.SharedKey,
|
||||
|
||||
routeConfigs: make(map[string]*routeConfig),
|
||||
encoder: encoder,
|
||||
cookieSecret: decodedCookieSecret,
|
||||
cookieDomain: opts.CookieDomain,
|
||||
|
@ -143,18 +125,18 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
signingKey: opts.SigningKey,
|
||||
templates: templates.New(),
|
||||
}
|
||||
// DeepCopy urls to avoid accidental mutation, err checked in validate func
|
||||
// errors checked in ValidateOptions
|
||||
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
|
||||
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
|
||||
|
||||
p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
|
||||
p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
|
||||
|
||||
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
|
||||
|
||||
if err := p.UpdatePolicies(&opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
metrics.AddPolicyCountCallback("proxy", func() int64 {
|
||||
return int64(len(p.routeConfigs))
|
||||
return int64(len(opts.Policies))
|
||||
})
|
||||
p.AuthorizeClient, err = clients.NewAuthorizeClient("grpc",
|
||||
&clients.Options{
|
||||
|
@ -169,116 +151,6 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
return p, err
|
||||
}
|
||||
|
||||
// UpdatePolicies updates the handlers based on the configured policies
|
||||
func (p *Proxy) UpdatePolicies(opts *config.Options) error {
|
||||
routeConfigs := make(map[string]*routeConfig, len(opts.Policies))
|
||||
if len(opts.Policies) == 0 {
|
||||
log.Warn().Msg("proxy: configuration has no policies")
|
||||
}
|
||||
for _, policy := range opts.Policies {
|
||||
if err := policy.Validate(); err != nil {
|
||||
return fmt.Errorf("proxy: couldn't update policies %s", err)
|
||||
}
|
||||
proxy := NewReverseProxy(policy.Destination)
|
||||
// build http transport (roundtripper) middleware chain
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
c := tripper.NewChain()
|
||||
c = c.Append(metrics.HTTPMetricsRoundTripper("proxy", policy.Destination.Host))
|
||||
|
||||
var tlsClientConfig tls.Config
|
||||
var isCustomClientConfig bool
|
||||
if policy.TLSSkipVerify {
|
||||
tlsClientConfig.InsecureSkipVerify = true
|
||||
isCustomClientConfig = true
|
||||
log.Warn().Str("policy", policy.String()).Msg("proxy: tls skip verify")
|
||||
}
|
||||
if policy.RootCAs != nil {
|
||||
tlsClientConfig.RootCAs = policy.RootCAs
|
||||
isCustomClientConfig = true
|
||||
log.Debug().Str("policy", policy.String()).Msg("proxy: custom root ca")
|
||||
}
|
||||
|
||||
if policy.ClientCertificate != nil {
|
||||
tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate}
|
||||
isCustomClientConfig = true
|
||||
log.Debug().Str("policy", policy.String()).Msg("proxy: client certs enabled")
|
||||
}
|
||||
|
||||
if policy.TLSServerName != "" {
|
||||
tlsClientConfig.ServerName = policy.TLSServerName
|
||||
isCustomClientConfig = true
|
||||
log.Debug().Str("policy", policy.String()).Msgf("proxy: tls hostname override to: %s", policy.TLSServerName)
|
||||
}
|
||||
|
||||
// We avoid setting a custom client config unless we have to as
|
||||
// if TLSClientConfig is nil, the default configuration is used.
|
||||
if isCustomClientConfig {
|
||||
transport.TLSClientConfig = &tlsClientConfig
|
||||
}
|
||||
proxy.Transport = c.Then(transport)
|
||||
|
||||
handler, err := p.newReverseProxyHandler(proxy, &policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
routeConfigs[policy.Source.Host] = &routeConfig{
|
||||
mux: handler,
|
||||
policy: policy,
|
||||
}
|
||||
}
|
||||
p.routeConfigs = routeConfigs
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewReverseProxy returns a new ReverseProxy that routes URLs to the scheme, host, and
|
||||
// base path provided in target. NewReverseProxy rewrites the Host header.
|
||||
func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
|
||||
proxy := httputil.NewSingleHostReverseProxy(to)
|
||||
sublogger := log.With().Str("proxy", to.Host).Logger()
|
||||
proxy.ErrorLog = stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0)
|
||||
director := proxy.Director
|
||||
proxy.Director = func(req *http.Request) {
|
||||
// Identifies the originating IP addresses of a client connecting to
|
||||
// a web server through an HTTP proxy or a load balancer.
|
||||
req.Header.Add("X-Forwarded-Host", req.Host)
|
||||
director(req)
|
||||
req.Host = to.Host
|
||||
}
|
||||
return proxy
|
||||
}
|
||||
|
||||
// each route has a custom set of middleware applied to the reverse proxy
|
||||
func (p *Proxy) newReverseProxyHandler(rp http.Handler, route *config.Policy) (http.Handler, error) {
|
||||
r := pom_httputil.NewRouter()
|
||||
r.SkipClean(true)
|
||||
r.StrictSlash(true)
|
||||
r.Use(middleware.StripPomeriumCookie(p.cookieName))
|
||||
// if signing key is set, add signer to middleware
|
||||
if len(p.signingKey) != 0 {
|
||||
signer, err := cryptutil.NewES256Signer(p.signingKey, route.Source.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.Use(middleware.SignRequest(signer, HeaderUserID, HeaderEmail, HeaderGroups, HeaderJWT))
|
||||
}
|
||||
// websockets cannot use the non-hijackable timeout-handler
|
||||
if !route.AllowWebsockets {
|
||||
timeout := p.defaultUpstreamTimeout
|
||||
if route.UpstreamTimeout != 0 {
|
||||
timeout = route.UpstreamTimeout
|
||||
}
|
||||
timeoutMsg := fmt.Sprintf("%s timed out in %s", route.Destination.Host, timeout)
|
||||
rp = http.TimeoutHandler(rp, timeout, timeoutMsg)
|
||||
}
|
||||
// todo(bdd) : fix cors
|
||||
// if route.CORSAllowPreflight {
|
||||
// r.Use(cors.Default().Handler)
|
||||
// }
|
||||
r.Host(route.Destination.Host)
|
||||
r.PathPrefix("/").Handler(rp)
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// UpdateOptions updates internal structures based on config.Options
|
||||
func (p *Proxy) UpdateOptions(o config.Options) error {
|
||||
if p == nil {
|
||||
|
@ -287,3 +159,123 @@ func (p *Proxy) UpdateOptions(o config.Options) error {
|
|||
log.Info().Msg("proxy: updating options")
|
||||
return p.UpdatePolicies(&o)
|
||||
}
|
||||
|
||||
// UpdatePolicies updates the H basedon the configured policies
|
||||
func (p *Proxy) UpdatePolicies(opts *config.Options) error {
|
||||
var err error
|
||||
if len(opts.Policies) == 0 {
|
||||
log.Warn().Msg("proxy: configuration has no policies")
|
||||
}
|
||||
r := httputil.NewRouter()
|
||||
r.SkipClean(true)
|
||||
r.StrictSlash(true)
|
||||
r.HandleFunc("/robots.txt", p.RobotsTxt).Methods(http.MethodGet)
|
||||
r = p.registerHelperHandlers(r)
|
||||
|
||||
for _, policy := range opts.Policies {
|
||||
if err := policy.Validate(); err != nil {
|
||||
return fmt.Errorf("proxy: invalid policy %s", err)
|
||||
}
|
||||
r, err = p.reverseProxyHandler(r, &policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
p.Handler = r
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Proxy) reverseProxyHandler(r *mux.Router, policy *config.Policy) (*mux.Router, error) {
|
||||
// 1. Create the reverse proxy connection
|
||||
proxy := httputil.NewReverseProxy(policy.Destination)
|
||||
// 2. Override any custom transport settings (e.g. TLS settings, etc)
|
||||
proxy.Transport = p.roundTripperFromPolicy(policy)
|
||||
// 3. Create a sub-router for a given route's hostname (`httpbin.corp.example.com`)
|
||||
rp := r.Host(policy.Source.Host).Subrouter()
|
||||
rp.PathPrefix("/").Handler(proxy)
|
||||
|
||||
// Optional: If websockets are enabled, do not set a handler request timeout
|
||||
// websockets cannot use the non-hijackable timeout-handler
|
||||
if !policy.AllowWebsockets {
|
||||
timeout := p.defaultUpstreamTimeout
|
||||
if policy.UpstreamTimeout != 0 {
|
||||
timeout = policy.UpstreamTimeout
|
||||
}
|
||||
timeoutMsg := fmt.Sprintf("%s timed out in %s", policy.Destination.Host, timeout)
|
||||
rp.Use(middleware.TimeoutHandlerFunc(timeout, timeoutMsg))
|
||||
}
|
||||
|
||||
// Optional: a cors preflight check, skip access control middleware
|
||||
if policy.CORSAllowPreflight {
|
||||
log.Warn().Str("route", policy.String()).Msg("proxy: cors preflight enabled")
|
||||
rp.Use(middleware.CorsBypass(proxy))
|
||||
}
|
||||
|
||||
// Optional: if a public route, skip access control middleware
|
||||
if policy.AllowPublicUnauthenticatedAccess {
|
||||
log.Warn().Str("route", policy.String()).Msg("proxy: all access control disabled")
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// 4. Retrieve the user session and add it to the request context
|
||||
rp.Use(sessions.RetrieveSession(p.sessionStore))
|
||||
// 5. Strip the user session cookie from the downstream request
|
||||
rp.Use(middleware.StripCookie(p.cookieName))
|
||||
// 6. AuthN - Verify the user is authenticated. Set email, group, & id headers
|
||||
rp.Use(p.AuthenticateSession)
|
||||
// 7. AuthZ - Verify the user is authorized for route
|
||||
rp.Use(p.AuthorizeSession)
|
||||
// Optional: Add a signed JWT attesting to the user's id, email, and group
|
||||
if len(p.signingKey) != 0 {
|
||||
signer, err := cryptutil.NewES256Signer(p.signingKey, policy.Source.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rp.Use(p.SignRequest(signer))
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// roundTripperFromPolicy adjusts the std library's `DefaultTransport RoundTripper`
|
||||
// for a given route. A route's `RoundTripper` establishes network connections
|
||||
// as needed and caches them for reuse by subsequent calls.
|
||||
func (p *Proxy) roundTripperFromPolicy(policy *config.Policy) http.RoundTripper {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
c := tripper.NewChain()
|
||||
c = c.Append(metrics.HTTPMetricsRoundTripper("proxy", policy.Destination.Host))
|
||||
|
||||
var tlsClientConfig tls.Config
|
||||
var isCustomClientConfig bool
|
||||
|
||||
if policy.TLSSkipVerify {
|
||||
tlsClientConfig.InsecureSkipVerify = true
|
||||
isCustomClientConfig = true
|
||||
log.Warn().Str("policy", policy.String()).Msg("proxy: tls skip verify")
|
||||
}
|
||||
|
||||
if policy.RootCAs != nil {
|
||||
tlsClientConfig.RootCAs = policy.RootCAs
|
||||
isCustomClientConfig = true
|
||||
log.Debug().Str("policy", policy.String()).Msg("proxy: custom root ca")
|
||||
}
|
||||
|
||||
if policy.ClientCertificate != nil {
|
||||
tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate}
|
||||
isCustomClientConfig = true
|
||||
log.Debug().Str("policy", policy.String()).Msg("proxy: client certs enabled")
|
||||
}
|
||||
|
||||
if policy.TLSServerName != "" {
|
||||
tlsClientConfig.ServerName = policy.TLSServerName
|
||||
isCustomClientConfig = true
|
||||
log.Debug().Str("policy", policy.String()).Msgf("proxy: tls override hostname: %s", policy.TLSServerName)
|
||||
}
|
||||
|
||||
// We avoid setting a custom client config unless we have to as
|
||||
// if TLSClientConfig is nil, the default configuration is used.
|
||||
if isCustomClientConfig {
|
||||
transport.TLSClientConfig = &tlsClientConfig
|
||||
}
|
||||
return c.Then(transport)
|
||||
}
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
|
@ -21,72 +19,6 @@ func newTestOptions(t *testing.T) *config.Options {
|
|||
return opts
|
||||
}
|
||||
|
||||
func TestNewReverseProxy(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
hostname, _, _ := net.SplitHostPort(r.Host)
|
||||
w.Write([]byte(hostname))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
backendURL, _ := url.Parse(backend.URL)
|
||||
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
|
||||
backendHost := net.JoinHostPort(backendHostname, backendPort)
|
||||
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
|
||||
|
||||
proxyHandler := NewReverseProxy(proxyURL)
|
||||
frontend := httptest.NewServer(proxyHandler)
|
||||
defer frontend.Close()
|
||||
|
||||
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
||||
res, _ := http.DefaultClient.Do(getReq)
|
||||
bodyBytes, _ := ioutil.ReadAll(res.Body)
|
||||
if g, e := string(bodyBytes), backendHostname; g != e {
|
||||
t.Errorf("got body %q; expected %q", g, e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewReverseProxyHandler(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
hostname, _, _ := net.SplitHostPort(r.Host)
|
||||
w.Write([]byte(hostname))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
backendURL, _ := url.Parse(backend.URL)
|
||||
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
|
||||
backendHost := net.JoinHostPort(backendHostname, backendPort)
|
||||
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
|
||||
proxyHandler := NewReverseProxy(proxyURL)
|
||||
opts := newTestOptions(t)
|
||||
opts.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSU0zbXBaSVdYQ1g5eUVneFU2czU3Q2J0YlVOREJTQ0VBdFFGNWZVV0hwY1FvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFaFBRditMQUNQVk5tQlRLMHhTVHpicEVQa1JyazFlVXQxQk9hMzJTRWZVUHpOaTRJV2VaLwpLS0lUdDJxMUlxcFYyS01TYlZEeXI5aWp2L1hoOThpeUV3PT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
|
||||
testPolicy := config.Policy{From: "https://corp.example.com", To: "https://example.com", UpstreamTimeout: 1 * time.Second}
|
||||
if err := testPolicy.Validate(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p, err := New(*opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
handle, err := p.newReverseProxyHandler(proxyHandler, &testPolicy)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
frontend := httptest.NewServer(handle)
|
||||
|
||||
defer frontend.Close()
|
||||
|
||||
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
||||
|
||||
res, _ := http.DefaultClient.Do(getReq)
|
||||
bodyBytes, _ := ioutil.ReadAll(res.Body)
|
||||
if g, e := string(bodyBytes), backendHostname; g != e {
|
||||
t.Errorf("got body %q; expected %q", g, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testOptions(t *testing.T) config.Options {
|
||||
authenticateService, _ := url.Parse("https://authenticate.corp.beyondperimeter.com")
|
||||
authorizeService, _ := url.Parse("https://authorize.corp.beyondperimeter.com")
|
||||
|
@ -106,61 +38,9 @@ func testOptions(t *testing.T) config.Options {
|
|||
return *opts
|
||||
}
|
||||
|
||||
func testOptionsTestServer(t *testing.T, uri string) config.Options {
|
||||
authenticateService, _ := url.Parse("https://authenticate.corp.beyondperimeter.com")
|
||||
authorizeService, _ := url.Parse("https://authorize.corp.beyondperimeter.com")
|
||||
testPolicy := config.Policy{
|
||||
From: "https://httpbin.corp.example",
|
||||
To: uri,
|
||||
}
|
||||
if err := testPolicy.Validate(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
opts := newTestOptions(t)
|
||||
opts.Policies = []config.Policy{testPolicy}
|
||||
opts.AuthenticateURL = authenticateService
|
||||
opts.AuthorizeURL = authorizeService
|
||||
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
|
||||
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
|
||||
opts.CookieName = "pomerium"
|
||||
return *opts
|
||||
}
|
||||
|
||||
func testOptionsWithCORS(t *testing.T, uri string) config.Options {
|
||||
testPolicy := config.Policy{
|
||||
From: "https://httpbin.corp.example",
|
||||
To: uri,
|
||||
CORSAllowPreflight: true,
|
||||
}
|
||||
if err := testPolicy.Validate(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
opts := testOptionsTestServer(t, uri)
|
||||
opts.Policies = []config.Policy{testPolicy}
|
||||
return opts
|
||||
}
|
||||
|
||||
func testOptionsWithPublicAccess(t *testing.T, uri string) config.Options {
|
||||
testPolicy := config.Policy{
|
||||
From: "https://httpbin.corp.example",
|
||||
To: uri,
|
||||
AllowPublicUnauthenticatedAccess: true,
|
||||
}
|
||||
if err := testPolicy.Validate(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
opts := testOptions(t)
|
||||
opts.Policies = []config.Policy{testPolicy}
|
||||
return opts
|
||||
}
|
||||
|
||||
func testOptionsWithEmptyPolicies(t *testing.T, uri string) config.Options {
|
||||
opts := testOptionsTestServer(t, uri)
|
||||
opts.Policies = []config.Policy{}
|
||||
return opts
|
||||
}
|
||||
|
||||
func TestOptions_Validate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
good := testOptions(t)
|
||||
badAuthURL := testOptions(t)
|
||||
badAuthURL.AuthenticateURL = nil
|
||||
|
@ -215,22 +95,33 @@ func TestOptions_Validate(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
good := testOptions(t)
|
||||
shortCookieLength := testOptions(t)
|
||||
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
|
||||
badRoutedProxy := testOptions(t)
|
||||
badRoutedProxy.SigningKey = "YmFkIGtleQo="
|
||||
badCookie := testOptions(t)
|
||||
badCookie.CookieName = ""
|
||||
badPolicyURL := config.Policy{To: "http://", From: "http://bar.example"}
|
||||
badNewPolicy := testOptions(t)
|
||||
badNewPolicy.Policies = []config.Policy{
|
||||
badPolicyURL,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts config.Options
|
||||
wantProxy bool
|
||||
numRoutes int
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", good, true, 1, false},
|
||||
{"empty options", config.Options{}, false, 0, true},
|
||||
{"short secret/validate sanity check", shortCookieLength, false, 0, true},
|
||||
{"invalid ec key, valid base64 though", badRoutedProxy, false, 0, true},
|
||||
{"good", good, true, false},
|
||||
{"empty options", config.Options{}, false, true},
|
||||
{"short secret/validate sanity check", shortCookieLength, false, true},
|
||||
{"invalid ec key, valid base64 though", badRoutedProxy, false, true},
|
||||
{"invalid cookie name, empty", badCookie, false, true},
|
||||
{"bad policy, bad policy url", badNewPolicy, false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -242,21 +133,17 @@ func TestNew(t *testing.T) {
|
|||
if got == nil && tt.wantProxy == true {
|
||||
t.Errorf("New() expected valid proxy struct")
|
||||
}
|
||||
if got != nil && len(got.routeConfigs) != tt.numRoutes {
|
||||
t.Errorf("New() = num routeConfigs \n%+v, want \n%+v \nfrom %+v", got, tt.numRoutes, tt.opts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UpdateOptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
good := testOptions(t)
|
||||
newPolicy := config.Policy{To: "http://foo.example", From: "http://bar.example"}
|
||||
newPolicies := testOptions(t)
|
||||
newPolicies.Policies = []config.Policy{
|
||||
newPolicy,
|
||||
}
|
||||
newPolicies.Policies = []config.Policy{newPolicy}
|
||||
err := newPolicy.Validate()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -268,26 +155,33 @@ func Test_UpdateOptions(t *testing.T) {
|
|||
}
|
||||
disableTLSPolicy := config.Policy{To: "http://foo.example", From: "http://bar.example", TLSSkipVerify: true}
|
||||
disableTLSPolicies := testOptions(t)
|
||||
disableTLSPolicies.Policies = []config.Policy{
|
||||
disableTLSPolicy,
|
||||
}
|
||||
disableTLSPolicies.Policies = []config.Policy{disableTLSPolicy}
|
||||
|
||||
customCAPolicies := testOptions(t)
|
||||
customCAPolicies.Policies = []config.Policy{
|
||||
{To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURlVENDQW1HZ0F3SUJBZ0lKQUszMmhoR0JIcmFtTUEwR0NTcUdTSWIzRFFFQkN3VUFNR0l4Q3pBSkJnTlYKQkFZVEFsVlRNUk13RVFZRFZRUUlEQXBEWVd4cFptOXlibWxoTVJZd0ZBWURWUVFIREExVFlXNGdSbkpoYm1OcApjMk52TVE4d0RRWURWUVFLREFaQ1lXUlRVMHd4RlRBVEJnTlZCQU1NRENvdVltRmtjM05zTG1OdmJUQWVGdzB4Ck9UQTJNVEl4TlRNeE5UbGFGdzB5TVRBMk1URXhOVE14TlRsYU1HSXhDekFKQmdOVkJBWVRBbFZUTVJNd0VRWUQKVlFRSURBcERZV3hwWm05eWJtbGhNUll3RkFZRFZRUUhEQTFUWVc0Z1JuSmhibU5wYzJOdk1ROHdEUVlEVlFRSwpEQVpDWVdSVFUwd3hGVEFUQmdOVkJBTU1EQ291WW1Ga2MzTnNMbU52YlRDQ0FTSXdEUVlKS29aSWh2Y05BUUVCCkJRQURnZ0VQQURDQ0FRb0NnZ0VCQU1JRTdQaU03Z1RDczloUTFYQll6Sk1ZNjF5b2FFbXdJclg1bFo2eEt5eDIKUG16QVMyQk1UT3F5dE1BUGdMYXcrWExKaGdMNVhFRmRFeXQvY2NSTHZPbVVMbEEzcG1jY1lZejJRVUxGUnRNVwpoeWVmZE9zS25SRlNKaUZ6YklSTWVWWGswV3ZvQmoxSUZWS3RzeWpicXY5dS8yQ1ZTbmRyT2ZFazBURzIzVTNBCnhQeFR1VzFDcmJWOC9xNzFGZEl6U09jaWNjZkNGSHBzS09vM1N0L3FiTFZ5dEg1YW9oYmNhYkZYUk5zS0VxdmUKd3c5SGRGeEJJdUdhK1J1VDVxMGlCaWt1c2JwSkhBd25ucVA3aS9kQWNnQ3NrZ2paakZlRVU0RUZ5K2IrYTFTWQpRQ2VGeHhDN2MzRHZhUmhCQjBWVmZQbGtQejBzdzZsODY1TWFUSWJSeW9VQ0F3RUFBYU15TURBd0NRWURWUjBUCkJBSXdBREFqQmdOVkhSRUVIREFhZ2d3cUxtSmhaSE56YkM1amIyMkNDbUpoWkhOemJDNWpiMjB3RFFZSktvWkkKaHZjTkFRRUxCUUFEZ2dFQkFJaTV1OXc4bWdUNnBwQ2M3eHNHK0E5ZkkzVzR6K3FTS2FwaHI1bHM3MEdCS2JpWQpZTEpVWVpoUGZXcGgxcXRra1UwTEhGUG04M1ZhNTJlSUhyalhUMFZlNEt0TzFuMElBZkl0RmFXNjJDSmdoR1luCmp6dzByeXpnQzRQeUZwTk1uTnRCcm9QdS9iUGdXaU1nTE9OcEVaaGlneDRROHdmMVkvVTlzK3pDQ3hvSmxhS1IKTVhidVE4N1g3bS85VlJueHhvNk56NVpmN09USFRwTk9JNlZqYTBCeGJtSUFVNnlyaXc5VXJnaWJYZk9qM2o2bgpNVExCdWdVVklCMGJCYWFzSnNBTUsrdzRMQU52YXBlWjBET1NuT1I0S0syNEowT3lvRjVmSG1wNTllTTE3SW9GClFxQmh6cG1RVWd1bmVjRVc4QlRxck5wRzc5UjF1K1YrNHd3Y2tQYz0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="},
|
||||
}
|
||||
customCAPolicies.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURlVENDQW1HZ0F3SUJBZ0lKQUszMmhoR0JIcmFtTUEwR0NTcUdTSWIzRFFFQkN3VUFNR0l4Q3pBSkJnTlYKQkFZVEFsVlRNUk13RVFZRFZRUUlEQXBEWVd4cFptOXlibWxoTVJZd0ZBWURWUVFIREExVFlXNGdSbkpoYm1OcApjMk52TVE4d0RRWURWUVFLREFaQ1lXUlRVMHd4RlRBVEJnTlZCQU1NRENvdVltRmtjM05zTG1OdmJUQWVGdzB4Ck9UQTJNVEl4TlRNeE5UbGFGdzB5TVRBMk1URXhOVE14TlRsYU1HSXhDekFKQmdOVkJBWVRBbFZUTVJNd0VRWUQKVlFRSURBcERZV3hwWm05eWJtbGhNUll3RkFZRFZRUUhEQTFUWVc0Z1JuSmhibU5wYzJOdk1ROHdEUVlEVlFRSwpEQVpDWVdSVFUwd3hGVEFUQmdOVkJBTU1EQ291WW1Ga2MzTnNMbU52YlRDQ0FTSXdEUVlKS29aSWh2Y05BUUVCCkJRQURnZ0VQQURDQ0FRb0NnZ0VCQU1JRTdQaU03Z1RDczloUTFYQll6Sk1ZNjF5b2FFbXdJclg1bFo2eEt5eDIKUG16QVMyQk1UT3F5dE1BUGdMYXcrWExKaGdMNVhFRmRFeXQvY2NSTHZPbVVMbEEzcG1jY1lZejJRVUxGUnRNVwpoeWVmZE9zS25SRlNKaUZ6YklSTWVWWGswV3ZvQmoxSUZWS3RzeWpicXY5dS8yQ1ZTbmRyT2ZFazBURzIzVTNBCnhQeFR1VzFDcmJWOC9xNzFGZEl6U09jaWNjZkNGSHBzS09vM1N0L3FiTFZ5dEg1YW9oYmNhYkZYUk5zS0VxdmUKd3c5SGRGeEJJdUdhK1J1VDVxMGlCaWt1c2JwSkhBd25ucVA3aS9kQWNnQ3NrZ2paakZlRVU0RUZ5K2IrYTFTWQpRQ2VGeHhDN2MzRHZhUmhCQjBWVmZQbGtQejBzdzZsODY1TWFUSWJSeW9VQ0F3RUFBYU15TURBd0NRWURWUjBUCkJBSXdBREFqQmdOVkhSRUVIREFhZ2d3cUxtSmhaSE56YkM1amIyMkNDbUpoWkhOemJDNWpiMjB3RFFZSktvWkkKaHZjTkFRRUxCUUFEZ2dFQkFJaTV1OXc4bWdUNnBwQ2M3eHNHK0E5ZkkzVzR6K3FTS2FwaHI1bHM3MEdCS2JpWQpZTEpVWVpoUGZXcGgxcXRra1UwTEhGUG04M1ZhNTJlSUhyalhUMFZlNEt0TzFuMElBZkl0RmFXNjJDSmdoR1luCmp6dzByeXpnQzRQeUZwTk1uTnRCcm9QdS9iUGdXaU1nTE9OcEVaaGlneDRROHdmMVkvVTlzK3pDQ3hvSmxhS1IKTVhidVE4N1g3bS85VlJueHhvNk56NVpmN09USFRwTk9JNlZqYTBCeGJtSUFVNnlyaXc5VXJnaWJYZk9qM2o2bgpNVExCdWdVVklCMGJCYWFzSnNBTUsrdzRMQU52YXBlWjBET1NuT1I0S0syNEowT3lvRjVmSG1wNTllTTE3SW9GClFxQmh6cG1RVWd1bmVjRVc4QlRxck5wRzc5UjF1K1YrNHd3Y2tQYz0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="}}
|
||||
|
||||
badCustomCAPolicies := testOptions(t)
|
||||
badCustomCAPolicies.Policies = []config.Policy{
|
||||
{To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "=@@"},
|
||||
}
|
||||
badCustomCAPolicies.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "=@@"}}
|
||||
|
||||
goodClientCertPolicies := testOptions(t)
|
||||
goodClientCertPolicies.Policies = []config.Policy{
|
||||
{To: "http://foo.example", From: "http://bar.example",
|
||||
TLSClientKey: "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", TLSClientCert: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="},
|
||||
}
|
||||
goodClientCertPolicies.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSClientKey: "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", TLSClientCert: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="}}
|
||||
goodClientCertPolicies.Validate()
|
||||
|
||||
customServerName := testOptions(t)
|
||||
customServerName.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSServerName: "test"}}
|
||||
|
||||
emptyPolicies := testOptions(t)
|
||||
emptyPolicies.Policies = nil
|
||||
|
||||
allowWebSockets := testOptions(t)
|
||||
allowWebSockets.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", AllowWebsockets: true}}
|
||||
customTimeout := testOptions(t)
|
||||
customTimeout.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", UpstreamTimeout: 10 * time.Second}}
|
||||
corsPreflight := testOptions(t)
|
||||
corsPreflight.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", CORSAllowPreflight: true}}
|
||||
disableAuth := testOptions(t)
|
||||
disableAuth.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", AllowPublicUnauthenticatedAccess: true}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
originalOptions config.Options
|
||||
|
@ -301,12 +195,18 @@ func Test_UpdateOptions(t *testing.T) {
|
|||
{"changed", good, newPolicies, "", "https://bar.example", false, true},
|
||||
{"changed and missing", good, newPolicies, "", "https://corp.example.example", false, false},
|
||||
{"bad signing key", good, newPolicies, "^bad base 64", "https://corp.example.example", true, false},
|
||||
{"good signing key", good, newPolicies, "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSU0zbXBaSVdYQ1g5eUVneFU2czU3Q2J0YlVOREJTQ0VBdFFGNWZVV0hwY1FvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFaFBRditMQUNQVk5tQlRLMHhTVHpicEVQa1JyazFlVXQxQk9hMzJTRWZVUHpOaTRJV2VaLwpLS0lUdDJxMUlxcFYyS01TYlZEeXI5aWp2L1hoOThpeUV3PT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQ==", "https://corp.example.example", false, true},
|
||||
{"bad change bad policy url", good, badNewPolicy, "", "https://bar.example", true, false},
|
||||
{"disable tls verification", good, disableTLSPolicies, "", "https://bar.example", false, true},
|
||||
{"custom root ca", good, customCAPolicies, "", "https://bar.example", false, true},
|
||||
{"bad custom root ca base64", good, badCustomCAPolicies, "", "https://bar.example", true, false},
|
||||
{"good client certs", good, goodClientCertPolicies, "", "https://bar.example", false, true},
|
||||
{"custom server name", customServerName, customServerName, "", "https://bar.example", false, true},
|
||||
{"good no policies to start", emptyPolicies, good, "", "https://corp.example.example", false, true},
|
||||
{"allow websockets", good, allowWebSockets, "", "https://corp.example.example", false, true},
|
||||
{"no websockets, custom timeout", good, customTimeout, "", "https://corp.example.example", false, true},
|
||||
{"enable cors preflight", good, corsPreflight, "", "https://corp.example.example", false, true},
|
||||
{"disable auth", good, disableAuth, "", "https://corp.example.example", false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -324,9 +224,10 @@ func Test_UpdateOptions(t *testing.T) {
|
|||
|
||||
// This is only safe if we actually can load policies
|
||||
if err == nil {
|
||||
req := httptest.NewRequest("GET", tt.host, nil)
|
||||
_, ok := p.router(req)
|
||||
if ok != tt.wantRoute {
|
||||
r := httptest.NewRequest("GET", tt.host, nil)
|
||||
w := httptest.NewRecorder()
|
||||
p.Handler.ServeHTTP(w, r)
|
||||
if tt.wantRoute && w.Code != http.StatusNotFound {
|
||||
t.Errorf("Failed to find route handler")
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue