diff --git a/client/client_connection.cpp b/client/client_connection.cpp index 2d1a4269..e046f115 100644 --- a/client/client_connection.cpp +++ b/client/client_connection.cpp @@ -29,13 +29,15 @@ #include #include #include - -// standard headers #include #include #include + +// standard headers #include #include +#include +#include #include #include @@ -137,7 +139,10 @@ void ClientConnection::connect(const ResultHandler& handler) } if (ec) + { LOG(ERROR, LOG_TAG) << "Failed to connect to host '" << server_.host << "', error: " << ec.message() << "\n"; + disconnect(); + } else LOG(NOTICE, LOG_TAG) << "Connected to " << server_.host << "\n"; @@ -392,19 +397,33 @@ ClientConnectionWs::~ClientConnectionWs() } +tcp_websocket& ClientConnectionWs::getWs() +{ + std::lock_guard lock(ws_mutex_); + if (tcp_ws_.has_value()) + return tcp_ws_.value(); + + tcp_ws_.emplace(strand_); + return tcp_ws_.value(); +} + + void ClientConnectionWs::disconnect() { LOG(DEBUG, LOG_TAG) << "Disconnecting\n"; - if (!tcp_ws_.is_open()) - { - LOG(DEBUG, LOG_TAG) << "Not connected\n"; - return; - } boost::system::error_code ec; - tcp_ws_.close(websocket::close_code::normal, ec); - if (ec) - LOG(ERROR, LOG_TAG) << "Error in socket close: " << ec.message() << "\n"; + + if (getWs().is_open()) + getWs().close(websocket::close_code::normal, ec); + // if (ec) + // LOG(ERROR, LOG_TAG) << "Error in socket close: " << ec.message() << "\n"; + if (getWs().next_layer().is_open()) + { + getWs().next_layer().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec); + getWs().next_layer().close(ec); + } boost::asio::post(strand_, [this]() { pendingRequests_.clear(); }); + tcp_ws_ = std::nullopt; LOG(DEBUG, LOG_TAG) << "Disconnected\n"; } @@ -413,23 +432,23 @@ std::string ClientConnectionWs::getMacAddress() { std::string mac = #ifndef WINDOWS - ::getMacAddress(tcp_ws_.next_layer().native_handle()); + ::getMacAddress(getWs().next_layer().native_handle()); #else - ::getMacAddress(tcp_ws_.next_layer().local_endpoint().address().to_string()); + ::getMacAddress(getWs().next_layer().local_endpoint().address().to_string()); #endif if (mac.empty()) mac = "00:00:00:00:00:00"; - LOG(INFO, LOG_TAG) << "My MAC: \"" << mac << "\", socket: " << tcp_ws_.next_layer().native_handle() << "\n"; + LOG(INFO, LOG_TAG) << "My MAC: \"" << mac << "\", socket: " << getWs().next_layer().native_handle() << "\n"; return mac; } void ClientConnectionWs::getNextMessage(const MessageHandler& handler) { - tcp_ws_.async_read(buffer_, [this, handler](beast::error_code ec, std::size_t bytes_transferred) mutable + getWs().async_read(buffer_, [this, handler](beast::error_code ec, std::size_t bytes_transferred) mutable { tv now; - LOG(DEBUG, LOG_TAG) << "on_read_ws, ec: " << ec << ", bytes_transferred: " << bytes_transferred << "\n"; + LOG(TRACE, LOG_TAG) << "on_read_ws, ec: " << ec << ", bytes_transferred: " << bytes_transferred << "\n"; // This indicates that the session was closed if (ec == websocket::error::closed) @@ -458,7 +477,7 @@ void ClientConnectionWs::getNextMessage(const MessageHandler& if (!response) LOG(WARNING, LOG_TAG) << "Failed to deserialize message of type: " << base_message_.type << "\n"; else - LOG(DEBUG, LOG_TAG) << "getNextMessage: " << response->type << ", size: " << response->size << ", id: " << response->id + LOG(TRACE, LOG_TAG) << "getNextMessage: " << response->type << ", size: " << response->size << ", id: " << response->id << ", refers: " << response->refersTo << "\n"; messageReceived(std::move(response), handler); @@ -469,24 +488,26 @@ void ClientConnectionWs::getNextMessage(const MessageHandler& boost::system::error_code ClientConnectionWs::doConnect(boost::asio::ip::basic_endpoint endpoint) { boost::system::error_code ec; - tcp_ws_.binary(true); - tcp_ws_.next_layer().connect(endpoint, ec); + getWs().binary(true); + getWs().next_layer().connect(endpoint, ec); + if (ec.failed()) + return ec; // Set suggested timeout settings for the websocket - tcp_ws_.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); + getWs().set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); // Set a decorator to change the User-Agent of the handshake - tcp_ws_.set_option(websocket::stream_base::decorator([](websocket::request_type& req) { req.set(http::field::user_agent, WS_CLIENT_NAME); })); + getWs().set_option(websocket::stream_base::decorator([](websocket::request_type& req) { req.set(http::field::user_agent, WS_CLIENT_NAME); })); // Perform the websocket handshake - tcp_ws_.handshake(server_.host + ":" + std::to_string(server_.port), "/stream", ec); + getWs().handshake(server_.host + ":" + std::to_string(server_.port), "/stream", ec); return ec; } void ClientConnectionWs::write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) { - tcp_ws_.async_write(boost::asio::buffer(buffer.data()), write_handler); + getWs().async_write(boost::asio::buffer(buffer.data()), write_handler); } @@ -495,12 +516,23 @@ void ClientConnectionWs::write(boost::asio::streambuf& buffer, WriteHandler&& wr ClientConnectionWss::ClientConnectionWss(boost::asio::io_context& io_context, boost::asio::ssl::context& ssl_context, ClientSettings::Server server) - : ClientConnection(io_context, std::move(server)), ssl_ws_(strand_, ssl_context) + : ClientConnection(io_context, std::move(server)), ssl_context_(ssl_context) { - if (server.certificate.has_value()) + getWs(); +} + + +ssl_websocket& ClientConnectionWss::getWs() +{ + std::lock_guard lock(ws_mutex_); + if (ssl_ws_.has_value()) + return ssl_ws_.value(); + + ssl_ws_.emplace(strand_, ssl_context_); + if (server_.certificate.has_value()) { - ssl_ws_.next_layer().set_verify_mode(boost::asio::ssl::verify_peer); - ssl_ws_.next_layer().set_verify_callback([](bool preverified, boost::asio::ssl::verify_context& ctx) + ssl_ws_->next_layer().set_verify_mode(boost::asio::ssl::verify_peer); + ssl_ws_->next_layer().set_verify_callback([](bool preverified, boost::asio::ssl::verify_context& ctx) { // The verify callback can be used to check whether the certificate that is // being presented is valid for the peer. For example, RFC 2818 describes @@ -518,6 +550,7 @@ ClientConnectionWss::ClientConnectionWss(boost::asio::io_context& io_context, bo return preverified; }); } + return ssl_ws_.value(); } @@ -530,16 +563,19 @@ ClientConnectionWss::~ClientConnectionWss() void ClientConnectionWss::disconnect() { LOG(DEBUG, LOG_TAG) << "Disconnecting\n"; - if (!ssl_ws_.is_open()) - { - LOG(DEBUG, LOG_TAG) << "Not connected\n"; - return; - } boost::system::error_code ec; - ssl_ws_.close(websocket::close_code::normal, ec); - if (ec) - LOG(ERROR, LOG_TAG) << "Error in socket close: " << ec.message() << "\n"; + + if (getWs().is_open()) + getWs().close(websocket::close_code::normal, ec); + // if (ec) + // LOG(ERROR, LOG_TAG) << "Error in socket close: " << ec.message() << "\n"; + if (getWs().next_layer().lowest_layer().is_open()) + { + getWs().next_layer().lowest_layer().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec); + getWs().next_layer().lowest_layer().close(ec); + } boost::asio::post(strand_, [this]() { pendingRequests_.clear(); }); + ssl_ws_ = std::nullopt; LOG(DEBUG, LOG_TAG) << "Disconnected\n"; } @@ -548,23 +584,23 @@ std::string ClientConnectionWss::getMacAddress() { std::string mac = #ifndef WINDOWS - ::getMacAddress(ssl_ws_.next_layer().lowest_layer().native_handle()); + ::getMacAddress(getWs().next_layer().lowest_layer().native_handle()); #else ::getMacAddress(ssl_ws_.next_layer().lowest_layer().local_endpoint().address().to_string()); #endif if (mac.empty()) mac = "00:00:00:00:00:00"; - LOG(INFO, LOG_TAG) << "My MAC: \"" << mac << "\", socket: " << ssl_ws_.next_layer().lowest_layer().native_handle() << "\n"; + LOG(INFO, LOG_TAG) << "My MAC: \"" << mac << "\", socket: " << getWs().next_layer().lowest_layer().native_handle() << "\n"; return mac; } void ClientConnectionWss::getNextMessage(const MessageHandler& handler) { - ssl_ws_.async_read(buffer_, [this, handler](beast::error_code ec, std::size_t bytes_transferred) mutable + getWs().async_read(buffer_, [this, handler](beast::error_code ec, std::size_t bytes_transferred) mutable { tv now; - LOG(DEBUG, LOG_TAG) << "on_read_ws, ec: " << ec << ", bytes_transferred: " << bytes_transferred << "\n"; + LOG(TRACE, LOG_TAG) << "on_read_ws, ec: " << ec << ", bytes_transferred: " << bytes_transferred << "\n"; // This indicates that the session was closed if (ec == websocket::error::closed) @@ -593,7 +629,7 @@ void ClientConnectionWss::getNextMessage(const MessageHandler& if (!response) LOG(WARNING, LOG_TAG) << "Failed to deserialize message of type: " << base_message_.type << "\n"; else - LOG(DEBUG, LOG_TAG) << "getNextMessage: " << response->type << ", size: " << response->size << ", id: " << response->id + LOG(TRACE, LOG_TAG) << "getNextMessage: " << response->type << ", size: " << response->size << ", id: " << response->id << ", refers: " << response->refersTo << "\n"; messageReceived(std::move(response), handler); @@ -604,32 +640,40 @@ void ClientConnectionWss::getNextMessage(const MessageHandler& boost::system::error_code ClientConnectionWss::doConnect(boost::asio::ip::basic_endpoint endpoint) { boost::system::error_code ec; - ssl_ws_.binary(true); - beast::get_lowest_layer(ssl_ws_).connect(endpoint, ec); + getWs().binary(true); + beast::get_lowest_layer(*ssl_ws_).connect(endpoint, ec); + if (ec.failed()) + return ec; // Set a timeout on the operation // beast::get_lowest_layer(ssl_ws_).expires_after(std::chrono::seconds(30)); // Set suggested timeout settings for the websocket - ssl_ws_.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); + getWs().set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); // Set SNI Hostname (many hosts need this to handshake successfully) - if (!SSL_set_tlsext_host_name(ssl_ws_.next_layer().native_handle(), server_.host.c_str())) - throw beast::system_error(beast::error_code(static_cast(::ERR_get_error()), boost::asio::error::get_ssl_category()), "Failed to set SNI Hostname"); + if (!SSL_set_tlsext_host_name(getWs().next_layer().native_handle(), server_.host.c_str())) + { + LOG(ERROR, LOG_TAG) << "Failed to set SNI Hostname\n"; + return boost::system::error_code(static_cast(::ERR_get_error()), boost::asio::error::get_ssl_category()); + } // Perform the SSL handshake - ssl_ws_.next_layer().handshake(boost::asio::ssl::stream_base::client); + getWs().next_layer().handshake(boost::asio::ssl::stream_base::client, ec); + if (ec.failed()) + return ec; // Set a decorator to change the User-Agent of the handshake - ssl_ws_.set_option(websocket::stream_base::decorator([](websocket::request_type& req) { req.set(http::field::user_agent, WS_CLIENT_NAME); })); + getWs().set_option(websocket::stream_base::decorator([](websocket::request_type& req) { req.set(http::field::user_agent, WS_CLIENT_NAME); })); // Perform the websocket handshake - ssl_ws_.handshake(server_.host + ":" + std::to_string(server_.port), "/stream", ec); + getWs().handshake(server_.host + ":" + std::to_string(server_.port), "/stream", ec); + return ec; } void ClientConnectionWss::write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) { - ssl_ws_.async_write(boost::asio::buffer(buffer.data()), write_handler); + getWs().async_write(boost::asio::buffer(buffer.data()), write_handler); } diff --git a/client/client_connection.hpp b/client/client_connection.hpp index ef38777f..94393937 100644 --- a/client/client_connection.hpp +++ b/client/client_connection.hpp @@ -37,6 +37,8 @@ // standard headers #include #include +#include +#include #include @@ -235,10 +237,15 @@ private: boost::system::error_code doConnect(boost::asio::ip::basic_endpoint endpoint) override; void write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) override; + /// @return the websocket + tcp_websocket& getWs(); + /// TCP web socket - tcp_websocket tcp_ws_; + std::optional tcp_ws_; /// Receive buffer boost::beast::flat_buffer buffer_; + /// protect ssl_ws_ + std::mutex ws_mutex_; }; @@ -260,8 +267,15 @@ private: boost::system::error_code doConnect(boost::asio::ip::basic_endpoint endpoint) override; void write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) override; + /// @return the websocket + ssl_websocket& getWs(); + + /// SSL context + boost::asio::ssl::context& ssl_context_; /// SSL web socket - ssl_websocket ssl_ws_; + std::optional ssl_ws_; /// Receive buffer boost::beast::flat_buffer buffer_; + /// protect ssl_ws_ + std::mutex ws_mutex_; }; diff --git a/common/snap_exception.hpp b/common/snap_exception.hpp index b40fc97b..d0bdae03 100644 --- a/common/snap_exception.hpp +++ b/common/snap_exception.hpp @@ -1,6 +1,6 @@ /*** This file is part of snapcast - Copyright (C) 2014-2024 Johannes Pohl + Copyright (C) 2014-2025 Johannes Pohl This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -29,15 +29,15 @@ class SnapException : public std::exception int error_code_; public: - SnapException(const char* text, int error_code = 0) : text_(text), error_code_(error_code) + explicit SnapException(const char* text, int error_code = 0) : text_(text), error_code_(error_code) { } - SnapException(const std::string& text, int error_code = 0) : SnapException(text.c_str(), error_code) + explicit SnapException(const std::string& text, int error_code = 0) : SnapException(text.c_str(), error_code) { } - ~SnapException() throw() override = default; + ~SnapException() override = default; int code() const noexcept {