diff --git a/internal/websocket/handler/handler.go b/internal/websocket/handler/handler.go index f2ae9e9a..6e34b8df 100644 --- a/internal/websocket/handler/handler.go +++ b/internal/websocket/handler/handler.go @@ -42,7 +42,7 @@ type MessageHandlerCtx struct { locked bool } -func (h *MessageHandlerCtx) Connected(id string, socket types.WebSocket) (bool, string) { +func (h *MessageHandlerCtx) Connected(session types.Session, socket types.WebSocket) (bool, string) { address := socket.Address() if address != "" { ok, banned := h.banned[address] @@ -54,12 +54,9 @@ func (h *MessageHandlerCtx) Connected(id string, socket types.WebSocket) (bool, h.logger.Debug().Msg("no remote address") } - if h.locked { - session, ok := h.sessions.Get(id) - if !ok || !session.Admin() { - h.logger.Debug().Msg("server locked") - return false, "locked" - } + if h.locked && !session.Admin(){ + h.logger.Debug().Msg("server locked") + return false, "locked" } return true, "" @@ -74,17 +71,12 @@ func (h *MessageHandlerCtx) Disconnected(id string) error { return h.sessions.Destroy(id) } -func (h *MessageHandlerCtx) Message(id string, raw []byte) error { +func (h *MessageHandlerCtx) Message(session types.Session, raw []byte) error { header := message.Message{} if err := json.Unmarshal(raw, &header); err != nil { return err } - session, ok := h.sessions.Get(id) - if !ok { - return errors.Errorf("unknown session id %s", id) - } - switch header.Event { // Signal Events case event.SIGNAL_ANSWER: diff --git a/internal/websocket/manager.go b/internal/websocket/manager.go index 0a727644..0000af69 100644 --- a/internal/websocket/manager.go +++ b/internal/websocket/manager.go @@ -149,13 +149,13 @@ func (ws *WebSocketManagerCtx) Upgrade(w http.ResponseWriter, r *http.Request) e // } socket := &WebSocketCtx{ - id: session.ID(), + session: session, ws: ws, address: ip, connection: connection, } - ok, reason := ws.handler.Connected(session.ID(), socket) + ok, reason := ws.handler.Connected(session, socket) if !ok { // TODO: Refactor if err = connection.WriteJSON(message.Disconnect{ @@ -226,7 +226,7 @@ func (ws *WebSocketManagerCtx) handle(connection *websocket.Conn, session types. Str("raw", string(raw)). Msg("received message from client") - if err := ws.handler.Message(session.ID(), raw); err != nil { + if err := ws.handler.Message(session, raw); err != nil { ws.logger.Error().Err(err).Msg("message handler has failed") } case <-cancel: diff --git a/internal/websocket/websocket.go b/internal/websocket/websocket.go index 6dd01327..11645794 100644 --- a/internal/websocket/websocket.go +++ b/internal/websocket/websocket.go @@ -6,10 +6,12 @@ import ( "sync" "github.com/gorilla/websocket" + + "demodesk/neko/internal/types" ) type WebSocketCtx struct { - id string + session types.Session address string ws *WebSocketManagerCtx connection *websocket.Conn @@ -40,7 +42,7 @@ func (socket *WebSocketCtx) Send(v interface{}) error { } socket.ws.logger.Debug(). - Str("session", socket.id). + Str("session", socket.session.ID()). Str("address", socket.connection.RemoteAddr().String()). Str("raw", string(raw)). Msg("sending message to client")