Add support for SSL Websockets

This commit is contained in:
badaix 2025-01-24 23:31:27 +01:00
parent 0a8b737f9f
commit 442b154fbf
5 changed files with 205 additions and 53 deletions

View file

@ -25,16 +25,18 @@
// 3rd party headers
#include <boost/asio/buffer.hpp>
#include <boost/asio/connect.hpp>
#include <boost/asio/read.hpp>
#include <boost/asio/streambuf.hpp>
#include <boost/asio/write.hpp>
// standard headers
#include <boost/beast/core/flat_buffer.hpp>
#include <boost/beast/core/stream_traits.hpp>
#include <boost/system/detail/error_code.hpp>
#include <cstdint>
#include <iostream>
#include <optional>
#include <string>
#include <utility>
@ -43,6 +45,9 @@ namespace http = beast::http; // from <boost/beast/http.hpp>
static constexpr auto LOG_TAG = "Connection";
static constexpr const char* WS_CLIENT_NAME = "Snapcast";
PendingRequest::PendingRequest(const boost::asio::strand<boost::asio::any_io_executor>& strand, uint16_t reqId, const MessageHandler<msg::BaseMessage>& handler)
: id_(reqId), timer_(strand), strand_(strand), handler_(handler)
{
@ -231,6 +236,28 @@ void ClientConnection::sendRequest(const msg::message_ptr& message, const chrono
}
void ClientConnection::messageReceived(std::unique_ptr<msg::BaseMessage> message, const MessageHandler<msg::BaseMessage>& handler)
{
for (auto iter = pendingRequests_.begin(); iter != pendingRequests_.end(); ++iter)
{
auto request = *iter;
if (auto req = request.lock())
{
if (req->id() == base_message_.refersTo)
{
req->setValue(std::move(message));
pendingRequests_.erase(iter);
getNextMessage(handler);
return;
}
}
}
if (handler)
handler({}, std::move(message));
}
///////////////////////////////////// TCP /////////////////////////////////////
ClientConnectionTcp::ClientConnectionTcp(boost::asio::io_context& io_context, ClientSettings::Server server)
@ -329,23 +356,8 @@ void ClientConnectionTcp::getNextMessage(const MessageHandler<msg::BaseMessage>&
auto response = msg::factory::createMessage(base_message_, buffer_.data());
if (!response)
LOG(WARNING, LOG_TAG) << "Failed to deserialize message of type: " << base_message_.type << "\n";
for (auto iter = pendingRequests_.begin(); iter != pendingRequests_.end(); ++iter)
{
auto request = *iter;
if (auto req = request.lock())
{
if (req->id() == base_message_.refersTo)
{
req->setValue(std::move(response));
pendingRequests_.erase(iter);
getNextMessage(handler);
return;
}
}
}
if (handler)
handler(ec, std::move(response));
messageReceived(std::move(response), handler);
});
});
}
@ -383,13 +395,13 @@ ClientConnectionWs::~ClientConnectionWs()
void ClientConnectionWs::disconnect()
{
LOG(DEBUG, LOG_TAG) << "Disconnecting\n";
if (!tcp_ws_->is_open())
if (!tcp_ws_.is_open())
{
LOG(DEBUG, LOG_TAG) << "Not connected\n";
return;
}
boost::system::error_code ec;
tcp_ws_->close(websocket::close_code::normal, ec);
tcp_ws_.close(websocket::close_code::normal, ec);
if (ec)
LOG(ERROR, LOG_TAG) << "Error in socket close: " << ec.message() << "\n";
boost::asio::post(strand_, [this]() { pendingRequests_.clear(); });
@ -401,20 +413,20 @@ std::string ClientConnectionWs::getMacAddress()
{
std::string mac =
#ifndef WINDOWS
::getMacAddress(tcp_ws_->next_layer().native_handle());
::getMacAddress(tcp_ws_.next_layer().native_handle());
#else
::getMacAddress(tcp_ws_->next_layer().local_endpoint().address().to_string());
::getMacAddress(tcp_ws_.next_layer().local_endpoint().address().to_string());
#endif
if (mac.empty())
mac = "00:00:00:00:00:00";
LOG(INFO, LOG_TAG) << "My MAC: \"" << mac << "\", socket: " << tcp_ws_->next_layer().native_handle() << "\n";
LOG(INFO, LOG_TAG) << "My MAC: \"" << mac << "\", socket: " << tcp_ws_.next_layer().native_handle() << "\n";
return mac;
}
void ClientConnectionWs::getNextMessage(const MessageHandler<msg::BaseMessage>& handler)
{
tcp_ws_->async_read(buffer_, [this, handler](beast::error_code ec, std::size_t bytes_transferred) mutable
tcp_ws_.async_read(buffer_, [this, handler](beast::error_code ec, std::size_t bytes_transferred) mutable
{
tv now;
LOG(DEBUG, LOG_TAG) << "on_read_ws, ec: " << ec << ", bytes_transferred: " << bytes_transferred << "\n";
@ -449,23 +461,7 @@ void ClientConnectionWs::getNextMessage(const MessageHandler<msg::BaseMessage>&
LOG(DEBUG, LOG_TAG) << "getNextMessage: " << response->type << ", size: " << response->size << ", id: " << response->id
<< ", refers: " << response->refersTo << "\n";
for (auto iter = pendingRequests_.begin(); iter != pendingRequests_.end(); ++iter)
{
auto request = *iter;
if (auto req = request.lock())
{
if (req->id() == base_message_.refersTo)
{
req->setValue(std::move(response));
pendingRequests_.erase(iter);
getNextMessage(handler);
return;
}
}
}
if (handler)
handler(ec, std::move(response));
messageReceived(std::move(response), handler);
});
}
@ -473,23 +469,146 @@ void ClientConnectionWs::getNextMessage(const MessageHandler<msg::BaseMessage>&
boost::system::error_code ClientConnectionWs::doConnect(boost::asio::ip::basic_endpoint<boost::asio::ip::tcp> endpoint)
{
boost::system::error_code ec;
tcp_ws_->binary(true);
tcp_ws_->next_layer().connect(endpoint, ec);
tcp_ws_.binary(true);
tcp_ws_.next_layer().connect(endpoint, ec);
// Set suggested timeout settings for the websocket
tcp_ws_->set_option(websocket::stream_base::timeout::suggested(beast::role_type::client));
tcp_ws_.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client));
// Set a decorator to change the User-Agent of the handshake
tcp_ws_->set_option(websocket::stream_base::decorator([](websocket::request_type& req)
{ req.set(http::field::user_agent, std::string(BOOST_BEAST_VERSION_STRING) + " websocket-client-async"); }));
tcp_ws_.set_option(websocket::stream_base::decorator([](websocket::request_type& req) { req.set(http::field::user_agent, WS_CLIENT_NAME); }));
// Perform the websocket handshake
tcp_ws_->handshake("127.0.0.1", "/stream", ec);
tcp_ws_.handshake(server_.host + ":" + std::to_string(server_.port), "/stream", ec);
return ec;
}
void ClientConnectionWs::write(boost::asio::streambuf& buffer, WriteHandler&& write_handler)
{
tcp_ws_->async_write(boost::asio::buffer(buffer.data()), write_handler);
tcp_ws_.async_write(boost::asio::buffer(buffer.data()), write_handler);
}
/////////////////////////////// SSL Websockets ////////////////////////////////
ClientConnectionWss::ClientConnectionWss(boost::asio::io_context& io_context, boost::asio::ssl::context& ssl_context, ClientSettings::Server server)
: ClientConnection(io_context, std::move(server)), ssl_ws_(strand_, ssl_context)
{
}
ClientConnectionWss::~ClientConnectionWss()
{
disconnect();
}
void ClientConnectionWss::disconnect()
{
LOG(DEBUG, LOG_TAG) << "Disconnecting\n";
if (!ssl_ws_.is_open())
{
LOG(DEBUG, LOG_TAG) << "Not connected\n";
return;
}
boost::system::error_code ec;
ssl_ws_.close(websocket::close_code::normal, ec);
if (ec)
LOG(ERROR, LOG_TAG) << "Error in socket close: " << ec.message() << "\n";
boost::asio::post(strand_, [this]() { pendingRequests_.clear(); });
LOG(DEBUG, LOG_TAG) << "Disconnected\n";
}
std::string ClientConnectionWss::getMacAddress()
{
std::string mac =
#ifndef WINDOWS
::getMacAddress(ssl_ws_.next_layer().lowest_layer().native_handle());
#else
::getMacAddress(tcp_ws_.next_layer().lowest_layer().local_endpoint().address().to_string());
#endif
if (mac.empty())
mac = "00:00:00:00:00:00";
LOG(INFO, LOG_TAG) << "My MAC: \"" << mac << "\", socket: " << ssl_ws_.next_layer().lowest_layer().native_handle() << "\n";
return mac;
}
void ClientConnectionWss::getNextMessage(const MessageHandler<msg::BaseMessage>& handler)
{
ssl_ws_.async_read(buffer_, [this, handler](beast::error_code ec, std::size_t bytes_transferred) mutable
{
tv now;
LOG(DEBUG, LOG_TAG) << "on_read_ws, ec: " << ec << ", bytes_transferred: " << bytes_transferred << "\n";
// This indicates that the session was closed
if (ec == websocket::error::closed)
{
if (handler)
handler(ec, nullptr);
return;
}
if (ec)
{
LOG(ERROR, LOG_TAG) << "ControlSessionWebsocket::on_read_ws error: " << ec.message() << "\n";
if (handler)
handler(ec, nullptr);
return;
}
buffer_.consume(bytes_transferred);
auto* data = static_cast<char*>(buffer_.data().data());
base_message_.deserialize(data);
base_message_.received = now;
auto response = msg::factory::createMessage(base_message_, data + base_msg_size_);
if (!response)
LOG(WARNING, LOG_TAG) << "Failed to deserialize message of type: " << base_message_.type << "\n";
else
LOG(DEBUG, LOG_TAG) << "getNextMessage: " << response->type << ", size: " << response->size << ", id: " << response->id
<< ", refers: " << response->refersTo << "\n";
messageReceived(std::move(response), handler);
});
}
boost::system::error_code ClientConnectionWss::doConnect(boost::asio::ip::basic_endpoint<boost::asio::ip::tcp> endpoint)
{
boost::system::error_code ec;
ssl_ws_.binary(true);
beast::get_lowest_layer(ssl_ws_).connect(endpoint, ec);
// Set a timeout on the operation
// beast::get_lowest_layer(ssl_ws_).expires_after(std::chrono::seconds(30));
// Set suggested timeout settings for the websocket
// ssl_ws_.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client));
// Set SNI Hostname (many hosts need this to handshake successfully)
if (!SSL_set_tlsext_host_name(ssl_ws_.next_layer().native_handle(), server_.host.c_str()))
throw beast::system_error(beast::error_code(static_cast<int>(::ERR_get_error()), boost::asio::error::get_ssl_category()), "Failed to set SNI Hostname");
// Perform the SSL handshake
ssl_ws_.next_layer().handshake(boost::asio::ssl::stream_base::client);
// Set a decorator to change the User-Agent of the handshake
ssl_ws_.set_option(websocket::stream_base::decorator([](websocket::request_type& req) { req.set(http::field::user_agent, WS_CLIENT_NAME); }));
// Perform the websocket handshake
ssl_ws_.handshake(server_.host + ":" + std::to_string(server_.port), "/stream", ec);
return ec;
}
void ClientConnectionWss::write(boost::asio::streambuf& buffer, WriteHandler&& write_handler)
{
ssl_ws_.async_write(boost::asio::buffer(buffer.data()), write_handler);
}