http refactor.

This commit is contained in:
Miroslav Šedivý 2021-09-17 00:24:33 +02:00
parent 4fa11e6a2a
commit 5a7cdd31fe
6 changed files with 300 additions and 138 deletions

View file

@ -1,86 +1,129 @@
package http
import (
"context"
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/middleware"
"github.com/rs/zerolog/log"
"demodesk/neko/internal/http/auth"
"demodesk/neko/internal/types"
"demodesk/neko/internal/utils"
)
func Logger(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
req := map[string]interface{}{}
type logEntryKey int
// exclude healthcheck from logs
if r.RequestURI == "/api/health" {
next.ServeHTTP(w, r)
const logEntryKeyCtx logEntryKey = iota
func setLogEntry(r *http.Request, data logEntry) *http.Request {
ctx := context.WithValue(r.Context(), logEntryKeyCtx, data)
return r.WithContext(ctx)
}
func getLogEntry(r *http.Request) logEntry {
return r.Context().Value(logEntryKeyCtx).(logEntry)
}
func LoggerMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, setLogEntry(r, newLogEntry(w, r)))
})
}
type logEntry struct {
req struct {
time time.Time
id string
scheme string
proto string
method string
remote string
agent string
uri string
}
res struct {
time time.Time
code int
bytes int
}
err error
elapsed time.Duration
hasSession bool
session types.Session
}
func newLogEntry(w http.ResponseWriter, r *http.Request) logEntry {
e := logEntry{}
e.req.time = time.Now()
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
e.req.id = reqID
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
e.req.scheme = scheme
e.req.proto = r.Proto
e.req.method = r.Method
e.req.remote = r.RemoteAddr
e.req.agent = r.UserAgent()
e.req.uri = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)
return e
}
func (e *logEntry) SetResponse(w http.ResponseWriter, r *http.Request) {
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
e.res.time = time.Now()
e.res.code = ww.Status()
e.res.bytes = ww.BytesWritten()
e.elapsed = e.res.time.Sub(e.req.time)
e.session, e.hasSession = auth.GetSession(r)
}
func (e *logEntry) SetError(err error) {
e.err = err
}
func (e *logEntry) Write() {
logger := log.With().
Str("module", "http").
Float64("elapsed", float64(e.elapsed.Nanoseconds())/1000000.0).
Interface("req", e.req).
Interface("res", e.res).
Logger()
if e.hasSession {
logger = logger.With().Str("session_id", e.session.ID()).Logger()
}
if e.err != nil {
httpErr, ok := e.err.(*utils.HTTPError)
if !ok {
logger.Err(e.err).Msgf("request failed (%d)", e.res.code)
return
}
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
req["id"] = reqID
if httpErr.Message == "" {
httpErr.Message = http.StatusText(httpErr.Code)
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
logger := logger.Error().Err(httpErr.InternalErr)
message := httpErr.Message
if httpErr.InternalMsg != "" {
message = httpErr.InternalMsg
}
req["scheme"] = scheme
req["proto"] = r.Proto
req["method"] = r.Method
req["remote"] = r.RemoteAddr
req["agent"] = r.UserAgent()
req["uri"] = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)
fields := map[string]interface{}{}
fields["req"] = req
entry := &entry{
fields: fields,
}
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
t1 := time.Now()
defer func() {
entry.Write(ww.Status(), ww.BytesWritten(), time.Since(t1))
}()
next.ServeHTTP(ww, r)
logger.Msgf("request failed (%d): %s", e.res.code, message)
return
}
return http.HandlerFunc(fn)
}
type entry struct {
fields map[string]interface{}
errors []map[string]interface{}
}
func (e *entry) Write(status, bytes int, elapsed time.Duration) {
res := map[string]interface{}{}
res["time"] = time.Now().UTC().Format(time.RFC1123)
res["status"] = status
res["bytes"] = bytes
res["elapsed"] = float64(elapsed.Nanoseconds()) / 1000000.0
e.fields["res"] = res
e.fields["module"] = "http"
if len(e.errors) > 0 {
e.fields["errors"] = e.errors
log.Error().Fields(e.fields).Msgf("request failed (%d)", status)
} else {
log.Debug().Fields(e.fields).Msgf("request complete (%d)", status)
}
}
func (e *entry) Panic(v interface{}, stack []byte) {
err := map[string]interface{}{}
err["message"] = fmt.Sprintf("%+v", v)
err["stack"] = string(stack)
e.errors = append(e.errors, err)
logger.Debug().Msgf("request complete (%d)", e.res.code)
}

View file

@ -5,8 +5,6 @@ import (
"net/http"
"os"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/cors"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@ -18,16 +16,15 @@ import (
type HttpManagerCtx struct {
logger zerolog.Logger
config *config.Server
router *chi.Mux
router *RouterCtx
http *http.Server
}
func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, config *config.Server) *HttpManagerCtx {
logger := log.With().Str("module", "http").Logger()
router := chi.NewRouter()
router.Use(middleware.Recoverer) // Recover from panics without crashing server
router.Use(cors.Handler(cors.Options{
router := newRouter()
router.UseBypass(cors.Handler(cors.Options{
AllowOriginFunc: func(r *http.Request, origin string) bool {
return config.AllowOrigin(origin)
},
@ -37,32 +34,28 @@ func New(WebSocketManager types.WebSocketManager, ApiManager types.ApiManager, c
AllowCredentials: true,
MaxAge: 300, // Maximum value not ignored by any of major browsers
}))
router.Use(middleware.RequestID) // Create a request ID for each request
router.Use(Logger) // Log API request calls using custom logger function
router.Route("/api", ApiManager.Route)
router.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
WebSocketManager.Upgrade(w, r, func(r *http.Request) bool {
router.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) error {
return WebSocketManager.Upgrade(w, r, func(r *http.Request) bool {
return config.AllowOrigin(r.Header.Get("Origin"))
})
})
if config.Static != "" {
fs := http.FileServer(http.Dir(config.Static))
router.Get("/*", func(w http.ResponseWriter, r *http.Request) {
if _, err := os.Stat(config.Static + r.URL.Path); !os.IsNotExist(err) {
router.Get("/*", func(w http.ResponseWriter, r *http.Request) error {
_, err := os.Stat(config.Static + r.URL.Path)
if !os.IsNotExist(err) {
fs.ServeHTTP(w, r)
} else {
http.NotFound(w, r)
}
return err
})
}
router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
return &HttpManagerCtx{
logger: logger,
config: config,

112
internal/http/router.go Normal file
View file

@ -0,0 +1,112 @@
package http
import (
"demodesk/neko/internal/types"
"demodesk/neko/internal/utils"
"net/http"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
)
type RouterCtx struct {
chi chi.Router
}
func newRouter() *RouterCtx {
router := chi.NewRouter()
router.Use(middleware.Recoverer) // Recover from panics without crashing server
router.Use(middleware.RequestID) // Create a request ID for each request
router.Use(LoggerMiddleware)
return &RouterCtx{router}
}
func (r *RouterCtx) Group(fn func(types.Router)) {
r.chi.Group(func(c chi.Router) {
fn(&RouterCtx{c})
})
}
func (r *RouterCtx) Route(pattern string, fn func(types.Router)) {
r.chi.Route(pattern, func(c chi.Router) {
fn(&RouterCtx{c})
})
}
func (r *RouterCtx) Get(pattern string, fn types.RouterHandler) {
r.chi.Get(pattern, routeHandler(fn))
}
func (r *RouterCtx) Post(pattern string, fn types.RouterHandler) {
r.chi.Post(pattern, routeHandler(fn))
}
func (r *RouterCtx) Put(pattern string, fn types.RouterHandler) {
r.chi.Put(pattern, routeHandler(fn))
}
func (r *RouterCtx) Delete(pattern string, fn types.RouterHandler) {
r.chi.Delete(pattern, routeHandler(fn))
}
func (r *RouterCtx) With(fn types.MiddlewareHandler) types.Router {
c := r.chi.With(middlewareHandler(fn))
return &RouterCtx{c}
}
func (r *RouterCtx) WithBypass(fn func(next http.Handler) http.Handler) types.Router {
c := r.chi.With(fn)
return &RouterCtx{c}
}
func (r *RouterCtx) Use(fn types.MiddlewareHandler) {
r.chi.Use(middlewareHandler(fn))
}
func (r *RouterCtx) UseBypass(fn func(next http.Handler) http.Handler) {
r.chi.Use(fn)
}
func (r *RouterCtx) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.chi.ServeHTTP(w, req)
}
func errorHandler(err error, w http.ResponseWriter, r *http.Request) {
httpErr, ok := err.(*utils.HTTPError)
if !ok {
httpErr = utils.HttpInternalServerError().WithInternalErr(err)
}
utils.HttpJsonResponse(w, httpErr.Code, httpErr)
}
func routeHandler(fn types.RouterHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logEntry := getLogEntry(r)
if err := fn(w, r); err != nil {
logEntry.SetError(err)
errorHandler(err, w, r)
}
logEntry.SetResponse(w, r)
logEntry.Write()
}
}
func middlewareHandler(fn types.MiddlewareHandler) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logEntry := getLogEntry(r)
ctx, err := fn(w, r)
if err != nil {
logEntry.SetError(err)
errorHandler(err, w, r)
logEntry.SetResponse(w, r)
logEntry.Write()
return
}
if ctx != nil {
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
}
}