diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 1e845fb4e..2c1f2e2ef 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -117,10 +117,10 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler } // ValidateHost ensures that each request's host is valid -func ValidateHost(mux map[string]http.Handler) func(next http.Handler) http.Handler { +func ValidateHost(validHost func(host string) bool) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if _, ok := mux[r.Host]; !ok { + if !validHost(r.Host) { httputil.ErrorResponse(w, r, "Unknown route", http.StatusNotFound) return } diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index 1ba0f841b..3f859722f 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -292,18 +292,22 @@ func handlerHelp(msg string) http.Handler { return &handlerHelper{msg} } func TestValidateHost(t *testing.T) { - m := make(map[string]http.Handler) - m["google.com"] = handlerHelp("google") + validHostFunc := func(host string) bool { + return host == "google.com" + } + + validHostHandler := handlerHelp("google") tests := []struct { - name string - validHosts map[string]http.Handler - clientPath string - expected []byte - status int + name string + isValidHost func(string) bool + validHostHandler http.Handler + clientPath string + expected []byte + status int }{ - {"good", m, "google.com", []byte("google"), 200}, - {"no route", m, "googles.com", []byte("google"), 404}, + {"good", validHostFunc, validHostHandler, "google.com", []byte("google"), 200}, + {"no route", validHostFunc, validHostHandler, "googles.com", []byte("google"), 404}, } for _, tt := range tests { @@ -315,13 +319,13 @@ func TestValidateHost(t *testing.T) { rr := httptest.NewRecorder() var testHandler http.Handler - if tt.validHosts[tt.clientPath] != nil { - tt.validHosts[tt.clientPath].ServeHTTP(rr, req) - testHandler = tt.validHosts[tt.clientPath] + if tt.isValidHost(tt.clientPath) { + tt.validHostHandler.ServeHTTP(rr, req) + testHandler = tt.validHostHandler } else { testHandler = handlerHelp("ok") } - handler := ValidateHost(tt.validHosts)(testHandler) + handler := ValidateHost(tt.isValidHost)(testHandler) handler.ServeHTTP(rr, req) if rr.Code != tt.status { diff --git a/internal/policy/policy.go b/internal/policy/policy.go index b5f16e5d1..52672d123 100644 --- a/internal/policy/policy.go +++ b/internal/policy/policy.go @@ -26,6 +26,10 @@ type Policy struct { Source *url.URL Destination *url.URL + + // Allow unauthenticated HTTP OPTIONS requests as per the CORS spec + // https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#Preflighted_requests + CORSAllowPreflight bool `yaml:"cors_allow_preflight"` } func (p *Policy) validate() (err error) { diff --git a/policy.example.yaml b/policy.example.yaml index 210e14fcc..ec1aa7e88 100644 --- a/policy.example.yaml +++ b/policy.example.yaml @@ -17,3 +17,8 @@ to: http://hello:8080 allowed_groups: - admins +- from: cross-origin.corp.beyondperimeter.com + to: httpbin.org + allowed_domains: + - gmail.com + cors_allow_preflight: true diff --git a/proxy/handlers.go b/proxy/handlers.go index d50f66e32..b3fc26512 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -14,6 +14,7 @@ import ( "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/middleware" + "github.com/pomerium/pomerium/internal/policy" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/version" ) @@ -68,7 +69,10 @@ func (p *Proxy) Handler() http.Handler { c = c.Append(middleware.UserAgentHandler("user_agent")) c = c.Append(middleware.RefererHandler("referer")) c = c.Append(middleware.RequestIDHandler("req_id", "Request-Id")) - c = c.Append(middleware.ValidateHost(p.mux)) + c = c.Append(middleware.ValidateHost(func(host string) bool { + _, ok := p.routeConfigs[host] + return ok + })) // serve the middleware and mux return c.Then(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -223,36 +227,62 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, stateParameter.RedirectURI, http.StatusFound) } +// shouldSkipAuthentication contains conditions for skipping authentication. +// Conditions should be few in number and have strong justifications. +func (p *Proxy) shouldSkipAuthentication(r *http.Request) bool { + pol, foundPolicy := p.policy(r) + + if isCORSPreflight(r) && foundPolicy && pol.CORSAllowPreflight { + log.FromRequest(r).Debug().Msg("proxy: skipping authentication for valid CORS preflight request") + return true + } + + return false +} + +// isCORSPreflight inspects the request to see if this is a valid CORS preflight request. +// These checks are not exhaustive, because the proxied server should be verifying it as well. +// +// See https://www.html5rocks.com/static/images/cors_server_flowchart.png for more info. +func isCORSPreflight(r *http.Request) bool { + return r.Method == http.MethodOptions && + r.Header.Get("Access-Control-Request-Method") != "" && + r.Header.Get("Origin") != "" +} + // Proxy authenticates a request, either proxying the request if it is authenticated, // or starting the authenticate service for validation if not. func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) { - err := p.Authenticate(w, r) - // If the authenticate is not successful we proceed to start the OAuth Flow with - // OAuthStart. If successful, we proceed to proxy to the configured upstream. - if err != nil { - switch err { - case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession: - log.FromRequest(r).Debug().Err(err).Msg("proxy: starting auth flow") - p.OAuthStart(w, r) + if !p.shouldSkipAuthentication(r) { + err := p.Authenticate(w, r) + // If the authenticate is not successful we proceed to start the OAuth Flow with + // OAuthStart. If successful, we proceed to proxy to the configured upstream. + if err != nil { + switch err { + case http.ErrNoCookie, sessions.ErrLifetimeExpired, sessions.ErrInvalidSession: + log.FromRequest(r).Debug().Err(err).Msg("proxy: starting auth flow") + p.OAuthStart(w, r) + return + default: + log.FromRequest(r).Error().Err(err).Msg("proxy: unexpected error") + httputil.ErrorResponse(w, r, "An unexpected error occurred", http.StatusInternalServerError) + return + } + } + // remove dupe session call + session, err := p.sessionStore.LoadSession(r) + if err != nil { + p.sessionStore.ClearSession(w, r) return - default: - log.FromRequest(r).Error().Err(err).Msg("proxy: unexpected error") - httputil.ErrorResponse(w, r, "An unexpected error occurred", http.StatusInternalServerError) + } + authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, session) + if !authorized || err != nil { + log.FromRequest(r).Warn().Err(err).Msg("proxy: user unauthorized") + httputil.ErrorResponse(w, r, "Access unauthorized", http.StatusForbidden) return } } - // remove dupe session call - session, err := p.sessionStore.LoadSession(r) - if err != nil { - p.sessionStore.ClearSession(w, r) - return - } - authorized, err := p.AuthorizeClient.Authorize(r.Context(), r.Host, session) - if !authorized || err != nil { - log.FromRequest(r).Warn().Err(err).Msg("proxy: user unauthorized") - httputil.ErrorResponse(w, r, "Access unauthorized", http.StatusForbidden) - return - } + // We have validated the users request and now proxy their request to the provided upstream. route, ok := p.router(r) if !ok { @@ -325,17 +355,31 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error) } // Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig -func (p *Proxy) Handle(host string, handler http.Handler) { - p.mux[host] = handler +func (p *Proxy) Handle(host string, handler http.Handler, pol *policy.Policy) { + p.routeConfigs[host] = &routeConfig{ + mux: handler, + policy: pol, + } } // router attempts to find a route for a request. If a route is successfully matched, // it returns the route information and a bool value of `true`. If a route can not be matched, // a nil value for the route and false bool value is returned. func (p *Proxy) router(r *http.Request) (http.Handler, bool) { - route, ok := p.mux[r.Host] + config, ok := p.routeConfigs[r.Host] if ok { - return route, true + return config.mux, true + } + return nil, false +} + +// policy attempts to find a policy for a request. If a policy is successfully matched, +// it returns the policy information and a bool value of `true`. If a policy can not be matched, +// a nil value for the policy and false bool value is returned. +func (p *Proxy) policy(r *http.Request) (*policy.Policy, bool) { + config, ok := p.routeConfigs[r.Host] + if ok { + return config.policy, true } return nil, false } diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index e4995da90..bc1be23fe 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -352,8 +352,22 @@ func TestProxy_Proxy(t *testing.T) { RefreshDeadline: time.Now().Add(10 * time.Second), } + opts := testOptions() + optsCORS := testOptionsWithCORS() + + defaultHeaders, goodCORSHeaders, badCORSHeaders := http.Header{}, http.Header{}, http.Header{} + + goodCORSHeaders.Set("origin", "anything") + goodCORSHeaders.Set("access-control-request-method", "anything") + + // missing "Origin" + badCORSHeaders.Set("access-control-request-method", "anything") + tests := []struct { name string + options *Options + method string + header http.Header host string session sessions.SessionStore authenticator clients.Authenticator @@ -361,16 +375,20 @@ func TestProxy_Proxy(t *testing.T) { wantStatus int }{ // weirdly, we want 503 here because that means proxy is trying to route a domain (example.com) that we dont control. Weird. I know. - {"good", "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusServiceUnavailable}, - {"unexpected error", "https://corp.example.com/test", &sessions.MockSessionStore{LoadError: errors.New("ok")}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusInternalServerError}, + {"good", opts, http.MethodGet, defaultHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusServiceUnavailable}, + {"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusServiceUnavailable}, + // same request as above, but with cors_allow_preflight=false in the policy + {"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, + // cors allowed, but the request is missing proper headers + {"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, + {"unexpected error", opts, http.MethodGet, defaultHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{LoadError: errors.New("ok")}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusInternalServerError}, // redirect to start auth process - {"unknown host", "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, - {"user forbidden", "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, + {"unknown host", opts, http.MethodGet, defaultHeaders, "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound}, + {"user forbidden", opts, http.MethodGet, defaultHeaders, "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - opts := testOptions() - p, err := New(opts) + p, err := New(tt.options) if err != nil { t.Fatal(err) } @@ -379,7 +397,8 @@ func TestProxy_Proxy(t *testing.T) { p.AuthenticateClient = tt.authenticator p.AuthorizeClient = tt.authorizer - r := httptest.NewRequest("GET", tt.host, nil) + r := httptest.NewRequest(tt.method, tt.host, nil) + r.Header = tt.header w := httptest.NewRecorder() p.Proxy(w, r) if status := w.Code; status != tt.wantStatus { diff --git a/proxy/proxy.go b/proxy/proxy.go index 8cbff7670..e7cbda2a0 100755 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -172,9 +172,14 @@ type Proxy struct { csrfStore sessions.CSRFStore sessionStore sessions.SessionStore - redirectURL *url.URL - templates *template.Template - mux map[string]http.Handler + redirectURL *url.URL + templates *template.Template + routeConfigs map[string]*routeConfig +} + +type routeConfig struct { + mux http.Handler + policy *policy.Policy } // New takes a Proxy service from options and a validation function. @@ -208,7 +213,7 @@ func New(opts *Options) (*Proxy, error) { } p := &Proxy{ - mux: make(map[string]http.Handler), + routeConfigs: make(map[string]*routeConfig), // services AuthenticateURL: opts.AuthenticateURL, // session state @@ -232,7 +237,7 @@ func New(opts *Options) (*Proxy, error) { if err != nil { return nil, err } - p.Handle(route.Source.Host, handler) + p.Handle(route.Source.Host, handler, &route) log.Info().Str("src", route.Source.Host).Str("dst", route.Destination.Host).Msg("proxy: new route") } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index bb3638eee..9cd511300 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -127,6 +127,13 @@ func testOptions() *Options { } } +func testOptionsWithCORS() *Options { + configBlob := `[{"from":"corp.example.com","to":"example.com","cors_allow_preflight":true}]` //valid yaml + opts := testOptions() + opts.Policy = base64.URLEncoding.EncodeToString([]byte(configBlob)) + return opts +} + func TestOptions_Validate(t *testing.T) { good := testOptions() badFromRoute := testOptions() @@ -204,7 +211,7 @@ func TestNew(t *testing.T) { opts *Options optFuncs []func(*Proxy) error wantProxy bool - numMuxes int + numRoutes int wantErr bool }{ {"good", good, nil, true, 1, false}, @@ -223,8 +230,8 @@ func TestNew(t *testing.T) { if got == nil && tt.wantProxy == true { t.Errorf("New() expected valid proxy struct") } - if got != nil && len(got.mux) != tt.numMuxes { - t.Errorf("New() = num muxes \n%+v, want \n%+v", got, tt.numMuxes) + if got != nil && len(got.routeConfigs) != tt.numRoutes { + t.Errorf("New() = num routeConfigs \n%+v, want \n%+v", got, tt.numRoutes) } }) }