diff --git a/internal/websocket/handler/handler.go b/internal/websocket/handler/handler.go index e6b01a49..6f8c438d 100644 --- a/internal/websocket/handler/handler.go +++ b/internal/websocket/handler/handler.go @@ -3,7 +3,6 @@ package handler import ( "encoding/json" - "github.com/pkg/errors" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -38,10 +37,11 @@ type MessageHandlerCtx struct { capture types.CaptureManager } -func (h *MessageHandlerCtx) Message(session types.Session, raw []byte) error { +func (h *MessageHandlerCtx) Message(session types.Session, raw []byte) bool { header := message.Message{} if err := json.Unmarshal(raw, &header); err != nil { - return err + h.logger.Error().Err(err).Msg("message parsing has failed") + return false } var err error @@ -87,8 +87,12 @@ func (h *MessageHandlerCtx) Message(session types.Session, raw []byte) error { return h.keyboardModifiers(session, payload) }) default: - return errors.Errorf("unknown message event %s", header.Event) + return false } - return errors.Wrapf(err, "%s failed", header.Event) + if err != nil { + h.logger.Error().Err(err).Msg("message handler has failed") + } + + return true } diff --git a/internal/websocket/manager.go b/internal/websocket/manager.go index 92aef70b..8aa3faa8 100644 --- a/internal/websocket/manager.go +++ b/internal/websocket/manager.go @@ -15,6 +15,8 @@ import ( "demodesk/neko/internal/types" ) +type HandlerFunction func(types.Session, []byte) bool + func New( sessions types.SessionManager, desktop types.DesktopManager, @@ -32,7 +34,8 @@ func New( return true }, }, - handler: handler.New(sessions, desktop, capture, webrtc), + handler: handler.New(sessions, desktop, capture, webrtc), + handlers: []HandlerFunction{}, } } @@ -45,6 +48,7 @@ type WebSocketManagerCtx struct { sessions types.SessionManager desktop types.DesktopManager handler *handler.MessageHandlerCtx + handlers []HandlerFunction shutdown chan bool } @@ -145,6 +149,10 @@ func (ws *WebSocketManagerCtx) Shutdown() error { return nil } +func (ws *WebSocketManagerCtx) AddHandler(handler HandlerFunction) { + ws.handlers = append(ws.handlers, handler) +} + func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) error { ws.logger.Debug().Msg("attempting to upgrade connection") @@ -260,8 +268,17 @@ func (ws *WebSocketManagerCtx) handle(connection *websocket.Conn, session types. Str("raw", string(raw)). Msg("received message from client") - if err := ws.handler.Message(session, raw); err != nil { - ws.logger.Error().Err(err).Msg("message handler has failed") + handled := ws.handler.Message(session, raw) + for _, handler := range ws.handlers { + if handled { + break + } + + handled = handler(session, raw) + } + + if !handled { + ws.logger.Warn().Msg("unhandled message") } case <-cancel: return