diff --git a/config/metrics.go b/config/metrics.go index 226a4e789..ab817a214 100644 --- a/config/metrics.go +++ b/config/metrics.go @@ -6,6 +6,7 @@ import ( "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/telemetry/metrics" ) @@ -89,7 +90,7 @@ func (mgr *MetricsManager) updateServer(cfg *Config) { } if username, password, ok := cfg.Options.GetMetricsBasicAuth(); ok { - handler = httputil.RequireBasicAuth(handler, username, password) + handler = middleware.RequireBasicAuth(username, password)(handler) } mgr.srv, err = httputil.NewServer(&httputil.ServerOptions{ diff --git a/internal/httputil/handlers.go b/internal/httputil/handlers.go index 75c5ae5e9..bd86d23ea 100644 --- a/internal/httputil/handlers.go +++ b/internal/httputil/handlers.go @@ -2,7 +2,6 @@ package httputil import ( "bytes" - "crypto/subtle" "encoding/json" "errors" "fmt" @@ -65,25 +64,3 @@ func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) { e.ErrorResponse(w, r) } } - -// RequireBasicAuth creates a new handler that requires basic auth from the client before -// calling the underlying handler. -func RequireBasicAuth(handler http.Handler, username, password string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - - u, p, ok := r.BasicAuth() - if !ok { - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - return - } - - if subtle.ConstantTimeCompare([]byte(u), []byte(username)) != 1 || - subtle.ConstantTimeCompare([]byte(p), []byte(password)) != 1 { - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - return - } - - handler.ServeHTTP(w, r) - }) -} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index ac4b5a92f..4222c0c7a 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -1,6 +1,8 @@ package middleware import ( + "crypto/sha256" + "crypto/subtle" "net/http" "strings" "time" @@ -75,3 +77,31 @@ func TimeoutHandlerFunc(timeout time.Duration, timeoutError string) func(next ht }) } } + +// RequireBasicAuth creates a new handler that requires basic auth from the client before +// calling the underlying handler. +func RequireBasicAuth(username, password string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u, p, ok := r.BasicAuth() + if !ok { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + givenUser := sha256.Sum256([]byte(u)) + givenPass := sha256.Sum256([]byte(p)) + requiredUser := sha256.Sum256([]byte(username)) + requiredPass := sha256.Sum256([]byte(password)) + + if subtle.ConstantTimeCompare(givenUser[:], requiredUser[:]) != 1 || + subtle.ConstantTimeCompare(givenPass[:], requiredPass[:]) != 1 { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index cb0fc72d6..5eb309f48 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -155,3 +155,46 @@ func TestValidateSignature(t *testing.T) { }) } } + +func TestRequireBasicAuth(t *testing.T) { + t.Parallel() + + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, http.StatusText(http.StatusOK)) + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + givenUser string + givenPass string + wantUser string + wantPass string + wantStatus int + }{ + {"good", "foo", "bar", "foo", "bar", 200}, + {"bad pass", "foo", "bar", "foo", "buzz", 401}, + {"bad user", "foo", "bar", "buzz", "bar", 401}, + {"empty", "", "", "", "", 401}, // don't add auth + {"empty user", "", "bar", "", "bar", 200}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + if tt.givenUser != "" || tt.givenPass != "" { + req.SetBasicAuth(tt.givenUser, tt.givenPass) + } + + rr := httptest.NewRecorder() + handler := RequireBasicAuth(tt.wantUser, tt.wantPass)(fn) + handler.ServeHTTP(rr, req) + if status := rr.Code; status != tt.wantStatus { + t.Errorf("RequireBasicAuth() error = %v, wantErr %v\n%v", rr.Result().StatusCode, tt.wantStatus, rr.Body.String()) + } + }) + } +}