proxy: use middleware to manage request flow

proxy: remove duplicate error handling in New
proxy: remove routeConfigs in favor of using gorilla/mux
proxy: add proxy specific middleware
proxy: no longer need to use middleware / handler to check if valid route. Can use build in 404 mux.
internal/middleware: add cors bypass middleware

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2019-09-24 18:56:35 -07:00
parent 70c5553d3c
commit 782ffbeb3e
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
17 changed files with 834 additions and 839 deletions

View file

@ -112,7 +112,7 @@ func newProxyService(opt config.Options, r *mux.Router) (*proxy.Proxy, error) {
if err != nil {
return nil, err
}
r.PathPrefix("/").Handler(service.Handler())
r.PathPrefix("/").Handler(service.Handler)
return service, nil
}

4
go.sum
View file

@ -70,6 +70,7 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk=
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
@ -124,8 +125,6 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30 h1:jggCv6hZvcxjGa3gqkYY2EUuOkITI9Znugz/f3QJfRQ=
github.com/pomerium/csrf v1.6.2-0.20190911035354-d4d212209a30/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0=
github.com/pomerium/csrf v1.6.2-0.20190918035251-f3318380bad3 h1:FmzFXnCAepHZwl6QPhTFqBHcbcGevdiEQjutK+M5bj4=
github.com/pomerium/csrf v1.6.2-0.20190918035251-f3318380bad3/go.mod h1:UE2U4JOsjXNeq+MX/lqhZpUFsNAxbXERuYsWK2iULh0=
github.com/pomerium/go-oidc v2.0.0+incompatible h1:gVvG/ExWsHQqatV+uceROnGmbVYF44mDNx5nayBhC0o=
@ -198,7 +197,6 @@ golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmV
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8 h1:1wopBVtVdWnn03fZelqdXTqk7U7zPQCb+T4rbU9ZEoU=
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 h1:0hQKqeLdqlt5iIwVOBErRisrHJAN57yOiPRQItI20fU=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=

View file

@ -0,0 +1,31 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
stdlog "log"
"net/http"
"net/http/httputil"
"net/url"
"github.com/pomerium/pomerium/internal/log"
)
// HeaderForwardHost is the header key the identifies the originating
// IP addresses of a client connecting to a web server through an HTTP proxy
// or a load balancer.
const HeaderForwardHost = "X-Forwarded-Host"
// NewReverseProxy returns a new ReverseProxy that routes
// URLs to the scheme, host, and base path provided in target,
// rewrites the Host header, and sets `X-Forwarded-Host`.
func NewReverseProxy(target *url.URL) *httputil.ReverseProxy {
reverseProxy := httputil.NewSingleHostReverseProxy(target)
sublogger := log.With().Str("reverse-proxy", target.Host).Logger()
reverseProxy.ErrorLog = stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0)
director := reverseProxy.Director
reverseProxy.Director = func(req *http.Request) {
req.Header.Add(HeaderForwardHost, req.Host)
director(req)
req.Host = target.Host
}
return reverseProxy
}

View file

@ -0,0 +1,35 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
func TestNewReverseProxy(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
hostname, _, _ := net.SplitHostPort(r.Host)
w.Write([]byte(hostname))
}))
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
backendHost := net.JoinHostPort(backendHostname, backendPort)
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
proxyHandler := NewReverseProxy(proxyURL)
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
res, _ := http.DefaultClient.Do(getReq)
bodyBytes, _ := ioutil.ReadAll(res.Body)
if g, e := string(bodyBytes), backendHostname; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
}

View file

@ -1,4 +1,4 @@
package httputil
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
import (
"encoding/base64"

View file

@ -0,0 +1,26 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"net/http"
"github.com/pomerium/pomerium/internal/telemetry/trace"
)
// CorsBypass is middleware that takes a target handler as a paramater,
// if the request is determined to be a CORS preflight request, that handler
// is called instead of the normal handler chain.
func CorsBypass(target http.Handler) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.CorsBypass")
defer span.End()
if r.Method == http.MethodOptions &&
r.Header.Get("Access-Control-Request-Method") != "" &&
r.Header.Get("Origin") != "" {
target.ServeHTTP(w, r.WithContext(ctx))
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

View file

@ -0,0 +1,57 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
func someOtherMiddleware(s string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Some-Other-Middleware", s)
next.ServeHTTP(w, r)
})
}
}
func TestCorsBypass(t *testing.T) {
t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
method string
header http.Header
wantStatus int
wantHeader string
}{
{"good", http.MethodOptions, http.Header{"Access-Control-Request-Method": []string{"GET"}, "Origin": []string{"localhost"}}, 200, ""},
{"invalid cors - non options request", http.MethodGet, http.Header{"Access-Control-Request-Method": []string{"GET"}, "Origin": []string{"localhost"}}, 200, "BAD"},
{"invalid cors - Origin not set", http.MethodOptions, http.Header{"Access-Control-Request-Method": []string{"GET"}, "Origin": []string{""}}, 200, "BAD"},
{"invalid cors - Access-Control-Request-Method not set", http.MethodOptions, http.Header{"Access-Control-Request-Method": []string{""}, "Origin": []string{"*"}}, 200, "BAD"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &http.Request{
Method: tt.method,
Header: tt.header,
}
w := httptest.NewRecorder()
target := fn
got := CorsBypass(target)(someOtherMiddleware("BAD")(target))
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("TestCorsBypass() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
if header := w.Header().Get("Some-Other-Middleware"); tt.wantHeader != header {
t.Errorf("TestCorsBypass() header = %v, want %v", header, tt.wantHeader)
}
})
}
}

View file

@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"strings"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil"
@ -121,22 +122,6 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
}
}
// ValidateHost ensures that each request's host is valid
func ValidateHost(validHost func(host string) bool) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateHost")
defer span.End()
if !validHost(r.Host) {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusNotFound, nil))
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Healthcheck endpoint middleware useful to setting up a path like
// `/ping` that load balancers or uptime testing external services
// can make a request before hitting any routes. It's also convenient
@ -185,3 +170,33 @@ func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
}
return cryptutil.CheckHMAC([]byte(fmt.Sprint(redirectURI, timestamp)), requestSig, secret)
}
// StripCookie strips the cookie from the downstram request.
func StripCookie(cookieName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.StripCookie")
defer span.End()
headers := make([]string, 0, len(r.Cookies()))
for _, cookie := range r.Cookies() {
if !strings.HasPrefix(cookie.Name, cookieName) {
headers = append(headers, cookie.String())
}
}
r.Header.Set("Cookie", strings.Join(headers, ";"))
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// TimeoutHandlerFunc wraps http.TimeoutHandler
func TimeoutHandlerFunc(timeout time.Duration, timeoutError string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.TimeoutHandlerFunc")
defer span.End()
http.TimeoutHandler(next, timeout, timeoutError).ServeHTTP(w, r.WithContext(ctx))
})
}
}

View file

@ -1,4 +1,4 @@
package middleware
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"encoding/base64"
@ -276,60 +276,80 @@ func TestHealthCheck(t *testing.T) {
}
}
// Redirect to a fixed URL
type handlerHelper struct {
msg string
}
func (rh *handlerHelper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(rh.msg))
}
func handlerHelp(msg string) http.Handler {
return &handlerHelper{msg}
}
func TestValidateHost(t *testing.T) {
validHostFunc := func(host string) bool {
return host == "google.com"
}
validHostHandler := handlerHelp("google")
func TestStripCookie(t *testing.T) {
tests := []struct {
name string
isValidHost func(string) bool
validHostHandler http.Handler
clientPath string
expected []byte
status int
name string
pomeriumCookie string
otherCookies []string
}{
{"good", validHostFunc, validHostHandler, "google.com", []byte("google"), 200},
{"no route", validHostFunc, validHostHandler, "googles.com", []byte("google"), 404},
{"good", "pomerium", []string{"x", "y", "z"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, tt.clientPath, nil)
if err != nil {
t.Fatal(err)
}
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, cookie := range r.Cookies() {
if cookie.Name == tt.pomeriumCookie {
t.Errorf("cookie not stripped %s", r.Cookies())
}
}
})
rr := httptest.NewRecorder()
var testHandler http.Handler
if tt.isValidHost(tt.clientPath) {
tt.validHostHandler.ServeHTTP(rr, req)
testHandler = tt.validHostHandler
} else {
testHandler = handlerHelp("ok")
for _, cn := range tt.otherCookies {
http.SetCookie(rr, &http.Cookie{
Name: cn,
Value: "some other cookie",
})
}
handler := ValidateHost(tt.isValidHost)(testHandler)
http.SetCookie(rr, &http.Cookie{
Name: tt.pomeriumCookie,
Value: "pomerium cookie!",
})
http.SetCookie(rr, &http.Cookie{
Name: tt.pomeriumCookie + "_csrf",
Value: "pomerium csrf cookie!",
})
req := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}}
handler := StripCookie(tt.pomeriumCookie)(testHandler)
handler.ServeHTTP(rr, req)
if rr.Code != tt.status {
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
t.Errorf("%s", rr.Body)
}
})
}
}
func TestTimeoutHandlerFunc(t *testing.T) {
t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
timeout time.Duration
timeoutError string
wantStatus int
wantBody string
}{
{"good", 1 * time.Second, "good timed out!?", http.StatusOK, http.StatusText(http.StatusOK)},
{"timeout!", 1 * time.Nanosecond, "ruh roh", http.StatusServiceUnavailable, "ruh roh"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
got := TimeoutHandlerFunc(tt.timeout, tt.timeoutError)(fn)
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
if body := w.Body.String(); tt.wantBody != body {
t.Errorf("SignRequest() body = %v, want %v", body, tt.wantBody)
}
})
}
}

View file

@ -1,48 +0,0 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"net/http"
"strings"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace"
)
func SignRequest(signer cryptutil.JWTSigner, id, email, groups, header string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.SignRequest")
defer span.End()
jwt, err := signer.SignJWT(
r.Header.Get(id),
r.Header.Get(email),
r.Header.Get(groups))
if err != nil {
log.Warn().Err(err).Msg("internal/middleware: failed signing request")
} else {
r.Header.Set(header, jwt)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// StripPomeriumCookie ensures that every response includes some basic security headers
func StripPomeriumCookie(cookieName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.StripPomeriumCookie")
defer span.End()
headers := make([]string, 0, len(r.Cookies()))
for _, cookie := range r.Cookies() {
if !strings.HasPrefix(cookie.Name, cookieName) {
headers = append(headers, cookie.String())
}
}
r.Header.Set("Cookie", strings.Join(headers, ";"))
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

View file

@ -1,100 +0,0 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/pomerium/pomerium/internal/cryptutil"
)
const exampleKey = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIM3mpZIWXCX9yEgxU6s57CbtbUNDBSCEAtQF5fUWHpcQoAoGCCqGSM49
AwEHoUQDQgAEhPQv+LACPVNmBTK0xSTzbpEPkRrk1eUt1BOa32SEfUPzNi4IWeZ/
KKITt2q1IqpV2KMSbVDyr9ijv/Xh98iyEw==
-----END EC PRIVATE KEY-----
`
func TestSignRequest(t *testing.T) {
tests := []struct {
name string
id string
email string
groups string
header string
}{
{"good", "id", "email", "group", "Jwt"},
}
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Header.Set(fmt.Sprintf("%s-header", tt.id), tt.id)
r.Header.Set(fmt.Sprintf("%s-header", tt.email), tt.email)
r.Header.Set(fmt.Sprintf("%s-header", tt.groups), tt.groups)
})
rr := httptest.NewRecorder()
signer, err := cryptutil.NewES256Signer(base64.StdEncoding.EncodeToString([]byte(exampleKey)), "audience")
if err != nil {
t.Fatal(err)
}
handler := SignRequest(signer, tt.id, tt.email, tt.groups, tt.header)(testHandler)
handler.ServeHTTP(rr, req)
jwt := req.Header["Jwt"]
if len(jwt) != 1 {
t.Errorf("no jwt found %v", req.Header)
}
})
}
}
func TestStripPomeriumCookie(t *testing.T) {
tests := []struct {
name string
pomeriumCookie string
otherCookies []string
}{
{"good", "pomerium", []string{"x", "y", "z"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, cookie := range r.Cookies() {
if cookie.Name == tt.pomeriumCookie {
t.Errorf("cookie not stripped %s", r.Cookies())
}
}
})
rr := httptest.NewRecorder()
for _, cn := range tt.otherCookies {
http.SetCookie(rr, &http.Cookie{
Name: cn,
Value: "some other cookie",
})
}
http.SetCookie(rr, &http.Cookie{
Name: tt.pomeriumCookie,
Value: "pomerium cookie!",
})
http.SetCookie(rr, &http.Cookie{
Name: tt.pomeriumCookie + "_csrf",
Value: "pomerium csrf cookie!",
})
req := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}}
handler := StripPomeriumCookie(tt.pomeriumCookie)(testHandler)
handler.ServeHTTP(rr, req)
})
}
}

View file

@ -7,67 +7,34 @@ import (
"strings"
"time"
"github.com/gorilla/mux"
"github.com/pomerium/csrf"
"github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates"
"github.com/pomerium/pomerium/internal/urlutil"
)
// Handler returns the proxy service's ServeMux
func (p *Proxy) Handler() http.Handler {
r := httputil.NewRouter()
r.SkipClean(true)
r.StrictSlash(true)
r.Use(middleware.ValidateHost(func(host string) bool {
_, ok := p.routeConfigs[host]
return ok
}))
r.HandleFunc("/robots.txt", p.RobotsTxt)
// requires authN not authZ
r.Use(sessions.RetrieveSession(p.sessionStore))
r.Use(p.VerifySession)
// Proxy service endpoints
v := r.PathPrefix("/.pomerium").Subrouter()
v.Use(csrf.Protect(
// registerHelperHandlers returns the proxy service's ServeMux
func (p *Proxy) registerHelperHandlers(r *mux.Router) *mux.Router {
h := r.PathPrefix(dashboardURL).Subrouter()
h.Use(sessions.RetrieveSession(p.sessionStore))
h.Use(p.AuthenticateSession)
h.Use(csrf.Protect(
p.cookieSecret,
csrf.Path("/"),
csrf.Domain(p.cookieDomain),
csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieName)),
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
))
v.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet)
v.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost)
v.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost)
v.HandleFunc("/refresh", p.ForceRefresh).Methods(http.MethodPost)
r.PathPrefix("/").HandlerFunc(p.Proxy)
h.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet)
h.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost)
h.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost)
h.HandleFunc("/refresh", p.ForceRefresh).Methods(http.MethodPost)
return r
}
// VerifySession is the middleware used to enforce a valid authentication
// session state is attached to the users's request context.
func (p *Proxy) VerifySession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
state, err := sessions.FromContext(r.Context())
if err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error")
p.authenticate(w, r)
return
}
if err := state.Valid(); err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to invalid session")
p.authenticate(w, r)
return
}
next.ServeHTTP(w, r)
})
}
// RobotsTxt sets the User-Agent header in the response to be "Disallow"
func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
@ -88,76 +55,6 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, uri.String(), http.StatusFound)
}
// Authenticate begins the authenticate flow, encrypting the redirect url
// in a request to the provider's sign in endpoint.
func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request) {
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
http.Redirect(w, r, uri.String(), http.StatusFound)
}
// shouldSkipAuthentication contains conditions for skipping authentication.
// Conditions should be few in number and have strong justifications.
func (p *Proxy) shouldSkipAuthentication(r *http.Request) bool {
policy, policyExists := p.policy(r)
if isCORSPreflight(r) && policyExists && policy.CORSAllowPreflight {
log.FromRequest(r).Debug().Msg("proxy: skipping authentication for valid CORS preflight request")
return true
}
if policyExists && policy.AllowPublicUnauthenticatedAccess {
log.FromRequest(r).Debug().Msg("proxy: skipping authentication for public route")
return true
}
return false
}
// isCORSPreflight inspects the request to see if this is a valid CORS preflight request.
// These checks are not exhaustive, because the proxied server should be verifying it as well.
//
// See https://www.html5rocks.com/static/images/cors_server_flowchart.png for more info.
func isCORSPreflight(r *http.Request) bool {
return r.Method == http.MethodOptions &&
r.Header.Get("Access-Control-Request-Method") != "" &&
r.Header.Get("Origin") != ""
}
// Proxy authenticates a request, either proxying the request if it is authenticated,
// or starting the authenticate service for validation if not.
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
route, ok := p.router(r)
if !ok {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusNotFound, nil))
return
}
if p.shouldSkipAuthentication(r) {
log.FromRequest(r).Debug().Msg("proxy: access control skipped")
route.ServeHTTP(w, r)
return
}
s, err := sessions.FromContext(r.Context())
if err != nil || s == nil {
log.Debug().Err(err).Msg("proxy: couldn't get session from context")
p.authenticate(w, r)
return
}
authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s)
if err != nil {
httputil.ErrorResponse(w, r, err)
return
} else if !authorized {
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not authorized for this route", s.RequestEmail()), http.StatusForbidden, nil))
return
}
r.Header.Set(HeaderUserID, s.User)
r.Header.Set(HeaderEmail, s.RequestEmail())
r.Header.Set(HeaderGroups, s.RequestGroups())
route.ServeHTTP(w, r)
}
// UserDashboard lets users investigate, and refresh their current session.
// It also contains certain administrative actions like user impersonation.
// Nota bene: This endpoint does authentication, not authorization.
@ -207,10 +104,9 @@ func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
// reject a refresh if it's been less than the refresh cooldown to prevent abuse
if time.Since(iss) < p.refreshCooldown {
httputil.ErrorResponse(w, r,
httputil.Error(
fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown),
http.StatusBadRequest, nil))
errStr := fmt.Sprintf("Session must be %s old before refreshing", p.refreshCooldown)
httpErr := httputil.Error(errStr, http.StatusBadRequest, nil)
httputil.ErrorResponse(w, r, httpErr)
return
}
session.ForceRefresh()
@ -218,7 +114,7 @@ func (p *Proxy) ForceRefresh(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, err)
return
}
http.Redirect(w, r, "/.pomerium", http.StatusFound)
http.Redirect(w, r, dashboardURL, http.StatusFound)
}
// Impersonate takes the result of a form and adds user impersonation details
@ -232,7 +128,9 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
}
isAdmin, err := p.AuthorizeClient.IsAdmin(r.Context(), session)
if err != nil || !isAdmin {
httputil.ErrorResponse(w, r, httputil.Error(fmt.Sprintf("%s is not an administrator", session.RequestEmail()), http.StatusForbidden, err))
errStr := fmt.Sprintf("%s is not an administrator", session.RequestEmail())
httpErr := httputil.Error(errStr, http.StatusForbidden, err)
httputil.ErrorResponse(w, r, httpErr)
return
}
// OK to impersonation
@ -247,27 +145,5 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
return
}
http.Redirect(w, r, "/.pomerium", http.StatusFound)
}
// router attempts to find a route for a request. If a route is successfully matched,
// it returns the route information and a bool value of `true`. If a route can
// not be matched, a nil value for the route and false bool value is returned.
func (p *Proxy) router(r *http.Request) (http.Handler, bool) {
config, ok := p.routeConfigs[r.Host]
if ok {
return config.mux, true
}
return nil, false
}
// policy attempts to find a policy for a request. If a policy is successfully matched,
// it returns the policy information and a bool value of `true`. If a policy can not be matched,
// a nil value for the policy and false bool value is returned.
func (p *Proxy) policy(r *http.Request) (*config.Policy, bool) {
config, ok := p.routeConfigs[r.Host]
if ok {
return &config.policy, true
}
return nil, false
http.Redirect(w, r, dashboardURL, http.StatusFound)
}

View file

@ -1,4 +1,4 @@
package proxy
package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"bytes"
@ -13,7 +13,6 @@ import (
"github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/proxy/clients"
)
@ -76,153 +75,114 @@ func TestProxy_authenticate(t *testing.T) {
}
}
func TestProxy_Handler(t *testing.T) {
proxy, err := New(testOptions(t))
if err != nil {
t.Fatal(err)
}
h := proxy.Handler()
if h == nil {
t.Error("handler cannot be nil")
}
mux := http.NewServeMux()
mux.Handle("/", h)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
mux.ServeHTTP(rr, req)
if rr.Code != http.StatusNotFound {
t.Errorf("expected 404 route not found for empty route")
}
}
// func TestProxy_PomeriumHandler(t *testing.T) {
// proxy, err := New(testOptions(t))
// if err != nil {
// t.Fatal(err)
// }
// h := proxy.registerHelperHandlers()
// if h == nil {
// t.Error("handler cannot be nil")
// }
// mux := http.NewServeMux()
// mux.Handle("/", h)
// req := httptest.NewRequest(http.MethodGet, "/", nil)
// rr := httptest.NewRecorder()
// mux.ServeHTTP(rr, req)
// if rr.Code != http.StatusNotFound {
// t.Errorf("expected 404 route not found for empty route")
// }
// }
func TestProxy_router(t *testing.T) {
testPolicy := config.Policy{From: "https://corp.example.com", To: "https://example.com"}
if err := testPolicy.Validate(); err != nil {
t.Fatal(err)
}
policies := []config.Policy{testPolicy}
tests := []struct {
name string
host string
mux []config.Policy
route http.Handler
wantOk bool
}{
{"good corp", "https://corp.example.com", policies, nil, true},
{"good with slash", "https://corp.example.com/", policies, nil, true},
{"good with path", "https://corp.example.com/123", policies, nil, true},
{"no policies", "https://notcorp.example.com/123", []config.Policy{}, nil, false},
{"bad corp", "https://notcorp.example.com/123", policies, nil, false},
{"bad sub-sub", "https://notcorp.corp.example.com/123", policies, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := testOptions(t)
opts.Policies = tt.mux
p, err := New(opts)
if err != nil {
t.Fatal(err)
}
p.encoder = &cryptutil.MockEncoder{MarshalResponse: "foo"}
// func TestProxy_Proxy(t *testing.T) {
// goodSession := &sessions.State{
// AccessToken: "AccessToken",
// RefreshToken: "RefreshToken",
// RefreshDeadline: time.Now().Add(20 * time.Second),
// }
req := httptest.NewRequest(http.MethodGet, tt.host, nil)
_, ok := p.router(req)
if ok != tt.wantOk {
t.Errorf("Proxy.router() ok = %v, want %v", ok, tt.wantOk)
}
})
}
}
// ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// w.Header().Set("Content-Type", "text/plain; charset=utf-8")
// w.Header().Set("X-Content-Type-Options", "nosniff")
// fmt.Fprintln(w, "RVSI FILIVS CAISAR")
// w.WriteHeader(http.StatusOK)
func TestProxy_Proxy(t *testing.T) {
goodSession := &sessions.State{
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
RefreshDeadline: time.Now().Add(20 * time.Second),
}
// }))
// defer ts.Close()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
w.WriteHeader(http.StatusOK)
// opts := testOptionsTestServer(t, ts.URL)
// optsCORS := testOptionsWithCORS(t, ts.URL)
// optsPublic := testOptionsWithPublicAccess(t, ts.URL)
// optsNoPolicies := testOptionsWithEmptyPolicies(t, ts.URL)
}))
defer ts.Close()
// defaultHeaders, goodCORSHeaders, badCORSHeaders, headersWs := http.Header{}, http.Header{}, http.Header{}, http.Header{}
// goodCORSHeaders.Set("origin", "anything")
// goodCORSHeaders.Set("access-control-request-method", "anything")
// // missing "Origin"
// badCORSHeaders.Set("access-control-request-method", "anything")
// headersWs.Set("Connection", "Upgrade")
// headersWs.Set("Upgrade", "websocket")
opts := testOptionsTestServer(t, ts.URL)
optsCORS := testOptionsWithCORS(t, ts.URL)
optsPublic := testOptionsWithPublicAccess(t, ts.URL)
optsNoPolicies := testOptionsWithEmptyPolicies(t, ts.URL)
// tests := []struct {
// name string
// options config.Options
// method string
// header http.Header
// host string
// session sessions.SessionStore
// authorizer clients.Authorizer
// wantStatus int
// }{
// {"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(20 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
// {"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
// {"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
// {"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
// // same request as above, but with cors_allow_preflight=false in the policy
// {"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
// // cors allowed, but the request is missing proper headers
// {"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
// // redirect to start auth process
// {"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
// {"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
// {"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError},
// // authenticate errors
// {"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
// {"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
// {"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound},
// {"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
// {"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
// }
defaultHeaders, goodCORSHeaders, badCORSHeaders, headersWs := http.Header{}, http.Header{}, http.Header{}, http.Header{}
goodCORSHeaders.Set("origin", "anything")
goodCORSHeaders.Set("access-control-request-method", "anything")
// missing "Origin"
badCORSHeaders.Set("access-control-request-method", "anything")
headersWs.Set("Connection", "Upgrade")
headersWs.Set("Upgrade", "websocket")
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// err := ValidateOptions(tt.options)
// if err != nil {
// t.Fatal(err)
// }
// p, err := New(tt.options)
// if err != nil {
// t.Fatal(err)
// }
// p.encoder = &cryptutil.MockEncoder{MarshalResponse: "foo"}
// p.sessionStore = tt.session
// p.AuthorizeClient = tt.authorizer
// r := httptest.NewRequest(tt.method, tt.host, nil)
// r.Header = tt.header
// r.Header.Set("Accept", "application/json")
// state, _ := tt.session.LoadSession(r)
// ctx := r.Context()
// ctx = sessions.NewContext(ctx, state, nil)
// r = r.WithContext(ctx)
tests := []struct {
name string
options config.Options
method string
header http.Header
host string
session sessions.SessionStore
authorizer clients.Authorizer
wantStatus int
}{
{"good", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(20 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
{"good email impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateEmail: "test@user.example"}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good group impersonation", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: &sessions.State{AccessToken: "AccessToken", RefreshToken: "RefreshToken", RefreshDeadline: time.Now().Add(10 * time.Second), ImpersonateGroups: []string{"group1", "group2"}}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
// same request as above, but with cors_allow_preflight=false in the policy
{"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
// cors allowed, but the request is missing proper headers
{"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
// redirect to start auth process
{"unknown host", opts, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
{"user not authorized", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
{"authorization call failed", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeError: errors.New("error")}, http.StatusInternalServerError},
// authenticate errors
{"session expired,redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: sessions.ErrExpired}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"public access", optsPublic, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusOK},
{"public access, but unknown host", optsPublic, http.MethodGet, defaultHeaders, "https://nothttpbin.corp.example", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusNotFound},
{"no http found (no session),redirect to authn", opts, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{LoadError: http.ErrNoCookie}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"No policies", optsNoPolicies, http.MethodGet, defaultHeaders, "https://httpbin.corp.example/", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
}
// w := httptest.NewRecorder()
// p.Proxy(w, r)
// if status := w.Code; status != tt.wantStatus {
// t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String())
// }
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateOptions(tt.options)
if err != nil {
t.Fatal(err)
}
p, err := New(tt.options)
if err != nil {
t.Fatal(err)
}
p.encoder = &cryptutil.MockEncoder{MarshalResponse: "foo"}
p.sessionStore = tt.session
p.AuthorizeClient = tt.authorizer
r := httptest.NewRequest(tt.method, tt.host, nil)
r.Header = tt.header
r.Header.Set("Accept", "application/json")
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, nil)
r = r.WithContext(ctx)
w := httptest.NewRecorder()
p.Proxy(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("handler returned wrong status code: got %v want %v \n body %s", status, tt.wantStatus, w.Body.String())
}
})
}
}
// })
// }
// }
func TestProxy_UserDashboard(t *testing.T) {
opts := testOptions(t)
@ -427,55 +387,9 @@ func TestProxy_SignOut(t *testing.T) {
}
}
func uriParseHelper(s string) *url.URL {
uri, _ := url.Parse(s)
uri, err := url.Parse(s)
if err != nil {
panic(err)
}
return uri
}
func TestProxy_VerifySession(t *testing.T) {
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
session sessions.SessionStore
ctxError error
provider identity.Authenticator
wantStatus int
}{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := Proxy{
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
sessionStore: tt.session,
}
r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := a.VerifySession(fn)
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
})
}
}

103
proxy/middleware.go Normal file
View file

@ -0,0 +1,103 @@
package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"fmt"
"net/http"
"strings"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil"
)
const (
// HeaderJWT is the header key containing JWT signed user details.
HeaderJWT = "x-pomerium-jwt-assertion"
// HeaderUserID is the header key containing the user's id.
HeaderUserID = "x-pomerium-authenticated-user-id"
// HeaderEmail is the header key containing the user's email.
HeaderEmail = "x-pomerium-authenticated-user-email"
// HeaderGroups is the header key containing the user's groups.
HeaderGroups = "x-pomerium-authenticated-user-groups"
)
// AuthenticateSession is middleware to enforce a valid authentication
// session state is retrieved from the users's request context.
func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthenticateSession")
defer span.End()
s, err := sessions.FromContext(r.Context())
if err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error")
p.authenticate(w, r)
return
}
if err := s.Valid(); err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to invalid session")
p.authenticate(w, r)
return
}
r.Header.Set(HeaderUserID, s.User)
r.Header.Set(HeaderEmail, s.RequestEmail())
r.Header.Set(HeaderGroups, s.RequestGroups())
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// AuthorizeSession is middleware to enforce a user is authorized for a request
// session state is retrieved from the users's request context.
func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthorizeSession")
defer span.End()
s, err := sessions.FromContext(r.Context())
if err != nil {
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err))
return
}
authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, s)
if err != nil {
httputil.ErrorResponse(w, r.WithContext(ctx), err)
return
} else if !authorized {
errMsg := fmt.Sprintf("%s is not authorized for this route", s.RequestEmail())
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error(errMsg, http.StatusForbidden, nil))
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// SignRequest is middleware that signs a JWT that contains a user's id,
// email, and group. Session state is retrieved from the users's request context
func (p *Proxy) SignRequest(signer cryptutil.JWTSigner) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.SignRequest")
defer span.End()
s, err := sessions.FromContext(r.Context())
if err != nil {
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err))
return
}
jwt, err := signer.SignJWT(s.User, s.Email, strings.Join(s.Groups, ","))
if err != nil {
log.Warn().Err(err).Msg("proxy: failed signing jwt")
} else {
r.Header.Set(HeaderJWT, jwt)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Authenticate begins the authenticate flow, encrypting the redirect url
// in a request to the provider's sign in endpoint.
func (p *Proxy) authenticate(w http.ResponseWriter, r *http.Request) {
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
http.Redirect(w, r, uri.String(), http.StatusFound)
}

175
proxy/middleware_test.go Normal file
View file

@ -0,0 +1,175 @@
package proxy
import (
"encoding/base64"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/proxy/clients"
)
func TestProxy_AuthenticateSession(t *testing.T) {
t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
session sessions.SessionStore
ctxError error
provider identity.Authenticator
wantStatus int
}{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusOK},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, nil, identity.MockProvider{}, http.StatusFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := Proxy{
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
sessionStore: tt.session,
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := a.AuthenticateSession(fn)
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("AuthenticateSession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
})
}
}
func TestProxy_AuthorizeSession(t *testing.T) {
t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
session sessions.SessionStore
authzClient clients.Authorizer
ctxError error
provider identity.Authenticator
wantStatus int
}{
{"user is authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK},
{"user is not authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusForbidden},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusForbidden},
{"authz client error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := Proxy{
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
sessionStore: tt.session,
AuthorizeClient: tt.authzClient,
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := a.AuthorizeSession(fn)
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("AuthorizeSession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
})
}
}
type mockJWTSigner struct {
SignError error
}
// Sign implements the JWTSigner interface from the cryptutil package, but just
// base64's the inputs instead for stesting.
func (s *mockJWTSigner) SignJWT(user, email, groups string) (string, error) {
return base64.StdEncoding.EncodeToString([]byte(fmt.Sprint(user, email, groups))), s.SignError
}
func TestProxy_SignRequest(t *testing.T) {
t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
session sessions.SessionStore
signerError error
ctxError error
wantStatus int
wantHeaders string
}{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "dGVzdA=="},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""},
{"signature failure, warn but ok", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := Proxy{
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
sessionStore: tt.session,
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
signer := &mockJWTSigner{SignError: tt.signerError}
got := a.SignRequest(signer)(fn)
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
if headers := r.Header.Get(HeaderJWT); tt.wantHeaders != headers {
t.Errorf("SignRequest() headers = %v, want %v", headers, tt.wantHeaders)
}
})
}
}

View file

@ -5,15 +5,14 @@ import (
"encoding/base64"
"fmt"
"html/template"
stdlog "log"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/gorilla/mux"
"github.com/pomerium/pomerium/internal/config"
"github.com/pomerium/pomerium/internal/cryptutil"
pom_httputil "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions"
@ -25,15 +24,8 @@ import (
)
const (
// HeaderJWT is the header key containing JWT signed user details.
HeaderJWT = "x-pomerium-jwt-assertion"
// HeaderUserID is the header key containing the user's id.
HeaderUserID = "x-pomerium-authenticated-user-id"
// HeaderEmail is the header key containing the user's email.
HeaderEmail = "x-pomerium-authenticated-user-email"
// HeaderGroups is the header key containing the user's groups.
HeaderGroups = "x-pomerium-authenticated-user-groups"
// dashboardURL is the path to authenticate's sign in endpoint
dashboardURL = "/.pomerium/"
// signinURL is the path to authenticate's sign in endpoint
signinURL = "/.pomerium/sign_in"
// signoutURL is the path to authenticate's sign out endpoint
@ -78,38 +70,29 @@ type Proxy struct {
AuthorizeClient clients.Authorizer
// cipher cipher.AEAD
encoder cryptutil.SecureEncoder
cookieName string
cookieDomain string
cookieSecret []byte
defaultUpstreamTimeout time.Duration
refreshCooldown time.Duration
routeConfigs map[string]*routeConfig
Handler http.Handler
sessionStore sessions.SessionStore
signingKey string
templates *template.Template
}
type routeConfig struct {
mux http.Handler
policy config.Policy
}
// New takes a Proxy service from options and a validation function.
// Function returns an error if options fail to validate.
func New(opts config.Options) (*Proxy, error) {
if err := ValidateOptions(opts); err != nil {
return nil, err
}
decodedCookieSecret, err := base64.StdEncoding.DecodeString(opts.CookieSecret)
if err != nil {
return nil, err
}
cipher, err := cryptutil.NewAEADCipherFromBase64(opts.CookieSecret)
if err != nil {
return nil, err
}
// errors checked in ValidateOptions
decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret)
cipher, _ := cryptutil.NewAEADCipherFromBase64(opts.CookieSecret)
encoder := cryptutil.NewSecureJSONEncoder(cipher)
if opts.CookieDomain == "" {
@ -132,7 +115,6 @@ func New(opts config.Options) (*Proxy, error) {
p := &Proxy{
SharedKey: opts.SharedKey,
routeConfigs: make(map[string]*routeConfig),
encoder: encoder,
cookieSecret: decodedCookieSecret,
cookieDomain: opts.CookieDomain,
@ -143,18 +125,18 @@ func New(opts config.Options) (*Proxy, error) {
signingKey: opts.SigningKey,
templates: templates.New(),
}
// DeepCopy urls to avoid accidental mutation, err checked in validate func
// errors checked in ValidateOptions
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
if err := p.UpdatePolicies(&opts); err != nil {
return nil, err
}
metrics.AddPolicyCountCallback("proxy", func() int64 {
return int64(len(p.routeConfigs))
return int64(len(opts.Policies))
})
p.AuthorizeClient, err = clients.NewAuthorizeClient("grpc",
&clients.Options{
@ -169,116 +151,6 @@ func New(opts config.Options) (*Proxy, error) {
return p, err
}
// UpdatePolicies updates the handlers based on the configured policies
func (p *Proxy) UpdatePolicies(opts *config.Options) error {
routeConfigs := make(map[string]*routeConfig, len(opts.Policies))
if len(opts.Policies) == 0 {
log.Warn().Msg("proxy: configuration has no policies")
}
for _, policy := range opts.Policies {
if err := policy.Validate(); err != nil {
return fmt.Errorf("proxy: couldn't update policies %s", err)
}
proxy := NewReverseProxy(policy.Destination)
// build http transport (roundtripper) middleware chain
transport := http.DefaultTransport.(*http.Transport).Clone()
c := tripper.NewChain()
c = c.Append(metrics.HTTPMetricsRoundTripper("proxy", policy.Destination.Host))
var tlsClientConfig tls.Config
var isCustomClientConfig bool
if policy.TLSSkipVerify {
tlsClientConfig.InsecureSkipVerify = true
isCustomClientConfig = true
log.Warn().Str("policy", policy.String()).Msg("proxy: tls skip verify")
}
if policy.RootCAs != nil {
tlsClientConfig.RootCAs = policy.RootCAs
isCustomClientConfig = true
log.Debug().Str("policy", policy.String()).Msg("proxy: custom root ca")
}
if policy.ClientCertificate != nil {
tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate}
isCustomClientConfig = true
log.Debug().Str("policy", policy.String()).Msg("proxy: client certs enabled")
}
if policy.TLSServerName != "" {
tlsClientConfig.ServerName = policy.TLSServerName
isCustomClientConfig = true
log.Debug().Str("policy", policy.String()).Msgf("proxy: tls hostname override to: %s", policy.TLSServerName)
}
// We avoid setting a custom client config unless we have to as
// if TLSClientConfig is nil, the default configuration is used.
if isCustomClientConfig {
transport.TLSClientConfig = &tlsClientConfig
}
proxy.Transport = c.Then(transport)
handler, err := p.newReverseProxyHandler(proxy, &policy)
if err != nil {
return err
}
routeConfigs[policy.Source.Host] = &routeConfig{
mux: handler,
policy: policy,
}
}
p.routeConfigs = routeConfigs
return nil
}
// NewReverseProxy returns a new ReverseProxy that routes URLs to the scheme, host, and
// base path provided in target. NewReverseProxy rewrites the Host header.
func NewReverseProxy(to *url.URL) *httputil.ReverseProxy {
proxy := httputil.NewSingleHostReverseProxy(to)
sublogger := log.With().Str("proxy", to.Host).Logger()
proxy.ErrorLog = stdlog.New(&log.StdLogWrapper{Logger: &sublogger}, "", 0)
director := proxy.Director
proxy.Director = func(req *http.Request) {
// Identifies the originating IP addresses of a client connecting to
// a web server through an HTTP proxy or a load balancer.
req.Header.Add("X-Forwarded-Host", req.Host)
director(req)
req.Host = to.Host
}
return proxy
}
// each route has a custom set of middleware applied to the reverse proxy
func (p *Proxy) newReverseProxyHandler(rp http.Handler, route *config.Policy) (http.Handler, error) {
r := pom_httputil.NewRouter()
r.SkipClean(true)
r.StrictSlash(true)
r.Use(middleware.StripPomeriumCookie(p.cookieName))
// if signing key is set, add signer to middleware
if len(p.signingKey) != 0 {
signer, err := cryptutil.NewES256Signer(p.signingKey, route.Source.Host)
if err != nil {
return nil, err
}
r.Use(middleware.SignRequest(signer, HeaderUserID, HeaderEmail, HeaderGroups, HeaderJWT))
}
// websockets cannot use the non-hijackable timeout-handler
if !route.AllowWebsockets {
timeout := p.defaultUpstreamTimeout
if route.UpstreamTimeout != 0 {
timeout = route.UpstreamTimeout
}
timeoutMsg := fmt.Sprintf("%s timed out in %s", route.Destination.Host, timeout)
rp = http.TimeoutHandler(rp, timeout, timeoutMsg)
}
// todo(bdd) : fix cors
// if route.CORSAllowPreflight {
// r.Use(cors.Default().Handler)
// }
r.Host(route.Destination.Host)
r.PathPrefix("/").Handler(rp)
return r, nil
}
// UpdateOptions updates internal structures based on config.Options
func (p *Proxy) UpdateOptions(o config.Options) error {
if p == nil {
@ -287,3 +159,123 @@ func (p *Proxy) UpdateOptions(o config.Options) error {
log.Info().Msg("proxy: updating options")
return p.UpdatePolicies(&o)
}
// UpdatePolicies updates the H basedon the configured policies
func (p *Proxy) UpdatePolicies(opts *config.Options) error {
var err error
if len(opts.Policies) == 0 {
log.Warn().Msg("proxy: configuration has no policies")
}
r := httputil.NewRouter()
r.SkipClean(true)
r.StrictSlash(true)
r.HandleFunc("/robots.txt", p.RobotsTxt).Methods(http.MethodGet)
r = p.registerHelperHandlers(r)
for _, policy := range opts.Policies {
if err := policy.Validate(); err != nil {
return fmt.Errorf("proxy: invalid policy %s", err)
}
r, err = p.reverseProxyHandler(r, &policy)
if err != nil {
return err
}
}
p.Handler = r
return nil
}
func (p *Proxy) reverseProxyHandler(r *mux.Router, policy *config.Policy) (*mux.Router, error) {
// 1. Create the reverse proxy connection
proxy := httputil.NewReverseProxy(policy.Destination)
// 2. Override any custom transport settings (e.g. TLS settings, etc)
proxy.Transport = p.roundTripperFromPolicy(policy)
// 3. Create a sub-router for a given route's hostname (`httpbin.corp.example.com`)
rp := r.Host(policy.Source.Host).Subrouter()
rp.PathPrefix("/").Handler(proxy)
// Optional: If websockets are enabled, do not set a handler request timeout
// websockets cannot use the non-hijackable timeout-handler
if !policy.AllowWebsockets {
timeout := p.defaultUpstreamTimeout
if policy.UpstreamTimeout != 0 {
timeout = policy.UpstreamTimeout
}
timeoutMsg := fmt.Sprintf("%s timed out in %s", policy.Destination.Host, timeout)
rp.Use(middleware.TimeoutHandlerFunc(timeout, timeoutMsg))
}
// Optional: a cors preflight check, skip access control middleware
if policy.CORSAllowPreflight {
log.Warn().Str("route", policy.String()).Msg("proxy: cors preflight enabled")
rp.Use(middleware.CorsBypass(proxy))
}
// Optional: if a public route, skip access control middleware
if policy.AllowPublicUnauthenticatedAccess {
log.Warn().Str("route", policy.String()).Msg("proxy: all access control disabled")
return r, nil
}
// 4. Retrieve the user session and add it to the request context
rp.Use(sessions.RetrieveSession(p.sessionStore))
// 5. Strip the user session cookie from the downstream request
rp.Use(middleware.StripCookie(p.cookieName))
// 6. AuthN - Verify the user is authenticated. Set email, group, & id headers
rp.Use(p.AuthenticateSession)
// 7. AuthZ - Verify the user is authorized for route
rp.Use(p.AuthorizeSession)
// Optional: Add a signed JWT attesting to the user's id, email, and group
if len(p.signingKey) != 0 {
signer, err := cryptutil.NewES256Signer(p.signingKey, policy.Source.Host)
if err != nil {
return nil, err
}
rp.Use(p.SignRequest(signer))
}
return r, nil
}
// roundTripperFromPolicy adjusts the std library's `DefaultTransport RoundTripper`
// for a given route. A route's `RoundTripper` establishes network connections
// as needed and caches them for reuse by subsequent calls.
func (p *Proxy) roundTripperFromPolicy(policy *config.Policy) http.RoundTripper {
transport := http.DefaultTransport.(*http.Transport).Clone()
c := tripper.NewChain()
c = c.Append(metrics.HTTPMetricsRoundTripper("proxy", policy.Destination.Host))
var tlsClientConfig tls.Config
var isCustomClientConfig bool
if policy.TLSSkipVerify {
tlsClientConfig.InsecureSkipVerify = true
isCustomClientConfig = true
log.Warn().Str("policy", policy.String()).Msg("proxy: tls skip verify")
}
if policy.RootCAs != nil {
tlsClientConfig.RootCAs = policy.RootCAs
isCustomClientConfig = true
log.Debug().Str("policy", policy.String()).Msg("proxy: custom root ca")
}
if policy.ClientCertificate != nil {
tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate}
isCustomClientConfig = true
log.Debug().Str("policy", policy.String()).Msg("proxy: client certs enabled")
}
if policy.TLSServerName != "" {
tlsClientConfig.ServerName = policy.TLSServerName
isCustomClientConfig = true
log.Debug().Str("policy", policy.String()).Msgf("proxy: tls override hostname: %s", policy.TLSServerName)
}
// We avoid setting a custom client config unless we have to as
// if TLSClientConfig is nil, the default configuration is used.
if isCustomClientConfig {
transport.TLSClientConfig = &tlsClientConfig
}
return c.Then(transport)
}

View file

@ -1,8 +1,6 @@
package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
@ -21,72 +19,6 @@ func newTestOptions(t *testing.T) *config.Options {
return opts
}
func TestNewReverseProxy(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
hostname, _, _ := net.SplitHostPort(r.Host)
w.Write([]byte(hostname))
}))
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
backendHost := net.JoinHostPort(backendHostname, backendPort)
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
proxyHandler := NewReverseProxy(proxyURL)
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
res, _ := http.DefaultClient.Do(getReq)
bodyBytes, _ := ioutil.ReadAll(res.Body)
if g, e := string(bodyBytes), backendHostname; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
}
func TestNewReverseProxyHandler(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
hostname, _, _ := net.SplitHostPort(r.Host)
w.Write([]byte(hostname))
}))
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
backendHost := net.JoinHostPort(backendHostname, backendPort)
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
proxyHandler := NewReverseProxy(proxyURL)
opts := newTestOptions(t)
opts.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSU0zbXBaSVdYQ1g5eUVneFU2czU3Q2J0YlVOREJTQ0VBdFFGNWZVV0hwY1FvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFaFBRditMQUNQVk5tQlRLMHhTVHpicEVQa1JyazFlVXQxQk9hMzJTRWZVUHpOaTRJV2VaLwpLS0lUdDJxMUlxcFYyS01TYlZEeXI5aWp2L1hoOThpeUV3PT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
testPolicy := config.Policy{From: "https://corp.example.com", To: "https://example.com", UpstreamTimeout: 1 * time.Second}
if err := testPolicy.Validate(); err != nil {
t.Fatal(err)
}
p, err := New(*opts)
if err != nil {
t.Fatal(err)
}
handle, err := p.newReverseProxyHandler(proxyHandler, &testPolicy)
if err != nil {
t.Fatal(err)
}
frontend := httptest.NewServer(handle)
defer frontend.Close()
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
res, _ := http.DefaultClient.Do(getReq)
bodyBytes, _ := ioutil.ReadAll(res.Body)
if g, e := string(bodyBytes), backendHostname; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
}
func testOptions(t *testing.T) config.Options {
authenticateService, _ := url.Parse("https://authenticate.corp.beyondperimeter.com")
authorizeService, _ := url.Parse("https://authorize.corp.beyondperimeter.com")
@ -106,61 +38,9 @@ func testOptions(t *testing.T) config.Options {
return *opts
}
func testOptionsTestServer(t *testing.T, uri string) config.Options {
authenticateService, _ := url.Parse("https://authenticate.corp.beyondperimeter.com")
authorizeService, _ := url.Parse("https://authorize.corp.beyondperimeter.com")
testPolicy := config.Policy{
From: "https://httpbin.corp.example",
To: uri,
}
if err := testPolicy.Validate(); err != nil {
t.Fatal(err)
}
opts := newTestOptions(t)
opts.Policies = []config.Policy{testPolicy}
opts.AuthenticateURL = authenticateService
opts.AuthorizeURL = authorizeService
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
opts.CookieName = "pomerium"
return *opts
}
func testOptionsWithCORS(t *testing.T, uri string) config.Options {
testPolicy := config.Policy{
From: "https://httpbin.corp.example",
To: uri,
CORSAllowPreflight: true,
}
if err := testPolicy.Validate(); err != nil {
t.Fatal(err)
}
opts := testOptionsTestServer(t, uri)
opts.Policies = []config.Policy{testPolicy}
return opts
}
func testOptionsWithPublicAccess(t *testing.T, uri string) config.Options {
testPolicy := config.Policy{
From: "https://httpbin.corp.example",
To: uri,
AllowPublicUnauthenticatedAccess: true,
}
if err := testPolicy.Validate(); err != nil {
t.Fatal(err)
}
opts := testOptions(t)
opts.Policies = []config.Policy{testPolicy}
return opts
}
func testOptionsWithEmptyPolicies(t *testing.T, uri string) config.Options {
opts := testOptionsTestServer(t, uri)
opts.Policies = []config.Policy{}
return opts
}
func TestOptions_Validate(t *testing.T) {
t.Parallel()
good := testOptions(t)
badAuthURL := testOptions(t)
badAuthURL.AuthenticateURL = nil
@ -215,22 +95,33 @@ func TestOptions_Validate(t *testing.T) {
}
func TestNew(t *testing.T) {
t.Parallel()
good := testOptions(t)
shortCookieLength := testOptions(t)
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
badRoutedProxy := testOptions(t)
badRoutedProxy.SigningKey = "YmFkIGtleQo="
badCookie := testOptions(t)
badCookie.CookieName = ""
badPolicyURL := config.Policy{To: "http://", From: "http://bar.example"}
badNewPolicy := testOptions(t)
badNewPolicy.Policies = []config.Policy{
badPolicyURL,
}
tests := []struct {
name string
opts config.Options
wantProxy bool
numRoutes int
wantErr bool
}{
{"good", good, true, 1, false},
{"empty options", config.Options{}, false, 0, true},
{"short secret/validate sanity check", shortCookieLength, false, 0, true},
{"invalid ec key, valid base64 though", badRoutedProxy, false, 0, true},
{"good", good, true, false},
{"empty options", config.Options{}, false, true},
{"short secret/validate sanity check", shortCookieLength, false, true},
{"invalid ec key, valid base64 though", badRoutedProxy, false, true},
{"invalid cookie name, empty", badCookie, false, true},
{"bad policy, bad policy url", badNewPolicy, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -242,21 +133,17 @@ func TestNew(t *testing.T) {
if got == nil && tt.wantProxy == true {
t.Errorf("New() expected valid proxy struct")
}
if got != nil && len(got.routeConfigs) != tt.numRoutes {
t.Errorf("New() = num routeConfigs \n%+v, want \n%+v \nfrom %+v", got, tt.numRoutes, tt.opts)
}
})
}
}
func Test_UpdateOptions(t *testing.T) {
t.Parallel()
good := testOptions(t)
newPolicy := config.Policy{To: "http://foo.example", From: "http://bar.example"}
newPolicies := testOptions(t)
newPolicies.Policies = []config.Policy{
newPolicy,
}
newPolicies.Policies = []config.Policy{newPolicy}
err := newPolicy.Validate()
if err != nil {
t.Fatal(err)
@ -268,26 +155,33 @@ func Test_UpdateOptions(t *testing.T) {
}
disableTLSPolicy := config.Policy{To: "http://foo.example", From: "http://bar.example", TLSSkipVerify: true}
disableTLSPolicies := testOptions(t)
disableTLSPolicies.Policies = []config.Policy{
disableTLSPolicy,
}
disableTLSPolicies.Policies = []config.Policy{disableTLSPolicy}
customCAPolicies := testOptions(t)
customCAPolicies.Policies = []config.Policy{
{To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURlVENDQW1HZ0F3SUJBZ0lKQUszMmhoR0JIcmFtTUEwR0NTcUdTSWIzRFFFQkN3VUFNR0l4Q3pBSkJnTlYKQkFZVEFsVlRNUk13RVFZRFZRUUlEQXBEWVd4cFptOXlibWxoTVJZd0ZBWURWUVFIREExVFlXNGdSbkpoYm1OcApjMk52TVE4d0RRWURWUVFLREFaQ1lXUlRVMHd4RlRBVEJnTlZCQU1NRENvdVltRmtjM05zTG1OdmJUQWVGdzB4Ck9UQTJNVEl4TlRNeE5UbGFGdzB5TVRBMk1URXhOVE14TlRsYU1HSXhDekFKQmdOVkJBWVRBbFZUTVJNd0VRWUQKVlFRSURBcERZV3hwWm05eWJtbGhNUll3RkFZRFZRUUhEQTFUWVc0Z1JuSmhibU5wYzJOdk1ROHdEUVlEVlFRSwpEQVpDWVdSVFUwd3hGVEFUQmdOVkJBTU1EQ291WW1Ga2MzTnNMbU52YlRDQ0FTSXdEUVlKS29aSWh2Y05BUUVCCkJRQURnZ0VQQURDQ0FRb0NnZ0VCQU1JRTdQaU03Z1RDczloUTFYQll6Sk1ZNjF5b2FFbXdJclg1bFo2eEt5eDIKUG16QVMyQk1UT3F5dE1BUGdMYXcrWExKaGdMNVhFRmRFeXQvY2NSTHZPbVVMbEEzcG1jY1lZejJRVUxGUnRNVwpoeWVmZE9zS25SRlNKaUZ6YklSTWVWWGswV3ZvQmoxSUZWS3RzeWpicXY5dS8yQ1ZTbmRyT2ZFazBURzIzVTNBCnhQeFR1VzFDcmJWOC9xNzFGZEl6U09jaWNjZkNGSHBzS09vM1N0L3FiTFZ5dEg1YW9oYmNhYkZYUk5zS0VxdmUKd3c5SGRGeEJJdUdhK1J1VDVxMGlCaWt1c2JwSkhBd25ucVA3aS9kQWNnQ3NrZ2paakZlRVU0RUZ5K2IrYTFTWQpRQ2VGeHhDN2MzRHZhUmhCQjBWVmZQbGtQejBzdzZsODY1TWFUSWJSeW9VQ0F3RUFBYU15TURBd0NRWURWUjBUCkJBSXdBREFqQmdOVkhSRUVIREFhZ2d3cUxtSmhaSE56YkM1amIyMkNDbUpoWkhOemJDNWpiMjB3RFFZSktvWkkKaHZjTkFRRUxCUUFEZ2dFQkFJaTV1OXc4bWdUNnBwQ2M3eHNHK0E5ZkkzVzR6K3FTS2FwaHI1bHM3MEdCS2JpWQpZTEpVWVpoUGZXcGgxcXRra1UwTEhGUG04M1ZhNTJlSUhyalhUMFZlNEt0TzFuMElBZkl0RmFXNjJDSmdoR1luCmp6dzByeXpnQzRQeUZwTk1uTnRCcm9QdS9iUGdXaU1nTE9OcEVaaGlneDRROHdmMVkvVTlzK3pDQ3hvSmxhS1IKTVhidVE4N1g3bS85VlJueHhvNk56NVpmN09USFRwTk9JNlZqYTBCeGJtSUFVNnlyaXc5VXJnaWJYZk9qM2o2bgpNVExCdWdVVklCMGJCYWFzSnNBTUsrdzRMQU52YXBlWjBET1NuT1I0S0syNEowT3lvRjVmSG1wNTllTTE3SW9GClFxQmh6cG1RVWd1bmVjRVc4QlRxck5wRzc5UjF1K1YrNHd3Y2tQYz0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="},
}
customCAPolicies.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSURlVENDQW1HZ0F3SUJBZ0lKQUszMmhoR0JIcmFtTUEwR0NTcUdTSWIzRFFFQkN3VUFNR0l4Q3pBSkJnTlYKQkFZVEFsVlRNUk13RVFZRFZRUUlEQXBEWVd4cFptOXlibWxoTVJZd0ZBWURWUVFIREExVFlXNGdSbkpoYm1OcApjMk52TVE4d0RRWURWUVFLREFaQ1lXUlRVMHd4RlRBVEJnTlZCQU1NRENvdVltRmtjM05zTG1OdmJUQWVGdzB4Ck9UQTJNVEl4TlRNeE5UbGFGdzB5TVRBMk1URXhOVE14TlRsYU1HSXhDekFKQmdOVkJBWVRBbFZUTVJNd0VRWUQKVlFRSURBcERZV3hwWm05eWJtbGhNUll3RkFZRFZRUUhEQTFUWVc0Z1JuSmhibU5wYzJOdk1ROHdEUVlEVlFRSwpEQVpDWVdSVFUwd3hGVEFUQmdOVkJBTU1EQ291WW1Ga2MzTnNMbU52YlRDQ0FTSXdEUVlKS29aSWh2Y05BUUVCCkJRQURnZ0VQQURDQ0FRb0NnZ0VCQU1JRTdQaU03Z1RDczloUTFYQll6Sk1ZNjF5b2FFbXdJclg1bFo2eEt5eDIKUG16QVMyQk1UT3F5dE1BUGdMYXcrWExKaGdMNVhFRmRFeXQvY2NSTHZPbVVMbEEzcG1jY1lZejJRVUxGUnRNVwpoeWVmZE9zS25SRlNKaUZ6YklSTWVWWGswV3ZvQmoxSUZWS3RzeWpicXY5dS8yQ1ZTbmRyT2ZFazBURzIzVTNBCnhQeFR1VzFDcmJWOC9xNzFGZEl6U09jaWNjZkNGSHBzS09vM1N0L3FiTFZ5dEg1YW9oYmNhYkZYUk5zS0VxdmUKd3c5SGRGeEJJdUdhK1J1VDVxMGlCaWt1c2JwSkhBd25ucVA3aS9kQWNnQ3NrZ2paakZlRVU0RUZ5K2IrYTFTWQpRQ2VGeHhDN2MzRHZhUmhCQjBWVmZQbGtQejBzdzZsODY1TWFUSWJSeW9VQ0F3RUFBYU15TURBd0NRWURWUjBUCkJBSXdBREFqQmdOVkhSRUVIREFhZ2d3cUxtSmhaSE56YkM1amIyMkNDbUpoWkhOemJDNWpiMjB3RFFZSktvWkkKaHZjTkFRRUxCUUFEZ2dFQkFJaTV1OXc4bWdUNnBwQ2M3eHNHK0E5ZkkzVzR6K3FTS2FwaHI1bHM3MEdCS2JpWQpZTEpVWVpoUGZXcGgxcXRra1UwTEhGUG04M1ZhNTJlSUhyalhUMFZlNEt0TzFuMElBZkl0RmFXNjJDSmdoR1luCmp6dzByeXpnQzRQeUZwTk1uTnRCcm9QdS9iUGdXaU1nTE9OcEVaaGlneDRROHdmMVkvVTlzK3pDQ3hvSmxhS1IKTVhidVE4N1g3bS85VlJueHhvNk56NVpmN09USFRwTk9JNlZqYTBCeGJtSUFVNnlyaXc5VXJnaWJYZk9qM2o2bgpNVExCdWdVVklCMGJCYWFzSnNBTUsrdzRMQU52YXBlWjBET1NuT1I0S0syNEowT3lvRjVmSG1wNTllTTE3SW9GClFxQmh6cG1RVWd1bmVjRVc4QlRxck5wRzc5UjF1K1YrNHd3Y2tQYz0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="}}
badCustomCAPolicies := testOptions(t)
badCustomCAPolicies.Policies = []config.Policy{
{To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "=@@"},
}
badCustomCAPolicies.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSCustomCA: "=@@"}}
goodClientCertPolicies := testOptions(t)
goodClientCertPolicies.Policies = []config.Policy{
{To: "http://foo.example", From: "http://bar.example",
TLSClientKey: "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", TLSClientCert: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="},
}
goodClientCertPolicies.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSClientKey: "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", TLSClientCert: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="}}
goodClientCertPolicies.Validate()
customServerName := testOptions(t)
customServerName.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", TLSServerName: "test"}}
emptyPolicies := testOptions(t)
emptyPolicies.Policies = nil
allowWebSockets := testOptions(t)
allowWebSockets.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", AllowWebsockets: true}}
customTimeout := testOptions(t)
customTimeout.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", UpstreamTimeout: 10 * time.Second}}
corsPreflight := testOptions(t)
corsPreflight.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", CORSAllowPreflight: true}}
disableAuth := testOptions(t)
disableAuth.Policies = []config.Policy{{To: "http://foo.example", From: "http://bar.example", AllowPublicUnauthenticatedAccess: true}}
tests := []struct {
name string
originalOptions config.Options
@ -301,12 +195,18 @@ func Test_UpdateOptions(t *testing.T) {
{"changed", good, newPolicies, "", "https://bar.example", false, true},
{"changed and missing", good, newPolicies, "", "https://corp.example.example", false, false},
{"bad signing key", good, newPolicies, "^bad base 64", "https://corp.example.example", true, false},
{"good signing key", good, newPolicies, "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSU0zbXBaSVdYQ1g5eUVneFU2czU3Q2J0YlVOREJTQ0VBdFFGNWZVV0hwY1FvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFaFBRditMQUNQVk5tQlRLMHhTVHpicEVQa1JyazFlVXQxQk9hMzJTRWZVUHpOaTRJV2VaLwpLS0lUdDJxMUlxcFYyS01TYlZEeXI5aWp2L1hoOThpeUV3PT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQ==", "https://corp.example.example", false, true},
{"bad change bad policy url", good, badNewPolicy, "", "https://bar.example", true, false},
{"disable tls verification", good, disableTLSPolicies, "", "https://bar.example", false, true},
{"custom root ca", good, customCAPolicies, "", "https://bar.example", false, true},
{"bad custom root ca base64", good, badCustomCAPolicies, "", "https://bar.example", true, false},
{"good client certs", good, goodClientCertPolicies, "", "https://bar.example", false, true},
{"custom server name", customServerName, customServerName, "", "https://bar.example", false, true},
{"good no policies to start", emptyPolicies, good, "", "https://corp.example.example", false, true},
{"allow websockets", good, allowWebSockets, "", "https://corp.example.example", false, true},
{"no websockets, custom timeout", good, customTimeout, "", "https://corp.example.example", false, true},
{"enable cors preflight", good, corsPreflight, "", "https://corp.example.example", false, true},
{"disable auth", good, disableAuth, "", "https://corp.example.example", false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -324,9 +224,10 @@ func Test_UpdateOptions(t *testing.T) {
// This is only safe if we actually can load policies
if err == nil {
req := httptest.NewRequest("GET", tt.host, nil)
_, ok := p.router(req)
if ok != tt.wantRoute {
r := httptest.NewRequest("GET", tt.host, nil)
w := httptest.NewRecorder()
p.Handler.ServeHTTP(w, r)
if tt.wantRoute && w.Code != http.StatusNotFound {
t.Errorf("Failed to find route handler")
return
}