add cors_allow_preflight option to route policy

This commit is contained in:
nitper 2019-04-08 15:46:14 -04:00
parent c18f7d89ae
commit 45bb2e0a4d
No known key found for this signature in database
GPG key ID: CD5CDA95578CC143
8 changed files with 146 additions and 58 deletions

View file

@ -117,10 +117,10 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
}
// ValidateHost ensures that each request's host is valid
func ValidateHost(mux map[string]http.Handler) func(next http.Handler) http.Handler {
func ValidateHost(validHost func(host string) bool) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, ok := mux[r.Host]; !ok {
if !validHost(r.Host) {
httputil.ErrorResponse(w, r, "Unknown route", http.StatusNotFound)
return
}

View file

@ -292,18 +292,22 @@ func handlerHelp(msg string) http.Handler {
return &handlerHelper{msg}
}
func TestValidateHost(t *testing.T) {
m := make(map[string]http.Handler)
m["google.com"] = handlerHelp("google")
validHostFunc := func(host string) bool {
return host == "google.com"
}
validHostHandler := handlerHelp("google")
tests := []struct {
name string
validHosts map[string]http.Handler
isValidHost func(string) bool
validHostHandler http.Handler
clientPath string
expected []byte
status int
}{
{"good", m, "google.com", []byte("google"), 200},
{"no route", m, "googles.com", []byte("google"), 404},
{"good", validHostFunc, validHostHandler, "google.com", []byte("google"), 200},
{"no route", validHostFunc, validHostHandler, "googles.com", []byte("google"), 404},
}
for _, tt := range tests {
@ -315,13 +319,13 @@ func TestValidateHost(t *testing.T) {
rr := httptest.NewRecorder()
var testHandler http.Handler
if tt.validHosts[tt.clientPath] != nil {
tt.validHosts[tt.clientPath].ServeHTTP(rr, req)
testHandler = tt.validHosts[tt.clientPath]
if tt.isValidHost(tt.clientPath) {
tt.validHostHandler.ServeHTTP(rr, req)
testHandler = tt.validHostHandler
} else {
testHandler = handlerHelp("ok")
}
handler := ValidateHost(tt.validHosts)(testHandler)
handler := ValidateHost(tt.isValidHost)(testHandler)
handler.ServeHTTP(rr, req)
if rr.Code != tt.status {

View file

@ -26,6 +26,10 @@ type Policy struct {
Source *url.URL
Destination *url.URL
// Allow unauthenticated HTTP OPTIONS requests as per the CORS spec
// https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#Preflighted_requests
CORSAllowPreflight bool `yaml:"cors_allow_preflight"`
}
func (p *Policy) validate() (err error) {

View file

@ -17,3 +17,8 @@
to: http://hello:8080
allowed_groups:
- admins
- from: cross-origin.corp.beyondperimeter.com
to: httpbin.org
allowed_domains:
- gmail.com
cors_allow_preflight: true

View file

@ -14,6 +14,7 @@ import (
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/policy"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/version"
)
@ -68,7 +69,10 @@ func (p *Proxy) Handler() http.Handler {
c = c.Append(middleware.UserAgentHandler("user_agent"))
c = c.Append(middleware.RefererHandler("referer"))
c = c.Append(middleware.RequestIDHandler("req_id", "Request-Id"))
c = c.Append(middleware.ValidateHost(p.mux))
c = c.Append(middleware.ValidateHost(func(host string) bool {
_, ok := p.routeConfigs[host]
return ok
}))
// serve the middleware and mux
return c.Then(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -223,9 +227,33 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, stateParameter.RedirectURI, http.StatusFound)
}
// shouldSkipAuthentication contains conditions for skipping authentication.
// Conditions should be few in number and have strong justifications.
func (p *Proxy) shouldSkipAuthentication(r *http.Request) bool {
pol, foundPolicy := p.policy(r)
if isCORSPreflight(r) && foundPolicy && pol.CORSAllowPreflight {
log.FromRequest(r).Debug().Msg("proxy: skipping authentication for valid CORS preflight request")
return true
}
return false
}
// isCORSPreflight inspects the request to see if this is a valid CORS preflight request.
// These checks are not exhaustive, because the proxied server should be verifying it as well.
//
// See https://www.html5rocks.com/static/images/cors_server_flowchart.png for more info.
func isCORSPreflight(r *http.Request) bool {
return r.Method == http.MethodOptions &&
r.Header.Get("Access-Control-Request-Method") != "" &&
r.Header.Get("Origin") != ""
}
// Proxy authenticates a request, either proxying the request if it is authenticated,
// or starting the authenticate service for validation if not.
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
if !p.shouldSkipAuthentication(r) {
err := p.Authenticate(w, r)
// If the authenticate is not successful we proceed to start the OAuth Flow with
// OAuthStart. If successful, we proceed to proxy to the configured upstream.
@ -253,6 +281,8 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, "Access unauthorized", http.StatusForbidden)
return
}
}
// We have validated the users request and now proxy their request to the provided upstream.
route, ok := p.router(r)
if !ok {
@ -325,17 +355,31 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
}
// Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig
func (p *Proxy) Handle(host string, handler http.Handler) {
p.mux[host] = handler
func (p *Proxy) Handle(host string, handler http.Handler, pol *policy.Policy) {
p.routeConfigs[host] = &routeConfig{
mux: handler,
policy: pol,
}
}
// router attempts to find a route for a request. If a route is successfully matched,
// it returns the route information and a bool value of `true`. If a route can not be matched,
// a nil value for the route and false bool value is returned.
func (p *Proxy) router(r *http.Request) (http.Handler, bool) {
route, ok := p.mux[r.Host]
config, ok := p.routeConfigs[r.Host]
if ok {
return route, true
return config.mux, true
}
return nil, false
}
// policy attempts to find a policy for a request. If a policy is successfully matched,
// it returns the policy information and a bool value of `true`. If a policy can not be matched,
// a nil value for the policy and false bool value is returned.
func (p *Proxy) policy(r *http.Request) (*policy.Policy, bool) {
config, ok := p.routeConfigs[r.Host]
if ok {
return config.policy, true
}
return nil, false
}

View file

@ -352,8 +352,22 @@ func TestProxy_Proxy(t *testing.T) {
RefreshDeadline: time.Now().Add(10 * time.Second),
}
opts := testOptions()
optsCORS := testOptionsWithCORS()
defaultHeaders, goodCORSHeaders, badCORSHeaders := http.Header{}, http.Header{}, http.Header{}
goodCORSHeaders.Set("origin", "anything")
goodCORSHeaders.Set("access-control-request-method", "anything")
// missing "Origin"
badCORSHeaders.Set("access-control-request-method", "anything")
tests := []struct {
name string
options *Options
method string
header http.Header
host string
session sessions.SessionStore
authenticator clients.Authenticator
@ -361,16 +375,20 @@ func TestProxy_Proxy(t *testing.T) {
wantStatus int
}{
// weirdly, we want 503 here because that means proxy is trying to route a domain (example.com) that we dont control. Weird. I know.
{"good", "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusServiceUnavailable},
{"unexpected error", "https://corp.example.com/test", &sessions.MockSessionStore{LoadError: errors.New("ok")}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusInternalServerError},
{"good", opts, http.MethodGet, defaultHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusServiceUnavailable},
{"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusServiceUnavailable},
// same request as above, but with cors_allow_preflight=false in the policy
{"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
// cors allowed, but the request is missing proper headers
{"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
{"unexpected error", opts, http.MethodGet, defaultHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{LoadError: errors.New("ok")}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusInternalServerError},
// redirect to start auth process
{"unknown host", "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
{"user forbidden", "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
{"unknown host", opts, http.MethodGet, defaultHeaders, "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
{"user forbidden", opts, http.MethodGet, defaultHeaders, "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := testOptions()
p, err := New(opts)
p, err := New(tt.options)
if err != nil {
t.Fatal(err)
}
@ -379,7 +397,8 @@ func TestProxy_Proxy(t *testing.T) {
p.AuthenticateClient = tt.authenticator
p.AuthorizeClient = tt.authorizer
r := httptest.NewRequest("GET", tt.host, nil)
r := httptest.NewRequest(tt.method, tt.host, nil)
r.Header = tt.header
w := httptest.NewRecorder()
p.Proxy(w, r)
if status := w.Code; status != tt.wantStatus {

View file

@ -174,7 +174,12 @@ type Proxy struct {
redirectURL *url.URL
templates *template.Template
mux map[string]http.Handler
routeConfigs map[string]*routeConfig
}
type routeConfig struct {
mux http.Handler
policy *policy.Policy
}
// New takes a Proxy service from options and a validation function.
@ -208,7 +213,7 @@ func New(opts *Options) (*Proxy, error) {
}
p := &Proxy{
mux: make(map[string]http.Handler),
routeConfigs: make(map[string]*routeConfig),
// services
AuthenticateURL: opts.AuthenticateURL,
// session state
@ -232,7 +237,7 @@ func New(opts *Options) (*Proxy, error) {
if err != nil {
return nil, err
}
p.Handle(route.Source.Host, handler)
p.Handle(route.Source.Host, handler, &route)
log.Info().Str("src", route.Source.Host).Str("dst", route.Destination.Host).Msg("proxy: new route")
}

View file

@ -127,6 +127,13 @@ func testOptions() *Options {
}
}
func testOptionsWithCORS() *Options {
configBlob := `[{"from":"corp.example.com","to":"example.com","cors_allow_preflight":true}]` //valid yaml
opts := testOptions()
opts.Policy = base64.URLEncoding.EncodeToString([]byte(configBlob))
return opts
}
func TestOptions_Validate(t *testing.T) {
good := testOptions()
badFromRoute := testOptions()
@ -204,7 +211,7 @@ func TestNew(t *testing.T) {
opts *Options
optFuncs []func(*Proxy) error
wantProxy bool
numMuxes int
numRoutes int
wantErr bool
}{
{"good", good, nil, true, 1, false},
@ -223,8 +230,8 @@ func TestNew(t *testing.T) {
if got == nil && tt.wantProxy == true {
t.Errorf("New() expected valid proxy struct")
}
if got != nil && len(got.mux) != tt.numMuxes {
t.Errorf("New() = num muxes \n%+v, want \n%+v", got, tt.numMuxes)
if got != nil && len(got.routeConfigs) != tt.numRoutes {
t.Errorf("New() = num routeConfigs \n%+v, want \n%+v", got, tt.numRoutes)
}
})
}