mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-07 05:12:45 +02:00
authorize: allow CORS preflight requests (#672)
* proxy: implement preserve host header option * authorize: allow CORS preflight requests
This commit is contained in:
parent
d92ee8d2a0
commit
98d2f194a0
6 changed files with 42 additions and 85 deletions
|
@ -11,6 +11,16 @@ allow {
|
||||||
route_policies[route].AllowPublicUnauthenticatedAccess == true
|
route_policies[route].AllowPublicUnauthenticatedAccess == true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# allow cors preflight
|
||||||
|
allow {
|
||||||
|
route := first_allowed_route(input.url)
|
||||||
|
route_policies[route].CORSAllowPreflight == true
|
||||||
|
input.method == "OPTIONS"
|
||||||
|
count(object.get(input.headers, "Access-Control-Request-Method", [])) > 0
|
||||||
|
count(object.get(input.headers, "Origin", [])) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# allow by email
|
# allow by email
|
||||||
allow {
|
allow {
|
||||||
route := first_allowed_route(input.url)
|
route := first_allowed_route(input.url)
|
||||||
|
@ -62,7 +72,6 @@ allow {
|
||||||
token.valid
|
token.valid
|
||||||
count(deny)==0
|
count(deny)==0
|
||||||
}
|
}
|
||||||
|
|
||||||
# allow pomerium urls
|
# allow pomerium urls
|
||||||
allow {
|
allow {
|
||||||
contains(input.url, "/.pomerium/")
|
contains(input.url, "/.pomerium/")
|
||||||
|
|
|
@ -110,6 +110,36 @@ test_pomerium_denied {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test_cors_preflight_allowed {
|
||||||
|
allow with data.route_policies as [{
|
||||||
|
"source": "example.com",
|
||||||
|
"allowed_users": ["bob@example.com"],
|
||||||
|
"CORSAllowPreflight": true
|
||||||
|
}] with input as {
|
||||||
|
"url": "http://example.com/",
|
||||||
|
"host": "example.com",
|
||||||
|
"method": "OPTIONS",
|
||||||
|
"headers": {
|
||||||
|
"Origin": ["someorigin"],
|
||||||
|
"Access-Control-Request-Method": ["GET"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
test_cors_preflight_denied {
|
||||||
|
not allow with data.route_policies as [{
|
||||||
|
"source": "example.com",
|
||||||
|
"allowed_users": ["bob@example.com"]
|
||||||
|
}] with input as {
|
||||||
|
"url": "http://example.com/",
|
||||||
|
"host": "example.com",
|
||||||
|
"method": "OPTIONS",
|
||||||
|
"headers": {
|
||||||
|
"Origin": ["someorigin"],
|
||||||
|
"Access-Control-Request-Method": ["GET"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
test_parse_url {
|
test_parse_url {
|
||||||
url := parse_url("http://example.com/some/path?qs")
|
url := parse_url("http://example.com/some/path?qs")
|
||||||
url.scheme == "http"
|
url.scheme == "http"
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -55,6 +55,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe
|
||||||
// request
|
// request
|
||||||
evt = evt.Str("request-id", hattrs.GetId())
|
evt = evt.Str("request-id", hattrs.GetId())
|
||||||
evt = evt.Str("method", hattrs.GetMethod())
|
evt = evt.Str("method", hattrs.GetMethod())
|
||||||
|
evt = evt.Interface("headers", hdrs)
|
||||||
evt = evt.Str("path", hattrs.GetPath())
|
evt = evt.Str("path", hattrs.GetPath())
|
||||||
evt = evt.Str("host", hattrs.GetHost())
|
evt = evt.Str("host", hattrs.GetHost())
|
||||||
evt = evt.Str("query", hattrs.GetQuery())
|
evt = evt.Str("query", hattrs.GetQuery())
|
||||||
|
|
|
@ -1,26 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CorsBypass is middleware that takes a target handler as a paramater,
|
|
||||||
// if the request is determined to be a CORS preflight request, that handler
|
|
||||||
// is called instead of the normal handler chain.
|
|
||||||
func CorsBypass(target http.Handler) func(next http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx, span := trace.StartSpan(r.Context(), "middleware.CorsBypass")
|
|
||||||
defer span.End()
|
|
||||||
if r.Method == http.MethodOptions &&
|
|
||||||
r.Header.Get("Access-Control-Request-Method") != "" &&
|
|
||||||
r.Header.Get("Origin") != "" {
|
|
||||||
target.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,56 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func someOtherMiddleware(s string) func(next http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Some-Other-Middleware", s)
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func TestCorsBypass(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
method string
|
|
||||||
header http.Header
|
|
||||||
wantStatus int
|
|
||||||
wantHeader string
|
|
||||||
}{
|
|
||||||
{"good", http.MethodOptions, http.Header{"Access-Control-Request-Method": []string{"GET"}, "Origin": []string{"localhost"}}, 200, ""},
|
|
||||||
{"invalid cors - non options request", http.MethodGet, http.Header{"Access-Control-Request-Method": []string{"GET"}, "Origin": []string{"localhost"}}, 200, "BAD"},
|
|
||||||
{"invalid cors - Origin not set", http.MethodOptions, http.Header{"Access-Control-Request-Method": []string{"GET"}, "Origin": []string{""}}, 200, "BAD"},
|
|
||||||
{"invalid cors - Access-Control-Request-Method not set", http.MethodOptions, http.Header{"Access-Control-Request-Method": []string{""}, "Origin": []string{"*"}}, 200, "BAD"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r := &http.Request{
|
|
||||||
Method: tt.method,
|
|
||||||
Header: tt.header,
|
|
||||||
}
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
target := fn
|
|
||||||
got := CorsBypass(target)(someOtherMiddleware("BAD")(target))
|
|
||||||
got.ServeHTTP(w, r)
|
|
||||||
if status := w.Code; status != tt.wantStatus {
|
|
||||||
t.Errorf("TestCorsBypass() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
|
||||||
}
|
|
||||||
if header := w.Header().Get("Some-Other-Middleware"); tt.wantHeader != header {
|
|
||||||
t.Errorf("TestCorsBypass() header = %v, want %v", header, tt.wantHeader)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Add table
Add a link
Reference in a new issue