mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 00:40:25 +02:00
httputil : wrap handlers for additional context (#413)
Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
487fc655d6
commit
b3d3159185
27 changed files with 495 additions and 463 deletions
|
@ -2,110 +2,91 @@ package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/frontend"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
// Error formats creates a HTTP error with code, user friendly (and safe) error
|
||||
// message. If nil or empty, HTTP status code defaults to 500 and message
|
||||
// defaults to the text of the status code.
|
||||
func Error(message string, code int, err error) error {
|
||||
if code == 0 {
|
||||
code = http.StatusInternalServerError
|
||||
}
|
||||
if message == "" {
|
||||
message = http.StatusText(code)
|
||||
}
|
||||
return &httpError{Message: message, Code: code, Err: err}
|
||||
}
|
||||
var errorTemplate = template.Must(frontend.NewTemplates())
|
||||
var fullVersion = version.FullVersion()
|
||||
|
||||
type httpError struct {
|
||||
// Message to present to the end user.
|
||||
Message string
|
||||
// HTTPError contains an HTTP status code and wrapped error.
|
||||
type HTTPError struct {
|
||||
// HTTP status codes as registered with IANA.
|
||||
Code int
|
||||
|
||||
Err error // the cause
|
||||
Status int
|
||||
// Err is the wrapped error
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *httpError) Error() string {
|
||||
s := fmt.Sprintf("%d %s: %s", e.Code, http.StatusText(e.Code), e.Message)
|
||||
if e.Err != nil {
|
||||
return s + ": " + e.Err.Error()
|
||||
}
|
||||
return s
|
||||
}
|
||||
func (e *httpError) Unwrap() error { return e.Err }
|
||||
|
||||
// Timeout reports whether this error represents a user debuggable error.
|
||||
func (e *httpError) Debugable() bool {
|
||||
return e.Code == http.StatusUnauthorized || e.Code == http.StatusForbidden
|
||||
// NewError returns an error that contains a HTTP status and error.
|
||||
func NewError(status int, err error) error {
|
||||
return &HTTPError{Status: status, Err: err}
|
||||
}
|
||||
|
||||
// ErrorResponse renders an error page given an error. If the error is a
|
||||
// http error from this package, a user friendly message is set, http status code,
|
||||
// the ability to debug are also set.
|
||||
func ErrorResponse(w http.ResponseWriter, r *http.Request, e error) {
|
||||
statusCode := http.StatusInternalServerError // default status code to return
|
||||
errorString := e.Error()
|
||||
var canDebug bool
|
||||
var requestID string
|
||||
var httpError *httpError
|
||||
// if this is an HTTPError, we can add some additional useful information
|
||||
if errors.As(e, &httpError) {
|
||||
canDebug = httpError.Debugable()
|
||||
statusCode = httpError.Code
|
||||
errorString = httpError.Message
|
||||
}
|
||||
// Error implements the `error` interface.
|
||||
func (e *HTTPError) Error() string {
|
||||
return http.StatusText(e.Status) + ": " + e.Err.Error()
|
||||
}
|
||||
|
||||
// Unwrap implements the `error` Unwrap interface.
|
||||
func (e *HTTPError) Unwrap() error { return e.Err }
|
||||
|
||||
// Debugable reports whether this error represents a user debuggable error.
|
||||
func (e *HTTPError) Debugable() bool {
|
||||
return e.Status == http.StatusUnauthorized || e.Status == http.StatusForbidden
|
||||
}
|
||||
|
||||
// RetryURL returns the requests intended destination, if any.
|
||||
func (e *HTTPError) RetryURL(r *http.Request) string {
|
||||
return r.FormValue(urlutil.QueryRedirectURI)
|
||||
}
|
||||
|
||||
type errResponse struct {
|
||||
Status int
|
||||
Error string
|
||||
|
||||
StatusText string `json:"-"`
|
||||
RequestID string `json:",omitempty"`
|
||||
CanDebug bool `json:"-"`
|
||||
RetryURL string `json:"-"`
|
||||
Version string `json:"-"`
|
||||
}
|
||||
|
||||
// ErrorResponse replies to the request with the specified error message and HTTP code.
|
||||
// It does not otherwise end the request; the caller should ensure no further
|
||||
// writes are done to w.
|
||||
func (e *HTTPError) ErrorResponse(w http.ResponseWriter, r *http.Request) {
|
||||
// indicate to clients that the error originates from Pomerium, not the app
|
||||
w.Header().Set(HeaderPomeriumResponse, "true")
|
||||
w.WriteHeader(e.Status)
|
||||
|
||||
log.FromRequest(r).Error().Err(e).Str("http-message", errorString).Int("http-code", statusCode).Msg("http-error")
|
||||
|
||||
log.FromRequest(r).Info().Err(e).Msg("httputil: ErrorResponse")
|
||||
var requestID string
|
||||
if id, ok := log.IDFromRequest(r); ok {
|
||||
requestID = id
|
||||
}
|
||||
response := errResponse{
|
||||
Status: e.Status,
|
||||
StatusText: http.StatusText(e.Status),
|
||||
Error: e.Error(),
|
||||
RequestID: requestID,
|
||||
CanDebug: e.Debugable(),
|
||||
RetryURL: e.RetryURL(r),
|
||||
Version: fullVersion,
|
||||
}
|
||||
|
||||
if r.Header.Get("Accept") == "application/json" {
|
||||
var response struct {
|
||||
Error string `json:"error"`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
response.Error = errorString
|
||||
writeJSONResponse(w, statusCode, response)
|
||||
} else {
|
||||
w.WriteHeader(statusCode)
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
|
||||
t := struct {
|
||||
Code int
|
||||
Title string
|
||||
Message string
|
||||
RequestID string
|
||||
CanDebug bool
|
||||
}{
|
||||
Code: statusCode,
|
||||
Title: http.StatusText(statusCode),
|
||||
Message: errorString,
|
||||
RequestID: requestID,
|
||||
CanDebug: canDebug,
|
||||
}
|
||||
template.Must(frontend.NewTemplates()).ExecuteTemplate(w, "error.html", t)
|
||||
}
|
||||
}
|
||||
|
||||
// writeJSONResponse is a helper that sets the application/json header and writes a response.
|
||||
func writeJSONResponse(w http.ResponseWriter, code int, response interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
|
||||
err := json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
w.Header().Set("Content-Type", "text/html; charset=UTF-8")
|
||||
errorTemplate.ExecuteTemplate(w, "error.html", response)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,68 +9,67 @@ import (
|
|||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestErrorResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rw http.ResponseWriter
|
||||
r *http.Request
|
||||
e *httpError
|
||||
}{
|
||||
{"good", httptest.NewRecorder(), &http.Request{Method: http.MethodGet}, &httpError{Code: http.StatusBadRequest, Message: "missing id token"}},
|
||||
{"good json", httptest.NewRecorder(), &http.Request{Method: http.MethodGet, Header: http.Header{"Accept": []string{"application/json"}}}, &httpError{Code: http.StatusBadRequest, Message: "missing id token"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ErrorResponse(tt.rw, tt.r, tt.e)
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestHTTPError_ErrorResponse(t *testing.T) {
|
||||
|
||||
func TestError_Error(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
Message string
|
||||
Code int
|
||||
InnerErr error
|
||||
want string
|
||||
}{
|
||||
{"good", "short and stout", http.StatusTeapot, nil, "418 I'm a teapot: short and stout"},
|
||||
{"nested error", "short and stout", http.StatusTeapot, errors.New("another error"), "418 I'm a teapot: short and stout: another error"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := httpError{
|
||||
Message: tt.Message,
|
||||
Code: tt.Code,
|
||||
Err: tt.InnerErr,
|
||||
}
|
||||
got := h.Error()
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Errorf("Error.Error() = %s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_httpError_Error(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
message string
|
||||
code int
|
||||
err error
|
||||
want string
|
||||
Status int
|
||||
Err error
|
||||
reqType string
|
||||
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{"good", "foobar", 200, nil, "200 OK: foobar"},
|
||||
{"no code", "foobar", 0, nil, "500 Internal Server Error: foobar"},
|
||||
{"no message or code", "", 0, nil, "500 Internal Server Error: Internal Server Error"},
|
||||
{"404 json", http.StatusNotFound, errors.New("route not known"), "application/json", http.StatusNotFound, "{\"Status\":404,\"Error\":\"Not Found: route not known\"}\n"},
|
||||
{"404 html", http.StatusNotFound, errors.New("route not known"), "", http.StatusNotFound, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := Error(tt.message, tt.code, tt.err)
|
||||
if got := e.Error(); got != tt.want {
|
||||
t.Errorf("httpError.Error() = %v, want %v", got, tt.want)
|
||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := NewError(tt.Status, tt.Err)
|
||||
var e *HTTPError
|
||||
if errors.As(err, &e) {
|
||||
e.ErrorResponse(w, r)
|
||||
} else {
|
||||
http.Error(w, "coulnd't convert error type", http.StatusTeapot)
|
||||
}
|
||||
})
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", tt.reqType)
|
||||
w := httptest.NewRecorder()
|
||||
fn(w, r)
|
||||
if diff := cmp.Diff(tt.wantStatus, w.Code); diff != "" {
|
||||
t.Errorf("ErrorResponse status:\n %s", diff)
|
||||
}
|
||||
if tt.reqType == "application/json" {
|
||||
if diff := cmp.Diff(tt.wantBody, w.Body.String()); diff != "" {
|
||||
t.Errorf("ErrorResponse status:\n %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status int
|
||||
err error
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", 404, errors.New("error"), true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := NewError(tt.status, tt.err)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewError() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if err != nil && !errors.Is(err, tt.err) {
|
||||
t.Errorf("NewError() unwrap fail = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
|
@ -14,7 +16,7 @@ func HealthCheck(w http.ResponseWriter, r *http.Request) {
|
|||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if r.Method == http.MethodGet {
|
||||
w.Write([]byte(http.StatusText(http.StatusOK)))
|
||||
fmt.Fprintln(w, http.StatusText(http.StatusOK))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -24,3 +26,22 @@ func Redirect(w http.ResponseWriter, r *http.Request, url string, code int) {
|
|||
w.Header().Set(HeaderPomeriumResponse, "true")
|
||||
http.Redirect(w, r, url, code)
|
||||
}
|
||||
|
||||
// The HandlerFunc type is an adapter to allow the use of
|
||||
// ordinary functions as HTTP handlers. If f is a function
|
||||
// with the appropriate signature, HandlerFunc(f) is a
|
||||
// Handler that calls f.
|
||||
//
|
||||
// adapted from std library to suppport error wrapping
|
||||
type HandlerFunc func(http.ResponseWriter, *http.Request) error
|
||||
|
||||
// ServeHTTP calls f(w, r) error.
|
||||
func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if err := f(w, r); err != nil {
|
||||
var e *HTTPError
|
||||
if !errors.As(err, &e) {
|
||||
e = &HTTPError{http.StatusInternalServerError, err}
|
||||
}
|
||||
e.ErrorResponse(w, r)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestHealthCheck(t *testing.T) {
|
||||
|
@ -66,3 +69,26 @@ func TestRedirect(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerFunc_ServeHTTP(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
f HandlerFunc
|
||||
wantBody string
|
||||
}{
|
||||
{"good http error", func(w http.ResponseWriter, r *http.Request) error { return NewError(404, errors.New("404")) }, "{\"Status\":404,\"Error\":\"Not Found: 404\"}\n"},
|
||||
{"good std error", func(w http.ResponseWriter, r *http.Request) error { return errors.New("404") }, "{\"Status\":500,\"Error\":\"Internal Server Error: 404\"}\n"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
tt.f.ServeHTTP(w, r)
|
||||
if diff := cmp.Diff(tt.wantBody, w.Body.String()); diff != "" {
|
||||
t.Errorf("ErrorResponse status:\n %s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,9 @@ func NewRouter() *mux.Router {
|
|||
|
||||
// CSRFFailureHandler sets a HTTP 403 Forbidden status and writes the
|
||||
// CSRF failure reason to the response.
|
||||
func CSRFFailureHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ErrorResponse(w, r, Error("CSRF Failure", http.StatusForbidden, csrf.FailureReason(r)))
|
||||
func CSRFFailureHandler(w http.ResponseWriter, r *http.Request) error {
|
||||
if err := csrf.FailureReason(r); err != nil {
|
||||
return NewError(http.StatusBadRequest, csrf.FailureReason(r))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,43 +1,12 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func TestCSRFFailureHandler(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
wantBody string
|
||||
wantStatus int
|
||||
}{
|
||||
{"basic csrf failure", "{\"error\":\"CSRF Failure\"}\n", http.StatusForbidden},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
CSRFFailureHandler(w, r)
|
||||
gotBody := w.Body.String()
|
||||
gotStatus := w.Result().StatusCode
|
||||
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRouter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue