Fix wss reconnect

This commit is contained in:
badaix 2025-01-27 10:34:13 +01:00
parent 054706e608
commit a407e68df6
3 changed files with 112 additions and 54 deletions

View file

@ -29,13 +29,15 @@
#include <boost/asio/read.hpp> #include <boost/asio/read.hpp>
#include <boost/asio/streambuf.hpp> #include <boost/asio/streambuf.hpp>
#include <boost/asio/write.hpp> #include <boost/asio/write.hpp>
// standard headers
#include <boost/beast/core/flat_buffer.hpp> #include <boost/beast/core/flat_buffer.hpp>
#include <boost/beast/core/stream_traits.hpp> #include <boost/beast/core/stream_traits.hpp>
#include <boost/system/detail/error_code.hpp> #include <boost/system/detail/error_code.hpp>
// standard headers
#include <cstdint> #include <cstdint>
#include <iostream> #include <iostream>
#include <mutex>
#include <optional>
#include <string> #include <string>
#include <utility> #include <utility>
@ -137,7 +139,10 @@ void ClientConnection::connect(const ResultHandler& handler)
} }
if (ec) if (ec)
{
LOG(ERROR, LOG_TAG) << "Failed to connect to host '" << server_.host << "', error: " << ec.message() << "\n"; LOG(ERROR, LOG_TAG) << "Failed to connect to host '" << server_.host << "', error: " << ec.message() << "\n";
disconnect();
}
else else
LOG(NOTICE, LOG_TAG) << "Connected to " << server_.host << "\n"; LOG(NOTICE, LOG_TAG) << "Connected to " << server_.host << "\n";
@ -392,19 +397,33 @@ ClientConnectionWs::~ClientConnectionWs()
} }
tcp_websocket& ClientConnectionWs::getWs()
{
std::lock_guard lock(ws_mutex_);
if (tcp_ws_.has_value())
return tcp_ws_.value();
tcp_ws_.emplace(strand_);
return tcp_ws_.value();
}
void ClientConnectionWs::disconnect() void ClientConnectionWs::disconnect()
{ {
LOG(DEBUG, LOG_TAG) << "Disconnecting\n"; LOG(DEBUG, LOG_TAG) << "Disconnecting\n";
if (!tcp_ws_.is_open())
{
LOG(DEBUG, LOG_TAG) << "Not connected\n";
return;
}
boost::system::error_code ec; boost::system::error_code ec;
tcp_ws_.close(websocket::close_code::normal, ec);
if (ec) if (getWs().is_open())
LOG(ERROR, LOG_TAG) << "Error in socket close: " << ec.message() << "\n"; getWs().close(websocket::close_code::normal, ec);
// if (ec)
// LOG(ERROR, LOG_TAG) << "Error in socket close: " << ec.message() << "\n";
if (getWs().next_layer().is_open())
{
getWs().next_layer().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec);
getWs().next_layer().close(ec);
}
boost::asio::post(strand_, [this]() { pendingRequests_.clear(); }); boost::asio::post(strand_, [this]() { pendingRequests_.clear(); });
tcp_ws_ = std::nullopt;
LOG(DEBUG, LOG_TAG) << "Disconnected\n"; LOG(DEBUG, LOG_TAG) << "Disconnected\n";
} }
@ -413,23 +432,23 @@ std::string ClientConnectionWs::getMacAddress()
{ {
std::string mac = std::string mac =
#ifndef WINDOWS #ifndef WINDOWS
::getMacAddress(tcp_ws_.next_layer().native_handle()); ::getMacAddress(getWs().next_layer().native_handle());
#else #else
::getMacAddress(tcp_ws_.next_layer().local_endpoint().address().to_string()); ::getMacAddress(getWs().next_layer().local_endpoint().address().to_string());
#endif #endif
if (mac.empty()) if (mac.empty())
mac = "00:00:00:00:00:00"; mac = "00:00:00:00:00:00";
LOG(INFO, LOG_TAG) << "My MAC: \"" << mac << "\", socket: " << tcp_ws_.next_layer().native_handle() << "\n"; LOG(INFO, LOG_TAG) << "My MAC: \"" << mac << "\", socket: " << getWs().next_layer().native_handle() << "\n";
return mac; return mac;
} }
void ClientConnectionWs::getNextMessage(const MessageHandler<msg::BaseMessage>& handler) void ClientConnectionWs::getNextMessage(const MessageHandler<msg::BaseMessage>& handler)
{ {
tcp_ws_.async_read(buffer_, [this, handler](beast::error_code ec, std::size_t bytes_transferred) mutable getWs().async_read(buffer_, [this, handler](beast::error_code ec, std::size_t bytes_transferred) mutable
{ {
tv now; tv now;
LOG(DEBUG, LOG_TAG) << "on_read_ws, ec: " << ec << ", bytes_transferred: " << bytes_transferred << "\n"; LOG(TRACE, LOG_TAG) << "on_read_ws, ec: " << ec << ", bytes_transferred: " << bytes_transferred << "\n";
// This indicates that the session was closed // This indicates that the session was closed
if (ec == websocket::error::closed) if (ec == websocket::error::closed)
@ -458,7 +477,7 @@ void ClientConnectionWs::getNextMessage(const MessageHandler<msg::BaseMessage>&
if (!response) if (!response)
LOG(WARNING, LOG_TAG) << "Failed to deserialize message of type: " << base_message_.type << "\n"; LOG(WARNING, LOG_TAG) << "Failed to deserialize message of type: " << base_message_.type << "\n";
else else
LOG(DEBUG, LOG_TAG) << "getNextMessage: " << response->type << ", size: " << response->size << ", id: " << response->id LOG(TRACE, LOG_TAG) << "getNextMessage: " << response->type << ", size: " << response->size << ", id: " << response->id
<< ", refers: " << response->refersTo << "\n"; << ", refers: " << response->refersTo << "\n";
messageReceived(std::move(response), handler); messageReceived(std::move(response), handler);
@ -469,24 +488,26 @@ void ClientConnectionWs::getNextMessage(const MessageHandler<msg::BaseMessage>&
boost::system::error_code ClientConnectionWs::doConnect(boost::asio::ip::basic_endpoint<boost::asio::ip::tcp> endpoint) boost::system::error_code ClientConnectionWs::doConnect(boost::asio::ip::basic_endpoint<boost::asio::ip::tcp> endpoint)
{ {
boost::system::error_code ec; boost::system::error_code ec;
tcp_ws_.binary(true); getWs().binary(true);
tcp_ws_.next_layer().connect(endpoint, ec); getWs().next_layer().connect(endpoint, ec);
if (ec.failed())
return ec;
// Set suggested timeout settings for the websocket // Set suggested timeout settings for the websocket
tcp_ws_.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); getWs().set_option(websocket::stream_base::timeout::suggested(beast::role_type::client));
// Set a decorator to change the User-Agent of the handshake // Set a decorator to change the User-Agent of the handshake
tcp_ws_.set_option(websocket::stream_base::decorator([](websocket::request_type& req) { req.set(http::field::user_agent, WS_CLIENT_NAME); })); getWs().set_option(websocket::stream_base::decorator([](websocket::request_type& req) { req.set(http::field::user_agent, WS_CLIENT_NAME); }));
// Perform the websocket handshake // Perform the websocket handshake
tcp_ws_.handshake(server_.host + ":" + std::to_string(server_.port), "/stream", ec); getWs().handshake(server_.host + ":" + std::to_string(server_.port), "/stream", ec);
return ec; return ec;
} }
void ClientConnectionWs::write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) void ClientConnectionWs::write(boost::asio::streambuf& buffer, WriteHandler&& write_handler)
{ {
tcp_ws_.async_write(boost::asio::buffer(buffer.data()), write_handler); getWs().async_write(boost::asio::buffer(buffer.data()), write_handler);
} }
@ -495,12 +516,23 @@ void ClientConnectionWs::write(boost::asio::streambuf& buffer, WriteHandler&& wr
ClientConnectionWss::ClientConnectionWss(boost::asio::io_context& io_context, boost::asio::ssl::context& ssl_context, ClientSettings::Server server) ClientConnectionWss::ClientConnectionWss(boost::asio::io_context& io_context, boost::asio::ssl::context& ssl_context, ClientSettings::Server server)
: ClientConnection(io_context, std::move(server)), ssl_ws_(strand_, ssl_context) : ClientConnection(io_context, std::move(server)), ssl_context_(ssl_context)
{ {
if (server.certificate.has_value()) getWs();
}
ssl_websocket& ClientConnectionWss::getWs()
{
std::lock_guard lock(ws_mutex_);
if (ssl_ws_.has_value())
return ssl_ws_.value();
ssl_ws_.emplace(strand_, ssl_context_);
if (server_.certificate.has_value())
{ {
ssl_ws_.next_layer().set_verify_mode(boost::asio::ssl::verify_peer); ssl_ws_->next_layer().set_verify_mode(boost::asio::ssl::verify_peer);
ssl_ws_.next_layer().set_verify_callback([](bool preverified, boost::asio::ssl::verify_context& ctx) ssl_ws_->next_layer().set_verify_callback([](bool preverified, boost::asio::ssl::verify_context& ctx)
{ {
// The verify callback can be used to check whether the certificate that is // The verify callback can be used to check whether the certificate that is
// being presented is valid for the peer. For example, RFC 2818 describes // being presented is valid for the peer. For example, RFC 2818 describes
@ -518,6 +550,7 @@ ClientConnectionWss::ClientConnectionWss(boost::asio::io_context& io_context, bo
return preverified; return preverified;
}); });
} }
return ssl_ws_.value();
} }
@ -530,16 +563,19 @@ ClientConnectionWss::~ClientConnectionWss()
void ClientConnectionWss::disconnect() void ClientConnectionWss::disconnect()
{ {
LOG(DEBUG, LOG_TAG) << "Disconnecting\n"; LOG(DEBUG, LOG_TAG) << "Disconnecting\n";
if (!ssl_ws_.is_open())
{
LOG(DEBUG, LOG_TAG) << "Not connected\n";
return;
}
boost::system::error_code ec; boost::system::error_code ec;
ssl_ws_.close(websocket::close_code::normal, ec);
if (ec) if (getWs().is_open())
LOG(ERROR, LOG_TAG) << "Error in socket close: " << ec.message() << "\n"; getWs().close(websocket::close_code::normal, ec);
// if (ec)
// LOG(ERROR, LOG_TAG) << "Error in socket close: " << ec.message() << "\n";
if (getWs().next_layer().lowest_layer().is_open())
{
getWs().next_layer().lowest_layer().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec);
getWs().next_layer().lowest_layer().close(ec);
}
boost::asio::post(strand_, [this]() { pendingRequests_.clear(); }); boost::asio::post(strand_, [this]() { pendingRequests_.clear(); });
ssl_ws_ = std::nullopt;
LOG(DEBUG, LOG_TAG) << "Disconnected\n"; LOG(DEBUG, LOG_TAG) << "Disconnected\n";
} }
@ -548,23 +584,23 @@ std::string ClientConnectionWss::getMacAddress()
{ {
std::string mac = std::string mac =
#ifndef WINDOWS #ifndef WINDOWS
::getMacAddress(ssl_ws_.next_layer().lowest_layer().native_handle()); ::getMacAddress(getWs().next_layer().lowest_layer().native_handle());
#else #else
::getMacAddress(ssl_ws_.next_layer().lowest_layer().local_endpoint().address().to_string()); ::getMacAddress(ssl_ws_.next_layer().lowest_layer().local_endpoint().address().to_string());
#endif #endif
if (mac.empty()) if (mac.empty())
mac = "00:00:00:00:00:00"; mac = "00:00:00:00:00:00";
LOG(INFO, LOG_TAG) << "My MAC: \"" << mac << "\", socket: " << ssl_ws_.next_layer().lowest_layer().native_handle() << "\n"; LOG(INFO, LOG_TAG) << "My MAC: \"" << mac << "\", socket: " << getWs().next_layer().lowest_layer().native_handle() << "\n";
return mac; return mac;
} }
void ClientConnectionWss::getNextMessage(const MessageHandler<msg::BaseMessage>& handler) void ClientConnectionWss::getNextMessage(const MessageHandler<msg::BaseMessage>& handler)
{ {
ssl_ws_.async_read(buffer_, [this, handler](beast::error_code ec, std::size_t bytes_transferred) mutable getWs().async_read(buffer_, [this, handler](beast::error_code ec, std::size_t bytes_transferred) mutable
{ {
tv now; tv now;
LOG(DEBUG, LOG_TAG) << "on_read_ws, ec: " << ec << ", bytes_transferred: " << bytes_transferred << "\n"; LOG(TRACE, LOG_TAG) << "on_read_ws, ec: " << ec << ", bytes_transferred: " << bytes_transferred << "\n";
// This indicates that the session was closed // This indicates that the session was closed
if (ec == websocket::error::closed) if (ec == websocket::error::closed)
@ -593,7 +629,7 @@ void ClientConnectionWss::getNextMessage(const MessageHandler<msg::BaseMessage>&
if (!response) if (!response)
LOG(WARNING, LOG_TAG) << "Failed to deserialize message of type: " << base_message_.type << "\n"; LOG(WARNING, LOG_TAG) << "Failed to deserialize message of type: " << base_message_.type << "\n";
else else
LOG(DEBUG, LOG_TAG) << "getNextMessage: " << response->type << ", size: " << response->size << ", id: " << response->id LOG(TRACE, LOG_TAG) << "getNextMessage: " << response->type << ", size: " << response->size << ", id: " << response->id
<< ", refers: " << response->refersTo << "\n"; << ", refers: " << response->refersTo << "\n";
messageReceived(std::move(response), handler); messageReceived(std::move(response), handler);
@ -604,32 +640,40 @@ void ClientConnectionWss::getNextMessage(const MessageHandler<msg::BaseMessage>&
boost::system::error_code ClientConnectionWss::doConnect(boost::asio::ip::basic_endpoint<boost::asio::ip::tcp> endpoint) boost::system::error_code ClientConnectionWss::doConnect(boost::asio::ip::basic_endpoint<boost::asio::ip::tcp> endpoint)
{ {
boost::system::error_code ec; boost::system::error_code ec;
ssl_ws_.binary(true); getWs().binary(true);
beast::get_lowest_layer(ssl_ws_).connect(endpoint, ec); beast::get_lowest_layer(*ssl_ws_).connect(endpoint, ec);
if (ec.failed())
return ec;
// Set a timeout on the operation // Set a timeout on the operation
// beast::get_lowest_layer(ssl_ws_).expires_after(std::chrono::seconds(30)); // beast::get_lowest_layer(ssl_ws_).expires_after(std::chrono::seconds(30));
// Set suggested timeout settings for the websocket // Set suggested timeout settings for the websocket
ssl_ws_.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); getWs().set_option(websocket::stream_base::timeout::suggested(beast::role_type::client));
// Set SNI Hostname (many hosts need this to handshake successfully) // Set SNI Hostname (many hosts need this to handshake successfully)
if (!SSL_set_tlsext_host_name(ssl_ws_.next_layer().native_handle(), server_.host.c_str())) if (!SSL_set_tlsext_host_name(getWs().next_layer().native_handle(), server_.host.c_str()))
throw beast::system_error(beast::error_code(static_cast<int>(::ERR_get_error()), boost::asio::error::get_ssl_category()), "Failed to set SNI Hostname"); {
LOG(ERROR, LOG_TAG) << "Failed to set SNI Hostname\n";
return boost::system::error_code(static_cast<int>(::ERR_get_error()), boost::asio::error::get_ssl_category());
}
// Perform the SSL handshake // Perform the SSL handshake
ssl_ws_.next_layer().handshake(boost::asio::ssl::stream_base::client); getWs().next_layer().handshake(boost::asio::ssl::stream_base::client, ec);
if (ec.failed())
return ec;
// Set a decorator to change the User-Agent of the handshake // Set a decorator to change the User-Agent of the handshake
ssl_ws_.set_option(websocket::stream_base::decorator([](websocket::request_type& req) { req.set(http::field::user_agent, WS_CLIENT_NAME); })); getWs().set_option(websocket::stream_base::decorator([](websocket::request_type& req) { req.set(http::field::user_agent, WS_CLIENT_NAME); }));
// Perform the websocket handshake // Perform the websocket handshake
ssl_ws_.handshake(server_.host + ":" + std::to_string(server_.port), "/stream", ec); getWs().handshake(server_.host + ":" + std::to_string(server_.port), "/stream", ec);
return ec; return ec;
} }
void ClientConnectionWss::write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) void ClientConnectionWss::write(boost::asio::streambuf& buffer, WriteHandler&& write_handler)
{ {
ssl_ws_.async_write(boost::asio::buffer(buffer.data()), write_handler); getWs().async_write(boost::asio::buffer(buffer.data()), write_handler);
} }

View file

@ -37,6 +37,8 @@
// standard headers // standard headers
#include <deque> #include <deque>
#include <memory> #include <memory>
#include <mutex>
#include <optional>
#include <string> #include <string>
@ -235,10 +237,15 @@ private:
boost::system::error_code doConnect(boost::asio::ip::basic_endpoint<boost::asio::ip::tcp> endpoint) override; boost::system::error_code doConnect(boost::asio::ip::basic_endpoint<boost::asio::ip::tcp> endpoint) override;
void write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) override; void write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) override;
/// @return the websocket
tcp_websocket& getWs();
/// TCP web socket /// TCP web socket
tcp_websocket tcp_ws_; std::optional<tcp_websocket> tcp_ws_;
/// Receive buffer /// Receive buffer
boost::beast::flat_buffer buffer_; boost::beast::flat_buffer buffer_;
/// protect ssl_ws_
std::mutex ws_mutex_;
}; };
@ -260,8 +267,15 @@ private:
boost::system::error_code doConnect(boost::asio::ip::basic_endpoint<boost::asio::ip::tcp> endpoint) override; boost::system::error_code doConnect(boost::asio::ip::basic_endpoint<boost::asio::ip::tcp> endpoint) override;
void write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) override; void write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) override;
/// @return the websocket
ssl_websocket& getWs();
/// SSL context
boost::asio::ssl::context& ssl_context_;
/// SSL web socket /// SSL web socket
ssl_websocket ssl_ws_; std::optional<ssl_websocket> ssl_ws_;
/// Receive buffer /// Receive buffer
boost::beast::flat_buffer buffer_; boost::beast::flat_buffer buffer_;
/// protect ssl_ws_
std::mutex ws_mutex_;
}; };

View file

@ -1,6 +1,6 @@
/*** /***
This file is part of snapcast This file is part of snapcast
Copyright (C) 2014-2024 Johannes Pohl Copyright (C) 2014-2025 Johannes Pohl
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
@ -29,15 +29,15 @@ class SnapException : public std::exception
int error_code_; int error_code_;
public: public:
SnapException(const char* text, int error_code = 0) : text_(text), error_code_(error_code) explicit SnapException(const char* text, int error_code = 0) : text_(text), error_code_(error_code)
{ {
} }
SnapException(const std::string& text, int error_code = 0) : SnapException(text.c_str(), error_code) explicit SnapException(const std::string& text, int error_code = 0) : SnapException(text.c_str(), error_code)
{ {
} }
~SnapException() throw() override = default; ~SnapException() override = default;
int code() const noexcept int code() const noexcept
{ {