middleware: equalize lengths of input (#1934)

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
bobby 2021-02-23 08:31:17 -08:00 committed by GitHub
parent e56fb38cb5
commit 9c7958b66f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 75 additions and 24 deletions

View file

@ -6,6 +6,7 @@ import (
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
) )
@ -89,7 +90,7 @@ func (mgr *MetricsManager) updateServer(cfg *Config) {
} }
if username, password, ok := cfg.Options.GetMetricsBasicAuth(); ok { if username, password, ok := cfg.Options.GetMetricsBasicAuth(); ok {
handler = httputil.RequireBasicAuth(handler, username, password) handler = middleware.RequireBasicAuth(username, password)(handler)
} }
mgr.srv, err = httputil.NewServer(&httputil.ServerOptions{ mgr.srv, err = httputil.NewServer(&httputil.ServerOptions{

View file

@ -2,7 +2,6 @@ package httputil
import ( import (
"bytes" "bytes"
"crypto/subtle"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -65,25 +64,3 @@ func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
e.ErrorResponse(w, r) e.ErrorResponse(w, r)
} }
} }
// RequireBasicAuth creates a new handler that requires basic auth from the client before
// calling the underlying handler.
func RequireBasicAuth(handler http.Handler, username, password string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
u, p, ok := r.BasicAuth()
if !ok {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
if subtle.ConstantTimeCompare([]byte(u), []byte(username)) != 1 ||
subtle.ConstantTimeCompare([]byte(p), []byte(password)) != 1 {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
handler.ServeHTTP(w, r)
})
}

View file

@ -1,6 +1,8 @@
package middleware package middleware
import ( import (
"crypto/sha256"
"crypto/subtle"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@ -75,3 +77,31 @@ func TimeoutHandlerFunc(timeout time.Duration, timeoutError string) func(next ht
}) })
} }
} }
// RequireBasicAuth creates a new handler that requires basic auth from the client before
// calling the underlying handler.
func RequireBasicAuth(username, password string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
u, p, ok := r.BasicAuth()
if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
givenUser := sha256.Sum256([]byte(u))
givenPass := sha256.Sum256([]byte(p))
requiredUser := sha256.Sum256([]byte(username))
requiredPass := sha256.Sum256([]byte(password))
if subtle.ConstantTimeCompare(givenUser[:], requiredUser[:]) != 1 ||
subtle.ConstantTimeCompare(givenPass[:], requiredPass[:]) != 1 {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
}

View file

@ -155,3 +155,46 @@ func TestValidateSignature(t *testing.T) {
}) })
} }
} }
func TestRequireBasicAuth(t *testing.T) {
t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
givenUser string
givenPass string
wantUser string
wantPass string
wantStatus int
}{
{"good", "foo", "bar", "foo", "bar", 200},
{"bad pass", "foo", "bar", "foo", "buzz", 401},
{"bad user", "foo", "bar", "buzz", "bar", 401},
{"empty", "", "", "", "", 401}, // don't add auth
{"empty user", "", "bar", "", "bar", 200},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
if tt.givenUser != "" || tt.givenPass != "" {
req.SetBasicAuth(tt.givenUser, tt.givenPass)
}
rr := httptest.NewRecorder()
handler := RequireBasicAuth(tt.wantUser, tt.wantPass)(fn)
handler.ServeHTTP(rr, req)
if status := rr.Code; status != tt.wantStatus {
t.Errorf("RequireBasicAuth() error = %v, wantErr %v\n%v", rr.Result().StatusCode, tt.wantStatus, rr.Body.String())
}
})
}
}