mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-04 01:09:36 +02:00
cmd/pomerium: move middleware for all http handlers to global context (#117)
This commit is contained in:
parent
04a653f694
commit
cfac5f10ff
9 changed files with 110 additions and 173 deletions
|
@ -30,7 +30,7 @@ func main() {
|
|||
fmt.Println(version.FullVersion())
|
||||
os.Exit(0)
|
||||
}
|
||||
opt, err := parseOptions()
|
||||
opt, err := optionsFromEnvConfig()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("cmd/pomerium: options")
|
||||
}
|
||||
|
@ -40,10 +40,6 @@ func main() {
|
|||
grpcServer := grpc.NewServer(grpcOpts...)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ping", func(rw http.ResponseWriter, _ *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(rw, version.UserAgent())
|
||||
})
|
||||
|
||||
_, err = newAuthenticateService(opt.Services, mux, grpcServer)
|
||||
if err != nil {
|
||||
|
@ -80,7 +76,8 @@ func main() {
|
|||
} else {
|
||||
defer srv.Close()
|
||||
}
|
||||
if err := https.ListenAndServeTLS(httpOpts, mux, grpcServer); err != nil {
|
||||
|
||||
if err := https.ListenAndServeTLS(httpOpts, wrapMiddleware(opt, mux), grpcServer); err != nil {
|
||||
log.Fatal().Err(err).Msg("cmd/pomerium: https server")
|
||||
}
|
||||
}
|
||||
|
@ -154,16 +151,31 @@ func newProxyService(s string, mux *http.ServeMux) (*proxy.Proxy, error) {
|
|||
return service, nil
|
||||
}
|
||||
|
||||
func parseOptions() (*Options, error) {
|
||||
o, err := optionsFromEnvConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func wrapMiddleware(o *Options, mux *http.ServeMux) http.Handler {
|
||||
c := middleware.NewChain()
|
||||
c = c.Append(middleware.NewHandler(log.Logger))
|
||||
c = c.Append(middleware.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
||||
middleware.FromRequest(r).Debug().
|
||||
Str("service", o.Services).
|
||||
Str("method", r.Method).
|
||||
Str("url", r.URL.String()).
|
||||
Int("status", status).
|
||||
Int("size", size).
|
||||
Dur("duration", duration).
|
||||
Str("user", r.Header.Get(proxy.HeaderUserID)).
|
||||
Str("email", r.Header.Get(proxy.HeaderEmail)).
|
||||
Str("group", r.Header.Get(proxy.HeaderGroups)).
|
||||
// Str("sig", r.Header.Get(proxy.HeaderJWT)).
|
||||
Msg("http-request")
|
||||
}))
|
||||
if o != nil && len(o.Headers) != 0 {
|
||||
c = c.Append(middleware.SetHeaders(o.Headers))
|
||||
}
|
||||
if o.Debug {
|
||||
log.SetDebugMode()
|
||||
}
|
||||
if o.LogLevel != "" {
|
||||
log.SetLevel(o.LogLevel)
|
||||
}
|
||||
return o, nil
|
||||
c = c.Append(middleware.ForwardedAddrHandler("fwd_ip"))
|
||||
c = c.Append(middleware.RemoteAddrHandler("ip"))
|
||||
c = c.Append(middleware.UserAgentHandler("user_agent"))
|
||||
c = c.Append(middleware.RefererHandler("referer"))
|
||||
c = c.Append(middleware.RequestIDHandler("req_id", "Request-Id"))
|
||||
c = c.Append(middleware.Healthcheck("/ping", version.UserAgent()))
|
||||
return c.Then(mux)
|
||||
}
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -166,34 +167,33 @@ func Test_newProxyeService(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_parseOptions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envKey string
|
||||
envValue string
|
||||
func Test_wrapMiddleware(t *testing.T) {
|
||||
o := &Options{
|
||||
Services: "all",
|
||||
Headers: map[string]string{
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "SAMEORIGIN",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
||||
"Content-Security-Policy": "default-src 'none'; style-src 'self' 'sha256-pSTVzZsFAqd2U3QYu+BoBDtuJWaPM/+qMy/dBRrhb5Y='; img-src 'self';",
|
||||
"Referrer-Policy": "Same-origin",
|
||||
}}
|
||||
mux := http.NewServeMux()
|
||||
req := httptest.NewRequest(http.MethodGet, "/404", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
io.WriteString(w, `OK`)
|
||||
})
|
||||
|
||||
want *Options
|
||||
wantErr bool
|
||||
}{
|
||||
{"no shared secret", "", "", nil, true},
|
||||
{"good", "SHARED_SECRET", "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", &Options{Services: "all", SharedKey: "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", LogLevel: "debug"}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
os.Setenv(tt.envKey, tt.envValue)
|
||||
defer os.Unsetenv(tt.envKey)
|
||||
mux.Handle("/404", h)
|
||||
out := wrapMiddleware(o, mux)
|
||||
out.ServeHTTP(rr, req)
|
||||
expected := fmt.Sprintf("OK")
|
||||
body := rr.Body.String()
|
||||
|
||||
got, err := parseOptions()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseOptions() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("parseOptions()\n")
|
||||
t.Errorf("got: %+v\n", got)
|
||||
t.Errorf("want: %+v\n", tt.want)
|
||||
|
||||
}
|
||||
})
|
||||
if body != expected {
|
||||
t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,8 +6,12 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pomerium/envconfig"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// DisableHeaderKey is the key used to check whether to disable setting header
|
||||
const DisableHeaderKey = "disable"
|
||||
|
||||
// Options are the global environmental flags used to set up pomerium's services.
|
||||
// If a base64 encoded certificate and key are not provided as environmental variables,
|
||||
// or if a file location is not provided, the server will attempt to find a matching keypair
|
||||
|
@ -45,6 +49,9 @@ type Options struct {
|
|||
// on port 80. If empty, no redirect server is started.
|
||||
HTTPRedirectAddr string `envconfig:"HTTP_REDIRECT_ADDR"`
|
||||
|
||||
// Headers to set on all proxied requests. Add a 'disable' key map to turn off.
|
||||
Headers map[string]string `envconfig:"HEADERS"`
|
||||
|
||||
// Timeout settings : https://github.com/pomerium/pomerium/issues/40
|
||||
ReadTimeout time.Duration `envconfig:"TIMEOUT_READ"`
|
||||
WriteTimeout time.Duration `envconfig:"TIMEOUT_WRITE"`
|
||||
|
@ -56,6 +63,14 @@ var defaultOptions = &Options{
|
|||
Debug: false,
|
||||
LogLevel: "debug",
|
||||
Services: "all",
|
||||
Headers: map[string]string{
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "SAMEORIGIN",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
||||
"Content-Security-Policy": "default-src 'none'; style-src 'self' 'sha256-pSTVzZsFAqd2U3QYu+BoBDtuJWaPM/+qMy/dBRrhb5Y='; img-src 'self';",
|
||||
"Referrer-Policy": "Same-origin",
|
||||
},
|
||||
}
|
||||
|
||||
// optionsFromEnvConfig builds the main binary's configuration
|
||||
|
@ -71,6 +86,15 @@ func optionsFromEnvConfig() (*Options, error) {
|
|||
if o.SharedKey == "" {
|
||||
return nil, errors.New("shared-key cannot be empty")
|
||||
}
|
||||
if o.Debug {
|
||||
log.SetDebugMode()
|
||||
}
|
||||
if o.LogLevel != "" {
|
||||
log.SetLevel(o.LogLevel)
|
||||
}
|
||||
if _, disable := o.Headers[DisableHeaderKey]; disable {
|
||||
o.Headers = make(map[string]string)
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue