package httputil

import (
	"errors"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/google/go-cmp/cmp"
)

func TestHealthCheck(t *testing.T) {
	t.Parallel()
	tests := []struct {
		name   string
		method string

		wantStatus int
	}{
		{"good - Get", http.MethodGet, http.StatusOK},
		{"good - Head", http.MethodHead, http.StatusOK},
		{"bad - Options", http.MethodOptions, http.StatusMethodNotAllowed},
		{"bad - Put", http.MethodPut, http.StatusMethodNotAllowed},
		{"bad - Post", http.MethodPost, http.StatusMethodNotAllowed},
		{"bad - route miss", http.MethodGet, http.StatusOK},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {

			r := httptest.NewRequest(tt.method, "/", nil)
			w := httptest.NewRecorder()

			HealthCheck(w, r)
			if w.Code != tt.wantStatus {
				t.Errorf("code differs. got %d want %d body: %s", w.Code, tt.wantStatus, w.Body.String())
			}
		})
	}
}

func TestRedirect(t *testing.T) {
	t.Parallel()
	tests := []struct {
		name   string
		method string

		url  string
		code int

		wantStatus int
	}{
		{"good", http.MethodGet, "https://pomerium.io", http.StatusFound, http.StatusFound},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {

			r := httptest.NewRequest(tt.method, "/", nil)
			w := httptest.NewRecorder()

			Redirect(w, r, tt.url, tt.code)
			if w.Code != tt.wantStatus {
				t.Errorf("code differs. got %d want %d body: %s", w.Code, tt.wantStatus, w.Body.String())
			}
			if w.Result().Header.Get(HeaderPomeriumResponse) == "" {
				t.Errorf("pomerium header not found")
			}
		})
	}
}

func TestHandlerFunc_ServeHTTP(t *testing.T) {

	tests := []struct {
		name     string
		f        HandlerFunc
		wantBody string
	}{
		{"good http error", func(w http.ResponseWriter, r *http.Request) error { return NewError(404, errors.New("404")) }, "{\"Status\":404,\"Error\":\"Not Found: 404\"}\n"},
		{"good std error", func(w http.ResponseWriter, r *http.Request) error { return errors.New("404") }, "{\"Status\":500,\"Error\":\"Internal Server Error: 404\"}\n"},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			r := httptest.NewRequest("GET", "/", nil)
			r.Header.Set("Accept", "application/json")
			w := httptest.NewRecorder()
			tt.f.ServeHTTP(w, r)
			if diff := cmp.Diff(tt.wantBody, w.Body.String()); diff != "" {
				t.Errorf("ErrorResponse status:\n %s", diff)
			}
		})
	}
}