/***
This file is part of snapcast
Copyright (C) 2014-2025 Johannes Pohl
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
***/
#pragma once
// local headers
#include "client_settings.hpp"
#include "common/message/factory.hpp"
#include "common/message/message.hpp"
#include "common/time_defs.hpp"
// 3rd party headers
#include
#include
#include
#include
#include
#include
#include
#include
// standard headers
#include
#include
#include
#include
#include
// using boost::asio::ip::tcp;
namespace beast = boost::beast; // from
namespace websocket = beast::websocket; // from
using tcp_socket = boost::asio::ip::tcp::socket;
using ssl_socket = boost::asio::ssl::stream;
using tcp_websocket = websocket::stream;
#ifdef HAS_OPENSSL
using ssl_websocket = websocket::stream;
#endif
class ClientConnection;
template
using MessageHandler = std::function)>;
/// Used to synchronize server requests (wait for server response)
class PendingRequest : public std::enable_shared_from_this
{
public:
/// c'tor
PendingRequest(const boost::asio::strand& strand, uint16_t reqId, const MessageHandler& handler);
/// d'tor
virtual ~PendingRequest();
/// Set the response for the pending request and passes it to the handler
/// @param value the response message
void setValue(std::unique_ptr value);
/// @return the id of the request
uint16_t id() const;
/// Start the timer for the request
/// @param timeout the timeout to wait for the reception of the response
void startTimer(const chronos::usec& timeout);
/// Needed to put the requests in a container
bool operator<(const PendingRequest& other) const;
private:
uint16_t id_;
boost::asio::steady_timer timer_;
boost::asio::strand strand_;
MessageHandler handler_;
};
/// Endpoint of the server connection
/**
* Server connection endpoint.
* Messages are sent to the server with the "send" method (async).
* Messages are sent sync to server with the sendReq method.
*/
class ClientConnection
{
public:
/// Result callback with boost::error_code
using ResultHandler = std::function;
/// Result callback of a write operation
using WriteHandler = std::function;
/// c'tor
ClientConnection(boost::asio::io_context& io_context, ClientSettings::Server server);
/// d'tor
virtual ~ClientConnection() = default;
/// async connect
/// @param handler async result handler
void connect(const ResultHandler& handler);
/// disconnect the socket
virtual void disconnect() = 0;
/// async send a message
/// @param message the message
/// @param handler the result handler
void send(const msg::message_ptr& message, const ResultHandler& handler);
/// Send request to the server and wait for answer
/// @param message the message
/// @param timeout the send timeout
/// @param handler async result handler with the response message or error
void sendRequest(const msg::message_ptr& message, const chronos::usec& timeout, const MessageHandler& handler);
/// @sa sendRequest with templated response message
template
void sendRequest(const msg::message_ptr& message, const chronos::usec& timeout, const MessageHandler& handler)
{
sendRequest(message, timeout, [handler](const boost::system::error_code& ec, std::unique_ptr response)
{
if (ec)
handler(ec, nullptr);
else if (auto casted_response = msg::message_cast(std::move(response)))
handler(ec, std::move(casted_response));
else
handler(boost::system::errc::make_error_code(boost::system::errc::bad_message), nullptr);
});
}
/// @return MAC address of the client
virtual std::string getMacAddress() = 0;
/// async get the next message
/// @param handler the next received message or error
virtual void getNextMessage(const MessageHandler& handler) = 0;
protected:
virtual void write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) = 0;
/// Connect to @p endpoint
virtual boost::system::error_code doConnect(boost::asio::ip::basic_endpoint endpoint) = 0;
/// Handle received messages, check for response of pending requests
void messageReceived(std::unique_ptr message, const MessageHandler& handler);
/// Send next pending message from messages_
void sendNext();
/// Base message holding the received message
msg::BaseMessage base_message_;
/// Strand to serialize send/receive
boost::asio::strand strand_;
/// TCP resolver
boost::asio::ip::tcp::resolver resolver_;
/// List of pending requests, waiting for a response (Message::refersTo)
std::vector> pendingRequests_;
/// unique request id to match a response
uint16_t reqId_;
/// Server settings (host and port)
ClientSettings::Server server_;
/// Size of a base message (= message header)
const size_t base_msg_size_;
/// Send stream buffer
boost::asio::streambuf streambuf_;
/// A pending request
struct PendingMessage
{
/// c'tor
PendingMessage(msg::message_ptr msg, ResultHandler handler) : msg(std::move(msg)), handler(std::move(handler))
{
}
/// Pointer to the request
msg::message_ptr msg;
/// Response handler
ResultHandler handler;
};
/// Pending messages to be sent
std::deque messages_;
};
/// Plain TCP connection
class ClientConnectionTcp : public ClientConnection
{
public:
/// c'tor
ClientConnectionTcp(boost::asio::io_context& io_context, ClientSettings::Server server);
/// d'tor
virtual ~ClientConnectionTcp();
void disconnect() override;
std::string getMacAddress() override;
void getNextMessage(const MessageHandler& handler) override;
private:
boost::system::error_code doConnect(boost::asio::ip::basic_endpoint endpoint) override;
void write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) override;
/// TCP socket
tcp_socket socket_;
/// Receive buffer
std::vector buffer_;
};
/// Websocket connection
class ClientConnectionWs : public ClientConnection
{
public:
/// c'tor
ClientConnectionWs(boost::asio::io_context& io_context, ClientSettings::Server server);
/// d'tor
virtual ~ClientConnectionWs();
void disconnect() override;
std::string getMacAddress() override;
void getNextMessage(const MessageHandler& handler) override;
private:
boost::system::error_code doConnect(boost::asio::ip::basic_endpoint endpoint) override;
void write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) override;
/// @return the websocket
tcp_websocket& getWs();
/// TCP web socket
std::optional tcp_ws_;
/// Receive buffer
boost::beast::flat_buffer buffer_;
/// protect tcp_ws_
std::mutex ws_mutex_;
};
#ifdef HAS_OPENSSL
/// Websocket connection
class ClientConnectionWss : public ClientConnection
{
public:
/// c'tor
ClientConnectionWss(boost::asio::io_context& io_context, boost::asio::ssl::context& ssl_context, ClientSettings::Server server);
/// d'tor
virtual ~ClientConnectionWss();
void disconnect() override;
std::string getMacAddress() override;
void getNextMessage(const MessageHandler& handler) override;
private:
boost::system::error_code doConnect(boost::asio::ip::basic_endpoint endpoint) override;
void write(boost::asio::streambuf& buffer, WriteHandler&& write_handler) override;
/// @return the websocket
ssl_websocket& getWs();
/// SSL context
boost::asio::ssl::context& ssl_context_;
/// SSL web socket
std::optional ssl_ws_;
/// Receive buffer
boost::beast::flat_buffer buffer_;
/// protect ssl_ws_
std::mutex ws_mutex_;
};
#endif // HAS_OPENSSL