diff --git a/docs/docs/CHANGELOG.md b/docs/docs/CHANGELOG.md index 43dac5921..f778e156a 100644 --- a/docs/docs/CHANGELOG.md +++ b/docs/docs/CHANGELOG.md @@ -19,6 +19,7 @@ ### Changed +- The healthcheck endpoints (`/ping`) now returns the http status `405` StatusMethodNotAllowed for non-`GET` requests. [GH-319](https://github.com/pomerium/pomerium/issues/319) - Authenticate service no longer uses gRPC. ### Removed diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index d9d46f2d4..ce3f22511 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -149,11 +149,17 @@ func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ctx, span := trace.StartSpan(r.Context(), "middleware.Healthcheck") defer span.End() - - if r.Method == "GET" && strings.EqualFold(r.URL.Path, endpoint) { + if strings.EqualFold(r.URL.Path, endpoint) { + // https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html + if r.Method != http.MethodGet && r.Method != http.MethodHead { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return + } w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(msg)) + if r.Method == http.MethodGet { + w.Write([]byte(msg)) + } return } next.ServeHTTP(w, r.WithContext(ctx)) diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index b2864d908..000f768ca 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -11,7 +11,7 @@ import ( ) func Test_SameDomain(t *testing.T) { - + t.Parallel() tests := []struct { name string uri string @@ -41,6 +41,7 @@ func Test_SameDomain(t *testing.T) { } func Test_ValidSignature(t *testing.T) { + t.Parallel() goodURL := "https://example.com/redirect" secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=" now := fmt.Sprint(time.Now().Unix()) @@ -123,6 +124,7 @@ func TestSetHeaders(t *testing.T) { } func TestValidateRedirectURI(t *testing.T) { + t.Parallel() tests := []struct { name string rootDomain string @@ -166,6 +168,7 @@ func TestValidateRedirectURI(t *testing.T) { } func TestValidateClientSecret(t *testing.T) { + t.Parallel() tests := []struct { name string sharedSecret string @@ -202,6 +205,7 @@ func TestValidateClientSecret(t *testing.T) { } func TestValidateSignature(t *testing.T) { + t.Parallel() secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=" now := fmt.Sprint(time.Now().Unix()) goodURL := "https://example.com/redirect" @@ -251,33 +255,36 @@ func TestValidateSignature(t *testing.T) { } func TestHealthCheck(t *testing.T) { + t.Parallel() + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hi")) + }) tests := []struct { name string method string clientPath string - expected []byte + serverPath string + + wantStatus int }{ - {"good", http.MethodGet, "/ping", []byte("OK")}, - //tood(bdd): miss? + {"good - Get", http.MethodGet, "/ping", "/ping", http.StatusOK}, + {"good - Head", http.MethodHead, "/ping", "/ping", http.StatusOK}, + {"bad - Options", http.MethodOptions, "/ping", "/ping", http.StatusMethodNotAllowed}, + {"bad - Put", http.MethodPut, "/ping", "/ping", http.StatusMethodNotAllowed}, + {"bad - Post", http.MethodPost, "/ping", "/ping", http.StatusMethodNotAllowed}, + {"bad - route miss", http.MethodGet, "/not-ping", "/ping", http.StatusOK}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req, err := http.NewRequest(http.MethodGet, tt.clientPath, nil) - if err != nil { - t.Fatal(err) - } + r := httptest.NewRequest(tt.method, tt.clientPath, nil) + w := httptest.NewRecorder() - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("Hi")) - }) - rr := httptest.NewRecorder() - handler := Healthcheck(tt.clientPath, string(tt.expected))(testHandler) - handler.ServeHTTP(rr, req) - if rr.Body.String() != string(tt.expected) { - t.Errorf("body differs. got %ss want %ss", rr.Body, tt.expected) - t.Errorf("%s", rr.Body) + handler := Healthcheck(tt.serverPath, string("OK"))(testHandler) + handler.ServeHTTP(w, r) + if w.Code != tt.wantStatus { + t.Errorf("code differs. got %d want %d body: %s", w.Code, tt.wantStatus, w.Body.String()) } }) }