mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-12 00:27:35 +02:00
middleware: equalize lengths of input (#1934)
Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
e56fb38cb5
commit
9c7958b66f
4 changed files with 75 additions and 24 deletions
|
@ -6,6 +6,7 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry"
|
"github.com/pomerium/pomerium/internal/telemetry"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
)
|
)
|
||||||
|
@ -89,7 +90,7 @@ func (mgr *MetricsManager) updateServer(cfg *Config) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if username, password, ok := cfg.Options.GetMetricsBasicAuth(); ok {
|
if username, password, ok := cfg.Options.GetMetricsBasicAuth(); ok {
|
||||||
handler = httputil.RequireBasicAuth(handler, username, password)
|
handler = middleware.RequireBasicAuth(username, password)(handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr.srv, err = httputil.NewServer(&httputil.ServerOptions{
|
mgr.srv, err = httputil.NewServer(&httputil.ServerOptions{
|
||||||
|
|
|
@ -2,7 +2,6 @@ package httputil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/subtle"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -65,25 +64,3 @@ func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
e.ErrorResponse(w, r)
|
e.ErrorResponse(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequireBasicAuth creates a new handler that requires basic auth from the client before
|
|
||||||
// calling the underlying handler.
|
|
||||||
func RequireBasicAuth(handler http.Handler, username, password string) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
|
||||||
|
|
||||||
u, p, ok := r.BasicAuth()
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if subtle.ConstantTimeCompare([]byte(u), []byte(username)) != 1 ||
|
|
||||||
subtle.ConstantTimeCompare([]byte(p), []byte(password)) != 1 {
|
|
||||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
handler.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -75,3 +77,31 @@ func TimeoutHandlerFunc(timeout time.Duration, timeoutError string) func(next ht
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RequireBasicAuth creates a new handler that requires basic auth from the client before
|
||||||
|
// calling the underlying handler.
|
||||||
|
func RequireBasicAuth(username, password string) func(next http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
u, p, ok := r.BasicAuth()
|
||||||
|
if !ok {
|
||||||
|
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
||||||
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
givenUser := sha256.Sum256([]byte(u))
|
||||||
|
givenPass := sha256.Sum256([]byte(p))
|
||||||
|
requiredUser := sha256.Sum256([]byte(username))
|
||||||
|
requiredPass := sha256.Sum256([]byte(password))
|
||||||
|
|
||||||
|
if subtle.ConstantTimeCompare(givenUser[:], requiredUser[:]) != 1 ||
|
||||||
|
subtle.ConstantTimeCompare(givenPass[:], requiredPass[:]) != 1 {
|
||||||
|
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
||||||
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -155,3 +155,46 @@ func TestValidateSignature(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRequireBasicAuth(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
givenUser string
|
||||||
|
givenPass string
|
||||||
|
wantUser string
|
||||||
|
wantPass string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{"good", "foo", "bar", "foo", "bar", 200},
|
||||||
|
{"bad pass", "foo", "bar", "foo", "buzz", 401},
|
||||||
|
{"bad user", "foo", "bar", "buzz", "bar", 401},
|
||||||
|
{"empty", "", "", "", "", 401}, // don't add auth
|
||||||
|
{"empty user", "", "bar", "", "bar", 200},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if tt.givenUser != "" || tt.givenPass != "" {
|
||||||
|
req.SetBasicAuth(tt.givenUser, tt.givenPass)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler := RequireBasicAuth(tt.wantUser, tt.wantPass)(fn)
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if status := rr.Code; status != tt.wantStatus {
|
||||||
|
t.Errorf("RequireBasicAuth() error = %v, wantErr %v\n%v", rr.Result().StatusCode, tt.wantStatus, rr.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue