mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-24 14:37:12 +02:00
add cors_allow_preflight option to route policy
This commit is contained in:
parent
c18f7d89ae
commit
45bb2e0a4d
8 changed files with 146 additions and 58 deletions
|
@ -117,10 +117,10 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
|
|||
}
|
||||
|
||||
// ValidateHost ensures that each request's host is valid
|
||||
func ValidateHost(mux map[string]http.Handler) func(next http.Handler) http.Handler {
|
||||
func ValidateHost(validHost func(host string) bool) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, ok := mux[r.Host]; !ok {
|
||||
if !validHost(r.Host) {
|
||||
httputil.ErrorResponse(w, r, "Unknown route", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -292,18 +292,22 @@ func handlerHelp(msg string) http.Handler {
|
|||
return &handlerHelper{msg}
|
||||
}
|
||||
func TestValidateHost(t *testing.T) {
|
||||
m := make(map[string]http.Handler)
|
||||
m["google.com"] = handlerHelp("google")
|
||||
validHostFunc := func(host string) bool {
|
||||
return host == "google.com"
|
||||
}
|
||||
|
||||
validHostHandler := handlerHelp("google")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
validHosts map[string]http.Handler
|
||||
isValidHost func(string) bool
|
||||
validHostHandler http.Handler
|
||||
clientPath string
|
||||
expected []byte
|
||||
status int
|
||||
}{
|
||||
{"good", m, "google.com", []byte("google"), 200},
|
||||
{"no route", m, "googles.com", []byte("google"), 404},
|
||||
{"good", validHostFunc, validHostHandler, "google.com", []byte("google"), 200},
|
||||
{"no route", validHostFunc, validHostHandler, "googles.com", []byte("google"), 404},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -315,13 +319,13 @@ func TestValidateHost(t *testing.T) {
|
|||
rr := httptest.NewRecorder()
|
||||
|
||||
var testHandler http.Handler
|
||||
if tt.validHosts[tt.clientPath] != nil {
|
||||
tt.validHosts[tt.clientPath].ServeHTTP(rr, req)
|
||||
testHandler = tt.validHosts[tt.clientPath]
|
||||
if tt.isValidHost(tt.clientPath) {
|
||||
tt.validHostHandler.ServeHTTP(rr, req)
|
||||
testHandler = tt.validHostHandler
|
||||
} else {
|
||||
testHandler = handlerHelp("ok")
|
||||
}
|
||||
handler := ValidateHost(tt.validHosts)(testHandler)
|
||||
handler := ValidateHost(tt.isValidHost)(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.status {
|
||||
|
|
|
@ -26,6 +26,10 @@ type Policy struct {
|
|||
|
||||
Source *url.URL
|
||||
Destination *url.URL
|
||||
|
||||
// Allow unauthenticated HTTP OPTIONS requests as per the CORS spec
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#Preflighted_requests
|
||||
CORSAllowPreflight bool `yaml:"cors_allow_preflight"`
|
||||
}
|
||||
|
||||
func (p *Policy) validate() (err error) {
|
||||
|
|
|
@ -17,3 +17,8 @@
|
|||
to: http://hello:8080
|
||||
allowed_groups:
|
||||
- admins
|
||||
- from: cross-origin.corp.beyondperimeter.com
|
||||
to: httpbin.org
|
||||
allowed_domains:
|
||||
- gmail.com
|
||||
cors_allow_preflight: true
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/policy"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
@ -68,7 +69,10 @@ func (p *Proxy) Handler() http.Handler {
|
|||
c = c.Append(middleware.UserAgentHandler("user_agent"))
|
||||
c = c.Append(middleware.RefererHandler("referer"))
|
||||
c = c.Append(middleware.RequestIDHandler("req_id", "Request-Id"))
|
||||
c = c.Append(middleware.ValidateHost(p.mux))
|
||||
c = c.Append(middleware.ValidateHost(func(host string) bool {
|
||||
_, ok := p.routeConfigs[host]
|
||||
return ok
|
||||
}))
|
||||
|
||||
// serve the middleware and mux
|
||||
return c.Then(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -223,9 +227,33 @@ func (p *Proxy) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|||
http.Redirect(w, r, stateParameter.RedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// shouldSkipAuthentication contains conditions for skipping authentication.
|
||||
// Conditions should be few in number and have strong justifications.
|
||||
func (p *Proxy) shouldSkipAuthentication(r *http.Request) bool {
|
||||
pol, foundPolicy := p.policy(r)
|
||||
|
||||
if isCORSPreflight(r) && foundPolicy && pol.CORSAllowPreflight {
|
||||
log.FromRequest(r).Debug().Msg("proxy: skipping authentication for valid CORS preflight request")
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isCORSPreflight inspects the request to see if this is a valid CORS preflight request.
|
||||
// These checks are not exhaustive, because the proxied server should be verifying it as well.
|
||||
//
|
||||
// See https://www.html5rocks.com/static/images/cors_server_flowchart.png for more info.
|
||||
func isCORSPreflight(r *http.Request) bool {
|
||||
return r.Method == http.MethodOptions &&
|
||||
r.Header.Get("Access-Control-Request-Method") != "" &&
|
||||
r.Header.Get("Origin") != ""
|
||||
}
|
||||
|
||||
// Proxy authenticates a request, either proxying the request if it is authenticated,
|
||||
// or starting the authenticate service for validation if not.
|
||||
func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
||||
if !p.shouldSkipAuthentication(r) {
|
||||
err := p.Authenticate(w, r)
|
||||
// If the authenticate is not successful we proceed to start the OAuth Flow with
|
||||
// OAuthStart. If successful, we proceed to proxy to the configured upstream.
|
||||
|
@ -253,6 +281,8 @@ func (p *Proxy) Proxy(w http.ResponseWriter, r *http.Request) {
|
|||
httputil.ErrorResponse(w, r, "Access unauthorized", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// We have validated the users request and now proxy their request to the provided upstream.
|
||||
route, ok := p.router(r)
|
||||
if !ok {
|
||||
|
@ -325,17 +355,31 @@ func (p *Proxy) Authenticate(w http.ResponseWriter, r *http.Request) (err error)
|
|||
}
|
||||
|
||||
// Handle constructs a route from the given host string and matches it to the provided http.Handler and UpstreamConfig
|
||||
func (p *Proxy) Handle(host string, handler http.Handler) {
|
||||
p.mux[host] = handler
|
||||
func (p *Proxy) Handle(host string, handler http.Handler, pol *policy.Policy) {
|
||||
p.routeConfigs[host] = &routeConfig{
|
||||
mux: handler,
|
||||
policy: pol,
|
||||
}
|
||||
}
|
||||
|
||||
// router attempts to find a route for a request. If a route is successfully matched,
|
||||
// it returns the route information and a bool value of `true`. If a route can not be matched,
|
||||
// a nil value for the route and false bool value is returned.
|
||||
func (p *Proxy) router(r *http.Request) (http.Handler, bool) {
|
||||
route, ok := p.mux[r.Host]
|
||||
config, ok := p.routeConfigs[r.Host]
|
||||
if ok {
|
||||
return route, true
|
||||
return config.mux, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// policy attempts to find a policy for a request. If a policy is successfully matched,
|
||||
// it returns the policy information and a bool value of `true`. If a policy can not be matched,
|
||||
// a nil value for the policy and false bool value is returned.
|
||||
func (p *Proxy) policy(r *http.Request) (*policy.Policy, bool) {
|
||||
config, ok := p.routeConfigs[r.Host]
|
||||
if ok {
|
||||
return config.policy, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
|
|
@ -352,8 +352,22 @@ func TestProxy_Proxy(t *testing.T) {
|
|||
RefreshDeadline: time.Now().Add(10 * time.Second),
|
||||
}
|
||||
|
||||
opts := testOptions()
|
||||
optsCORS := testOptionsWithCORS()
|
||||
|
||||
defaultHeaders, goodCORSHeaders, badCORSHeaders := http.Header{}, http.Header{}, http.Header{}
|
||||
|
||||
goodCORSHeaders.Set("origin", "anything")
|
||||
goodCORSHeaders.Set("access-control-request-method", "anything")
|
||||
|
||||
// missing "Origin"
|
||||
badCORSHeaders.Set("access-control-request-method", "anything")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
options *Options
|
||||
method string
|
||||
header http.Header
|
||||
host string
|
||||
session sessions.SessionStore
|
||||
authenticator clients.Authenticator
|
||||
|
@ -361,16 +375,20 @@ func TestProxy_Proxy(t *testing.T) {
|
|||
wantStatus int
|
||||
}{
|
||||
// weirdly, we want 503 here because that means proxy is trying to route a domain (example.com) that we dont control. Weird. I know.
|
||||
{"good", "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusServiceUnavailable},
|
||||
{"unexpected error", "https://corp.example.com/test", &sessions.MockSessionStore{LoadError: errors.New("ok")}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusInternalServerError},
|
||||
{"good", opts, http.MethodGet, defaultHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusServiceUnavailable},
|
||||
{"good cors preflight", optsCORS, http.MethodOptions, goodCORSHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusServiceUnavailable},
|
||||
// same request as above, but with cors_allow_preflight=false in the policy
|
||||
{"valid cors, but not allowed", opts, http.MethodOptions, goodCORSHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
// cors allowed, but the request is missing proper headers
|
||||
{"invalid cors headers", optsCORS, http.MethodOptions, badCORSHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
{"unexpected error", opts, http.MethodGet, defaultHeaders, "https://corp.example.com/test", &sessions.MockSessionStore{LoadError: errors.New("ok")}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusInternalServerError},
|
||||
// redirect to start auth process
|
||||
{"unknown host", "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
||||
{"user forbidden", "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
{"unknown host", opts, http.MethodGet, defaultHeaders, "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusNotFound},
|
||||
{"user forbidden", opts, http.MethodGet, defaultHeaders, "https://notcorp.example.com/test", &sessions.MockSessionStore{Session: goodSession}, clients.MockAuthenticate{}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
opts := testOptions()
|
||||
p, err := New(opts)
|
||||
p, err := New(tt.options)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -379,7 +397,8 @@ func TestProxy_Proxy(t *testing.T) {
|
|||
p.AuthenticateClient = tt.authenticator
|
||||
p.AuthorizeClient = tt.authorizer
|
||||
|
||||
r := httptest.NewRequest("GET", tt.host, nil)
|
||||
r := httptest.NewRequest(tt.method, tt.host, nil)
|
||||
r.Header = tt.header
|
||||
w := httptest.NewRecorder()
|
||||
p.Proxy(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
|
|
|
@ -174,7 +174,12 @@ type Proxy struct {
|
|||
|
||||
redirectURL *url.URL
|
||||
templates *template.Template
|
||||
mux map[string]http.Handler
|
||||
routeConfigs map[string]*routeConfig
|
||||
}
|
||||
|
||||
type routeConfig struct {
|
||||
mux http.Handler
|
||||
policy *policy.Policy
|
||||
}
|
||||
|
||||
// New takes a Proxy service from options and a validation function.
|
||||
|
@ -208,7 +213,7 @@ func New(opts *Options) (*Proxy, error) {
|
|||
}
|
||||
|
||||
p := &Proxy{
|
||||
mux: make(map[string]http.Handler),
|
||||
routeConfigs: make(map[string]*routeConfig),
|
||||
// services
|
||||
AuthenticateURL: opts.AuthenticateURL,
|
||||
// session state
|
||||
|
@ -232,7 +237,7 @@ func New(opts *Options) (*Proxy, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Handle(route.Source.Host, handler)
|
||||
p.Handle(route.Source.Host, handler, &route)
|
||||
log.Info().Str("src", route.Source.Host).Str("dst", route.Destination.Host).Msg("proxy: new route")
|
||||
}
|
||||
|
||||
|
|
|
@ -127,6 +127,13 @@ func testOptions() *Options {
|
|||
}
|
||||
}
|
||||
|
||||
func testOptionsWithCORS() *Options {
|
||||
configBlob := `[{"from":"corp.example.com","to":"example.com","cors_allow_preflight":true}]` //valid yaml
|
||||
opts := testOptions()
|
||||
opts.Policy = base64.URLEncoding.EncodeToString([]byte(configBlob))
|
||||
return opts
|
||||
}
|
||||
|
||||
func TestOptions_Validate(t *testing.T) {
|
||||
good := testOptions()
|
||||
badFromRoute := testOptions()
|
||||
|
@ -204,7 +211,7 @@ func TestNew(t *testing.T) {
|
|||
opts *Options
|
||||
optFuncs []func(*Proxy) error
|
||||
wantProxy bool
|
||||
numMuxes int
|
||||
numRoutes int
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", good, nil, true, 1, false},
|
||||
|
@ -223,8 +230,8 @@ func TestNew(t *testing.T) {
|
|||
if got == nil && tt.wantProxy == true {
|
||||
t.Errorf("New() expected valid proxy struct")
|
||||
}
|
||||
if got != nil && len(got.mux) != tt.numMuxes {
|
||||
t.Errorf("New() = num muxes \n%+v, want \n%+v", got, tt.numMuxes)
|
||||
if got != nil && len(got.routeConfigs) != tt.numRoutes {
|
||||
t.Errorf("New() = num routeConfigs \n%+v, want \n%+v", got, tt.numRoutes)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue