diff --git a/client/client_settings.hpp b/client/client_settings.hpp index 72ec60bf..071de9d3 100644 --- a/client/client_settings.hpp +++ b/client/client_settings.hpp @@ -60,16 +60,23 @@ struct ClientSettings /// Server settings struct Server { + /// Auth info + struct Auth + { + /// the scheme (Basic, Plain, bearer, ...) + std::string scheme; + /// the param (base64 encoded ":", ":", token, ...) + std::string param; + }; + /// server host or IP address std::string host; /// protocol: "tcp", "ws" or "wss" std::string protocol{"tcp"}; /// server port size_t port{1704}; - /// username - std::optional username; - /// password - std::optional password; + /// auth info + std::optional auth; /// server certificate std::optional server_certificate; /// Certificate file diff --git a/client/controller.cpp b/client/controller.cpp index 12cdc81f..d470a2c2 100644 --- a/client/controller.cpp +++ b/client/controller.cpp @@ -303,7 +303,7 @@ void Controller::getNextMessage() else if (response->type == message_type::kError) { auto error = msg::message_cast(std::move(response)); - LOG(ERROR, LOG_TAG) << "Received error: " << error->message << ", code: " << error->code << "\n"; + LOG(ERROR, LOG_TAG) << "Received error: " << error->error << ", message: " << error->message << ", code: " << error->code << "\n"; } else { @@ -460,7 +460,10 @@ void Controller::worker() settings_.host_id = ::getHostId(macAddress); // Say hello to the server - auto hello = std::make_shared(macAddress, settings_.host_id, settings_.instance, settings_.server.username, settings_.server.password); + std::optional auth; + if (settings_.server.auth.has_value()) + auth = msg::Hello::Auth{settings_.server.auth->scheme, settings_.server.auth->param}; + auto hello = std::make_shared(macAddress, settings_.host_id, settings_.instance, auth); clientConnection_->sendRequest( hello, 2s, [this](const boost::system::error_code& ec, std::unique_ptr response) mutable { diff --git a/client/snapclient.cpp b/client/snapclient.cpp index 95f91d2d..bdb073be 100644 --- a/client/snapclient.cpp +++ b/client/snapclient.cpp @@ -17,6 +17,7 @@ ***/ // local headers +#include "common/base64.h" #include "common/popl.hpp" #include "common/utils/string_utils.hpp" #include "controller.hpp" @@ -368,10 +369,14 @@ int main(int argc, char** argv) throw SnapException("Snapclient is built without wss support"); #endif } - if (!uri.user.empty()) - settings.server.username = uri.user; - if (!uri.password.empty()) - settings.server.password = uri.password; + + if (!uri.user.empty() || !uri.password.empty()) + { + ClientSettings::Server::Auth auth; + auth.scheme = "Basic"; + auth.param = base64_encode(uri.user + ":" + uri.password); + settings.server.auth = auth; + } } if (server_cert_opt->is_set()) diff --git a/common/message/error.hpp b/common/message/error.hpp index b7919696..01032d7e 100644 --- a/common/message/error.hpp +++ b/common/message/error.hpp @@ -31,37 +31,45 @@ namespace msg class Error : public BaseMessage { public: - /// c'tor taking the @p code and @p message of error - explicit Error(uint32_t code, std::string message) : BaseMessage(message_type::kError), code(code), message(std::move(message)) + /// c'tor taking the @p code, @p error and @p message of error + explicit Error(uint32_t code, std::string error, std::string message) + : BaseMessage(message_type::kError), code(code), error(std::move(error)), message(std::move(message)) { } - Error() : Error(0, "") + /// c'tor + Error() : Error(0, "", "") { } void read(std::istream& stream) override { readVal(stream, code); + readVal(stream, error); readVal(stream, message); } uint32_t getSize() const override { return static_cast(sizeof(uint32_t) // code + + sizeof(uint32_t) // error string len + + error.size() // error string + sizeof(uint32_t) // message len + message.size()); // message; } /// error code uint32_t code; - /// error message + /// error string + std::string error; + /// detailed error message std::string message; protected: void doserialize(std::ostream& stream) const override { writeVal(stream, code); + writeVal(stream, error); writeVal(stream, message); } }; diff --git a/common/message/hello.hpp b/common/message/hello.hpp index 791e9974..f562f171 100644 --- a/common/message/hello.hpp +++ b/common/message/hello.hpp @@ -36,14 +36,48 @@ namespace msg class Hello : public JsonMessage { public: + /// Auth info + struct Auth + { + /// c'tor + Auth() = default; + + /// c'tor construct from json + explicit Auth(const json& j) + { + if (j.contains("scheme")) + scheme = j["scheme"]; + if (j.contains("param")) + param = j["param"]; + } + + /// c'tor construct from @p scheme and @p param + Auth(std::string scheme, std::string param) : scheme(std::move(scheme)), param(std::move(param)) + { + } + + /// @return serialized to json + json toJson() const + { + json j; + j["scheme"] = scheme; + j["param"] = param; + return j; + } + + /// the scheme (Basic, Plain, bearer, ...) + std::string scheme; + /// the param (base64 encoded ":", ":", token, ...) + std::string param; + }; + /// c'tor Hello() : JsonMessage(message_type::kHello) { } - /// c'tor taking @p macAddress, @p id and @p instance - Hello(const std::string& mac_address, const std::string& id, size_t instance, std::optional username, std::optional password) - : JsonMessage(message_type::kHello) + /// c'tor taking @p macAddress, @p id, @p instance and @p auth info + Hello(const std::string& mac_address, const std::string& id, size_t instance, std::optional auth) : JsonMessage(message_type::kHello) { msg["MAC"] = mac_address; msg["HostName"] = ::getHostName(); @@ -53,10 +87,8 @@ public: msg["Arch"] = ::getArch(); msg["Instance"] = instance; msg["ID"] = id; - if (username.has_value()) - msg["Username"] = username.value(); - if (password.has_value()) - msg["Password"] = password.value(); + if (auth.has_value()) + msg["Auth"] = auth->toJson(); msg["SnapStreamProtocolVersion"] = 2; } @@ -129,20 +161,12 @@ public: return id; } - /// @return the username - std::optional getUsername() const + /// @return the auth info + std::optional getAuth() const { - if (!msg.contains("Username")) + if (!msg.contains("Auth")) return std::nullopt; - return msg["Username"]; - } - - /// @return the password - std::optional getPassword() const - { - if (!msg.contains("Password")) - return std::nullopt; - return msg["Password"]; + return Auth{msg["Auth"]}; } }; diff --git a/doc/binary_protocol.md b/doc/binary_protocol.md index c9331155..44f23afe 100644 --- a/doc/binary_protocol.md +++ b/doc/binary_protocol.md @@ -113,20 +113,22 @@ Sample JSON payload (whitespace added for readability): ```json { "Arch": "x86_64", + "Auth": { + "param": "YmFkYWl4OnBhc3N3ZA==", + "scheme": "Basic" + }, "ClientName": "Snapclient", "HostName": "my_hostname", "ID": "00:11:22:33:44:55", "Instance": 1, "MAC": "00:11:22:33:44:55", "OS": "Arch Linux", - "Username": "Badaix", - "Password": "$ecret", "SnapStreamProtocolVersion": 2, - "Version": "0.17.1" + "Version": "0.32.0" } ``` -The fields `Username` and `Password` are optional and only used if authentication and authorization is enabled on the server. +The field `Auth` is optional and only used if authentication and authorization is enabled on the server. ### Client Info @@ -151,5 +153,7 @@ Sample JSON payload (whitespace added for readability): | Field | Type | Description | |---------|--------|----------------------------------------------------------| | code | uint32 | Error code | -| size | uint32 | Size of the following error message | +| size | uint32 | Size of the following error string | | error | char[] | string containing the error (not null terminated) | +| size | uint32 | Size of the following error message | +| error | char[] | string containing error details (not null terminated) | diff --git a/server/authinfo.cpp b/server/authinfo.cpp index 9e287cb1..0af81ef2 100644 --- a/server/authinfo.cpp +++ b/server/authinfo.cpp @@ -105,6 +105,9 @@ AuthInfo::AuthInfo(ServerSettings::Authorization auth_settings) : is_authenticat ErrorCode AuthInfo::validateUser(const std::string& username, const std::optional& password) const { + if (!auth_settings_.enabled) + return {}; + auto iter = std::find_if(auth_settings_.users.begin(), auth_settings_.users.end(), [&](const ServerSettings::Authorization::User& user) { return user.name == username; }); if (iter == auth_settings_.users.end()) @@ -119,12 +122,14 @@ ErrorCode AuthInfo::authenticate(const std::string& scheme, const std::string& p { std::string scheme_normed = utils::string::trim_copy(utils::string::tolower_copy(scheme)); std::string param_normed = utils::string::trim_copy(param); - // if (scheme_normed == "bearer") - // return authenticateBearer(param_normed); if (scheme_normed == "basic") return authenticateBasic(param_normed); + else if (scheme_normed == "plain") + return authenticatePlain(param_normed); + // else if (scheme_normed == "bearer") + // return authenticateBearer(param_normed); - return {AuthErrc::auth_scheme_not_supported, "Scheme must be 'Basic'"}; // or 'Bearer'"}; + return {AuthErrc::auth_scheme_not_supported, "Scheme must be 'Basic' or 'Plain'"}; // or 'Bearer'"}; } @@ -150,6 +155,21 @@ ErrorCode AuthInfo::authenticateBasic(const std::string& credentials) return ec; } + +ErrorCode AuthInfo::authenticatePlain(const std::string& user_password) +{ + is_authenticated_ = false; + std::string password; + std::string username = utils::string::split_left(user_password, ':', password); + auto ec = validateUser(username_, password); + + // TODO: don't log passwords + LOG(INFO, LOG_TAG) << "Authorization basic: " << user_password << ", user: " << username_ << ", password: " << password << "\n"; + is_authenticated_ = (ec.value() == 0); + return ec; +} + + #if 0 ErrorCode AuthInfo::authenticateBearer(const std::string& token) { diff --git a/server/authinfo.hpp b/server/authinfo.hpp index 84968869..276e45ee 100644 --- a/server/authinfo.hpp +++ b/server/authinfo.hpp @@ -79,6 +79,8 @@ public: /// Authenticate with basic scheme ErrorCode authenticateBasic(const std::string& credentials); + /// Authenticate with : + ErrorCode authenticatePlain(const std::string& user_password); /// Authenticate with bearer scheme // ErrorCode authenticateBearer(const std::string& token); /// Authenticate with basic or bearer scheme with an auth header diff --git a/server/control_requests.cpp b/server/control_requests.cpp index 37c1b716..27b25696 100644 --- a/server/control_requests.cpp +++ b/server/control_requests.cpp @@ -885,6 +885,8 @@ void ServerAuthenticateRequest::execute(const jsonrpcpp::request_ptr& request, A // Response: {"id":8,"jsonrpc":"2.0","result":"ok"} // Request: {"id":8,"jsonrpc":"2.0","method":"Server.Authenticate","params":{"scheme":"Bearer","param":"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MTg1NjQ1MTYsImlhdCI6MTcxODUyODUxNiwic3ViIjoiQmFkYWl4In0.gHrMVp7jTAg8aCSg3cttcfIxswqmOPuqVNOb5p79Cn0NmAqRmLXtDLX4QjOoOqqb66ezBBeikpNjPi_aO18YPoNmX9fPxSwcObTHBupnm5eugEpneMPDASFUSE2hg8rrD_OEoAVxx6hCLln7Z3ILyWDmR6jcmy7z0bp0BiAqOywUrFoVIsnlDZRs3wOaap5oS9J2oaA_gNi_7OuvAhrydn26LDhm0KiIqEcyIholkpRHrDYODkz98h2PkZdZ2U429tTvVhzDBJ1cBq2Zq3cvuMZT6qhwaUc8eYA8fUJ7g65iP4o2OZtUzlfEUqX1TKyuWuSK6CUlsZooNE-MSCT7_w"}} // Response: {"id":8,"jsonrpc":"2.0","result":"ok"} + // Request: {"id":8,"jsonrpc":"2.0","method":"Server.Authenticate","params":{"scheme":"Plain","param":":"}} + // Response: {"id":8,"jsonrpc":"2.0","result":"ok"} // clang-format on checkParams(request, {"scheme", "param"}); diff --git a/server/control_session_http.cpp b/server/control_session_http.cpp index 056bbfc0..c78bf7c2 100644 --- a/server/control_session_http.cpp +++ b/server/control_session_http.cpp @@ -409,7 +409,7 @@ void ControlSessionHttp::on_read(beast::error_code ec, std::size_t bytes_transfe } else // if (req_.target() == "/stream") { - auto ws_session = make_shared(nullptr, std::move(*ws)); + auto ws_session = make_shared(nullptr, settings_, std::move(*ws)); message_receiver_->onNewSession(std::move(ws_session)); } } @@ -435,7 +435,7 @@ void ControlSessionHttp::on_read(beast::error_code ec, std::size_t bytes_transfe } else // if (req_.target() == "/stream") { - auto ws_session = make_shared(nullptr, std::move(*ws)); + auto ws_session = make_shared(nullptr, settings_, std::move(*ws)); message_receiver_->onNewSession(std::move(ws_session)); } } diff --git a/server/server.cpp b/server/server.cpp index 3c0bf963..f15d510d 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -22,6 +22,7 @@ // local headers #include "common/aixlog.hpp" #include "common/message/client_info.hpp" +#include "common/message/error.hpp" #include "common/message/hello.hpp" #include "common/message/server_settings.hpp" #include "common/message/time.hpp" @@ -261,7 +262,7 @@ void Server::onMessageReceived(std::shared_ptr controlSession, c -void Server::onMessageReceived(StreamSession* streamSession, const msg::BaseMessage& baseMessage, char* buffer) +void Server::onMessageReceived(const std::shared_ptr& streamSession, const msg::BaseMessage& baseMessage, char* buffer) { LOG(DEBUG, LOG_TAG) << "onMessageReceived: " << baseMessage.type << ", size: " << baseMessage.size << ", id: " << baseMessage.id << ", refers: " << baseMessage.refersTo << ", sent: " << baseMessage.sent.sec << "," << baseMessage.sent.usec @@ -308,12 +309,48 @@ void Server::onMessageReceived(StreamSession* streamSession, const msg::BaseMess msg::Hello helloMsg; helloMsg.deserialize(baseMessage, buffer); streamSession->clientId = helloMsg.getUniqueId(); + auto auth = helloMsg.getAuth(); + // TODO: don't log passwords LOG(INFO, LOG_TAG) << "Hello from " << streamSession->clientId << ", host: " << helloMsg.getHostName() << ", v" << helloMsg.getVersion() << ", ClientName: " << helloMsg.getClientName() << ", OS: " << helloMsg.getOS() << ", Arch: " << helloMsg.getArch() - << ", Protocol version: " << helloMsg.getProtocolVersion() << ", Userrname: " << helloMsg.getUsername().value_or("") - << ", Password: " << (helloMsg.getPassword().has_value() ? "" : "") << "\n"; - streamSession->stop(); - return; + << ", Protocol version: " << helloMsg.getProtocolVersion() << ", Auth: " << auth.value_or(msg::Hello::Auth{}).toJson().dump() + << "\n"; + + if (settings_.auth.enabled) + { + ErrorCode ec; + if (auth.has_value()) + ec = streamSession->authinfo.authenticate(auth->scheme, auth->param); + + if (!auth.has_value() || ec) + { + if (ec) + LOG(ERROR, LOG_TAG) << "Authentication failed: " << ec.detailed_message() << "\n"; + else + LOG(ERROR, LOG_TAG) << "Authentication required\n"; + auto error_msg = make_shared(401, "Unauthorized", ec ? ec.detailed_message() : "Authentication required"); + streamSession->send(error_msg, [streamSession](boost::system::error_code ec, std::size_t length) + { + LOG(DEBUG, LOG_TAG) << "Sent result: " << ec << ", len: " << length << "\n"; + streamSession->stop(); + }); + return; + } + + if (!streamSession->authinfo.hasPermission("Streaming")) + { + std::string error = "Permission 'Streaming' missing"; + LOG(ERROR, LOG_TAG) << error << "\n"; + auto error_msg = make_shared(403, "Forbidden", error); + streamSession->send(error_msg, [streamSession](boost::system::error_code ec, std::size_t length) + { + LOG(DEBUG, LOG_TAG) << "Sent result: " << ec << ", len: " << length << "\n"; + streamSession->stop(); + }); + return; + } + } + bool newGroup(false); GroupPtr group = Config::instance().getGroupFromClient(streamSession->clientId); if (group == nullptr) diff --git a/server/server.hpp b/server/server.hpp index 767c1718..0cb53497 100644 --- a/server/server.hpp +++ b/server/server.hpp @@ -1,6 +1,6 @@ /*** This file is part of snapcast - Copyright (C) 2014-2024 Johannes Pohl + Copyright (C) 2014-2025 Johannes Pohl This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -65,7 +65,7 @@ public: private: /// Implementation of StreamMessageReceiver - void onMessageReceived(StreamSession* streamSession, const msg::BaseMessage& baseMessage, char* buffer) override; + void onMessageReceived(const std::shared_ptr& streamSession, const msg::BaseMessage& baseMessage, char* buffer) override; void onDisconnect(StreamSession* streamSession) override; /// Implementation of ControllMessageReceiver diff --git a/server/stream_server.cpp b/server/stream_server.cpp index 11be2b66..485a38ff 100644 --- a/server/stream_server.cpp +++ b/server/stream_server.cpp @@ -47,7 +47,7 @@ StreamServer::~StreamServer() = default; void StreamServer::cleanup() { - auto new_end = std::remove_if(sessions_.begin(), sessions_.end(), [](std::weak_ptr session) { return session.expired(); }); + auto new_end = std::remove_if(sessions_.begin(), sessions_.end(), [](const std::weak_ptr& session) { return session.expired(); }); auto count = distance(new_end, sessions_.end()); if (count > 0) { @@ -69,7 +69,7 @@ void StreamServer::addSession(std::shared_ptr session) } -void StreamServer::onChunkEncoded(const PcmStream* pcmStream, bool isDefaultStream, std::shared_ptr chunk, double /*duration*/) +void StreamServer::onChunkEncoded(const PcmStream* pcmStream, bool isDefaultStream, const std::shared_ptr& chunk, double /*duration*/) { // LOG(TRACE, LOG_TAG) << "onChunkRead (" << pcmStream->getName() << "): " << duration << "ms\n"; shared_const_buffer buffer(*chunk); @@ -112,7 +112,7 @@ void StreamServer::onChunkEncoded(const PcmStream* pcmStream, bool isDefaultStre } -void StreamServer::onMessageReceived(StreamSession* streamSession, const msg::BaseMessage& baseMessage, char* buffer) +void StreamServer::onMessageReceived(const std::shared_ptr& streamSession, const msg::BaseMessage& baseMessage, char* buffer) { try { @@ -122,7 +122,7 @@ void StreamServer::onMessageReceived(StreamSession* streamSession, const msg::Ba catch (const std::exception& e) { LOG(ERROR, LOG_TAG) << "Server::onMessageReceived exception: " << e.what() << ", message type: " << baseMessage.type << "\n"; - auto session = getStreamSession(streamSession); + auto session = getStreamSession(streamSession.get()); session->stop(); } } @@ -139,7 +139,7 @@ void StreamServer::onDisconnect(StreamSession* streamSession) LOG(INFO, LOG_TAG) << "onDisconnect: " << session->clientId << "\n"; LOG(DEBUG, LOG_TAG) << "sessions: " << sessions_.size() << "\n"; sessions_.erase(std::remove_if(sessions_.begin(), sessions_.end(), - [streamSession](std::weak_ptr session) + [streamSession](const std::weak_ptr& session) { auto s = session.lock(); return s.get() == streamSession; @@ -209,7 +209,7 @@ void StreamServer::handleAccept(tcp::socket socket) socket.set_option(tcp::no_delay(true)); LOG(NOTICE, LOG_TAG) << "StreamServer::NewConnection: " << socket.remote_endpoint().address().to_string() << "\n"; - shared_ptr session = make_shared(this, std::move(socket)); + shared_ptr session = make_shared(this, settings_, std::move(socket)); addSession(session); } catch (const std::exception& e) diff --git a/server/stream_server.hpp b/server/stream_server.hpp index c55637fd..283fbdaa 100644 --- a/server/stream_server.hpp +++ b/server/stream_server.hpp @@ -1,6 +1,6 @@ /*** This file is part of snapcast - Copyright (C) 2014-2024 Johannes Pohl + Copyright (C) 2014-2025 Johannes Pohl This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -53,19 +53,27 @@ using session_ptr = std::shared_ptr; class StreamServer : public StreamMessageReceiver { public: + /// c'tor StreamServer(boost::asio::io_context& io_context, const ServerSettings& serverSettings, StreamMessageReceiver* messageReceiver = nullptr); + /// d'tor virtual ~StreamServer(); + /// Start accepting connections void start(); + /// Stop accepting connections and active sessions void stop(); /// Send a message to all connceted clients // void send(const msg::BaseMessage* message); + /// Add a new stream session void addSession(std::shared_ptr session); - void onChunkEncoded(const PcmStream* pcmStream, bool isDefaultStream, std::shared_ptr chunk, double duration); + /// Callback for chunks that are ready to be sent + void onChunkEncoded(const PcmStream* pcmStream, bool isDefaultStream, const std::shared_ptr& chunk, double duration); + /// @return stream session for @p clientId session_ptr getStreamSession(const std::string& clientId) const; + /// @return stream session for @p session session_ptr getStreamSession(StreamSession* session) const; private: @@ -74,7 +82,7 @@ private: void cleanup(); /// Implementation of StreamMessageReceiver - void onMessageReceived(StreamSession* streamSession, const msg::BaseMessage& baseMessage, char* buffer) override; + void onMessageReceived(const std::shared_ptr& streamSession, const msg::BaseMessage& baseMessage, char* buffer) override; void onDisconnect(StreamSession* streamSession) override; mutable std::recursive_mutex sessionsMutex_; diff --git a/server/stream_session.cpp b/server/stream_session.cpp index b1e1b9e4..31859c2d 100644 --- a/server/stream_session.cpp +++ b/server/stream_session.cpp @@ -34,8 +34,8 @@ using namespace streamreader; static constexpr auto LOG_TAG = "StreamSession"; -StreamSession::StreamSession(const boost::asio::any_io_executor& executor, StreamMessageReceiver* receiver) - : messageReceiver_(receiver), pcmStream_(nullptr), strand_(boost::asio::make_strand(executor)) +StreamSession::StreamSession(const boost::asio::any_io_executor& executor, const ServerSettings& server_settings, StreamMessageReceiver* receiver) + : authinfo(server_settings.auth), messageReceiver_(receiver), pcm_stream_(nullptr), strand_(boost::asio::make_strand(executor)) { base_msg_size_ = baseMessage_.getSize(); buffer_.resize(base_msg_size_); @@ -45,25 +45,29 @@ StreamSession::StreamSession(const boost::asio::any_io_executor& executor, Strea void StreamSession::setPcmStream(PcmStreamPtr pcmStream) { std::lock_guard lock(mutex_); - pcmStream_ = std::move(pcmStream); + pcm_stream_ = std::move(pcmStream); } const PcmStreamPtr StreamSession::pcmStream() const { std::lock_guard lock(mutex_); - return pcmStream_; + return pcm_stream_; } -void StreamSession::send_next() +void StreamSession::sendNext() { auto& buffer = messages_.front(); buffer.on_air = true; boost::asio::post(strand_, [this, self = shared_from_this(), buffer]() { - sendAsync(buffer, [this](boost::system::error_code ec, std::size_t length) + sendAsync(buffer, [this, buffer](boost::system::error_code ec, std::size_t length) { + auto write_handler = buffer.getWriteHandler(); + if (write_handler) + write_handler(ec, length); + messages_.pop_front(); if (ec) { @@ -72,15 +76,15 @@ void StreamSession::send_next() return; } if (!messages_.empty()) - send_next(); + sendNext(); }); }); } -void StreamSession::send(shared_const_buffer const_buf) +void StreamSession::send(shared_const_buffer const_buf, WriteHandler&& handler) { - boost::asio::post(strand_, [this, self = shared_from_this(), const_buf]() + boost::asio::post(strand_, [this, self = shared_from_this(), const_buf = std::move(const_buf), handler = std::move(handler)]() mutable { // delete PCM chunks that are older than the overall buffer duration messages_.erase(std::remove_if(messages_.begin(), messages_.end(), @@ -94,25 +98,26 @@ void StreamSession::send(shared_const_buffer const_buf) }), messages_.end()); - messages_.push_back(const_buf); + const_buf.setWriteHandler(std::move(handler)); + messages_.push_back(std::move(const_buf)); if (messages_.size() > 1) { LOG(TRACE, LOG_TAG) << "outstanding async_write\n"; return; } - send_next(); + sendNext(); }); } -void StreamSession::send(msg::message_ptr message) +void StreamSession::send(const msg::message_ptr& message, WriteHandler&& handler) { if (!message) return; // TODO: better set the timestamp in send_next for more accurate time sync - send(shared_const_buffer(*message)); + send(shared_const_buffer(*message), std::move(handler)); } diff --git a/server/stream_session.hpp b/server/stream_session.hpp index 33dbc137..5a54b987 100644 --- a/server/stream_session.hpp +++ b/server/stream_session.hpp @@ -1,6 +1,6 @@ /*** This file is part of snapcast - Copyright (C) 2014-2024 Johannes Pohl + Copyright (C) 2014-2025 Johannes Pohl This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -20,6 +20,7 @@ // local headers +#include "authinfo.hpp" #include "common/message/message.hpp" #include "streamreader/stream_manager.hpp" @@ -40,29 +41,35 @@ class StreamSession; +/// Write result callback function type +using WriteHandler = std::function; /// Interface: callback for a received message. class StreamMessageReceiver { public: - virtual void onMessageReceived(StreamSession* connection, const msg::BaseMessage& baseMessage, char* buffer) = 0; + /// message received callback + virtual void onMessageReceived(const std::shared_ptr& connection, const msg::BaseMessage& baseMessage, char* buffer) = 0; + /// disonnect callback virtual void onDisconnect(StreamSession* connection) = 0; }; -// A reference-counted non-modifiable buffer class. +/// A reference-counted non-modifiable buffer class. class shared_const_buffer { + /// the message struct Message { - std::vector data; - bool is_pcm_chunk; - message_type type; - chronos::time_point_clk rec_time; + std::vector data; ///< data + bool is_pcm_chunk; ///< is it a PCM chunk + message_type type; ///< message type + chronos::time_point_clk rec_time; ///< recording time }; public: - shared_const_buffer(msg::BaseMessage& message) : on_air(false) + /// c'tor + explicit shared_const_buffer(msg::BaseMessage& message) { tv t; message.sent = t; @@ -83,32 +90,47 @@ public: // Implement the ConstBufferSequence requirements. using value_type = boost::asio::const_buffer; using const_iterator = const boost::asio::const_buffer*; + + /// begin iterator const boost::asio::const_buffer* begin() const { return &buffer_; } + /// end iterator const boost::asio::const_buffer* end() const { return &buffer_ + 1; } + /// the payload const Message& message() const { return *message_; } - bool on_air; + /// set write callback + void setWriteHandler(WriteHandler&& handler) + { + handler_ = std::move(handler); + } + + /// get write callback + const WriteHandler& getWriteHandler() const + { + return handler_; + } + + /// is the buffer sent? + bool on_air{false}; private: std::shared_ptr message_; boost::asio::const_buffer buffer_; + WriteHandler handler_; }; -/// Write result callback function type -using WriteHandler = std::function; - /// Endpoint for a connected client. /** * Endpoint for a connected client. @@ -119,7 +141,7 @@ class StreamSession : public std::enable_shared_from_this { public: /// c'tor. Received message from the client are passed to StreamMessageReceiver - StreamSession(const boost::asio::any_io_executor& executor, StreamMessageReceiver* receiver); + StreamSession(const boost::asio::any_io_executor& executor, const ServerSettings& server_settings, StreamMessageReceiver* receiver); /// d'tor virtual ~StreamSession() = default; @@ -139,33 +161,40 @@ public: protected: /// Send data @p buffer to the streaming client, result is returned in the callback @p handler - virtual void sendAsync(const shared_const_buffer& buffer, const WriteHandler& handler) = 0; + virtual void sendAsync(const shared_const_buffer& buffer, WriteHandler&& handler) = 0; public: /// Sends a message to the client (asynchronous) - void send(msg::message_ptr message); + void send(const msg::message_ptr& message, WriteHandler&& handler = nullptr); /// Sends a message to the client (asynchronous) - void send(shared_const_buffer const_buf); + void send(shared_const_buffer const_buf, WriteHandler&& handler = nullptr); /// Max playout latency. No need to send PCM data that is older than bufferMs void setBufferMs(size_t bufferMs); + /// Client id of the session std::string clientId; + /// Set the sessions PCM stream void setPcmStream(streamreader::PcmStreamPtr pcmStream); + /// Get the sessions PCM stream const streamreader::PcmStreamPtr pcmStream() const; -protected: - void send_next(); + /// Authentication info attached to this session + AuthInfo authinfo; - msg::BaseMessage baseMessage_; - std::vector buffer_; - size_t base_msg_size_; - StreamMessageReceiver* messageReceiver_; - size_t bufferMs_; - streamreader::PcmStreamPtr pcmStream_; - boost::asio::strand strand_; - std::deque messages_; - mutable std::mutex mutex_; +protected: + /// Send next message from "messages_" + void sendNext(); + + msg::BaseMessage baseMessage_; ///< base message buffer + std::vector buffer_; ///< buffer + size_t base_msg_size_; ///< size of a base message + StreamMessageReceiver* messageReceiver_; ///< message receiver + size_t bufferMs_; ///< buffer size in [ms] + streamreader::PcmStreamPtr pcm_stream_; ///< the sessions PCM stream + boost::asio::strand strand_; ///< strand to sync IO on + std::deque messages_; ///< messages to be sent + mutable std::mutex mutex_; ///< protect pcm_stream_ }; diff --git a/server/stream_session_tcp.cpp b/server/stream_session_tcp.cpp index 37751b06..892cf516 100644 --- a/server/stream_session_tcp.cpp +++ b/server/stream_session_tcp.cpp @@ -34,8 +34,8 @@ using namespace streamreader; static constexpr auto LOG_TAG = "StreamSessionTCP"; -StreamSessionTcp::StreamSessionTcp(StreamMessageReceiver* receiver, tcp::socket&& socket) - : StreamSession(socket.get_executor(), receiver), socket_(std::move(socket)) +StreamSessionTcp::StreamSessionTcp(StreamMessageReceiver* receiver, const ServerSettings& server_settings, tcp::socket&& socket) + : StreamSession(socket.get_executor(), server_settings, receiver), socket_(std::move(socket)) { } @@ -49,7 +49,7 @@ StreamSessionTcp::~StreamSessionTcp() void StreamSessionTcp::start() { - read_next(); + readNext(); } @@ -83,7 +83,7 @@ std::string StreamSessionTcp::getIP() } -void StreamSessionTcp::read_next() +void StreamSessionTcp::readNext() { boost::asio::async_read(socket_, boost::asio::buffer(buffer_, base_msg_size_), [this, self = shared_from_this()](boost::system::error_code ec, std::size_t length) mutable @@ -126,15 +126,19 @@ void StreamSessionTcp::read_next() tv now; baseMessage_.received = now; if (messageReceiver_ != nullptr) - messageReceiver_->onMessageReceived(this, baseMessage_, buffer_.data()); - read_next(); + messageReceiver_->onMessageReceived(shared_from_this(), baseMessage_, buffer_.data()); + readNext(); }); }); } -void StreamSessionTcp::sendAsync(const shared_const_buffer& buffer, const WriteHandler& handler) +void StreamSessionTcp::sendAsync(const shared_const_buffer& buffer, WriteHandler&& handler) { boost::asio::async_write(socket_, buffer, - [self = shared_from_this(), buffer, handler](boost::system::error_code ec, std::size_t length) { handler(ec, length); }); + [self = shared_from_this(), buffer, handler = std::move(handler)](boost::system::error_code ec, std::size_t length) + { + if (handler) + handler(ec, length); + }); } diff --git a/server/stream_session_tcp.hpp b/server/stream_session_tcp.hpp index 798bf301..a8bd5e32 100644 --- a/server/stream_session_tcp.hpp +++ b/server/stream_session_tcp.hpp @@ -1,6 +1,6 @@ /*** This file is part of snapcast - Copyright (C) 2014-2024 Johannes Pohl + Copyright (C) 2014-2025 Johannes Pohl This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -40,15 +40,17 @@ class StreamSessionTcp : public StreamSession { public: /// ctor. Received message from the client are passed to StreamMessageReceiver - StreamSessionTcp(StreamMessageReceiver* receiver, tcp::socket&& socket); + StreamSessionTcp(StreamMessageReceiver* receiver, const ServerSettings& server_settings, tcp::socket&& socket); ~StreamSessionTcp() override; void start() override; void stop() override; std::string getIP() override; protected: - void read_next(); - void sendAsync(const shared_const_buffer& buffer, const WriteHandler& handler) override; + /// Read next message + void readNext(); + /// Send message @p buffer and pass result to @p handler + void sendAsync(const shared_const_buffer& buffer, WriteHandler&& handler) override; private: tcp::socket socket_; diff --git a/server/stream_session_ws.cpp b/server/stream_session_ws.cpp index a723aaad..c01b3c7d 100644 --- a/server/stream_session_ws.cpp +++ b/server/stream_session_ws.cpp @@ -1,6 +1,6 @@ /*** This file is part of snapcast - Copyright (C) 2014-2023 Johannes Pohl + Copyright (C) 2014-2025 Johannes Pohl This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -32,14 +32,14 @@ using namespace std; static constexpr auto LOG_TAG = "StreamSessionWS"; -StreamSessionWebsocket::StreamSessionWebsocket(StreamMessageReceiver* receiver, ssl_websocket&& ssl_ws) - : StreamSession(ssl_ws.get_executor(), receiver), ssl_ws_(std::move(ssl_ws)), is_ssl_(true) +StreamSessionWebsocket::StreamSessionWebsocket(StreamMessageReceiver* receiver, const ServerSettings& server_settings, ssl_websocket&& ssl_ws) + : StreamSession(ssl_ws.get_executor(), server_settings, receiver), ssl_ws_(std::move(ssl_ws)), is_ssl_(true) { LOG(DEBUG, LOG_TAG) << "StreamSessionWS, mode: ssl\n"; } -StreamSessionWebsocket::StreamSessionWebsocket(StreamMessageReceiver* receiver, tcp_websocket&& tcp_ws) - : StreamSession(tcp_ws.get_executor(), receiver), tcp_ws_(std::move(tcp_ws)), is_ssl_(false) +StreamSessionWebsocket::StreamSessionWebsocket(StreamMessageReceiver* receiver, const ServerSettings& server_settings, tcp_websocket&& tcp_ws) + : StreamSession(tcp_ws.get_executor(), server_settings, receiver), tcp_ws_(std::move(tcp_ws)), is_ssl_(false) { LOG(DEBUG, LOG_TAG) << "StreamSessionWS, mode: tcp\n"; } @@ -98,13 +98,21 @@ std::string StreamSessionWebsocket::getIP() } -void StreamSessionWebsocket::sendAsync(const shared_const_buffer& buffer, const WriteHandler& handler) +void StreamSessionWebsocket::sendAsync(const shared_const_buffer& buffer, WriteHandler&& handler) { LOG(TRACE, LOG_TAG) << "sendAsync: " << buffer.message().type << "\n"; if (is_ssl_) - ssl_ws_->async_write(buffer, [self = shared_from_this(), buffer, handler](boost::system::error_code ec, std::size_t length) { handler(ec, length); }); + ssl_ws_->async_write(buffer, [self = shared_from_this(), buffer, handler = std::move(handler)](boost::system::error_code ec, std::size_t length) + { + if (handler) + handler(ec, length); + }); else - tcp_ws_->async_write(buffer, [self = shared_from_this(), buffer, handler](boost::system::error_code ec, std::size_t length) { handler(ec, length); }); + tcp_ws_->async_write(buffer, [self = shared_from_this(), buffer, handler = std::move(handler)](boost::system::error_code ec, std::size_t length) + { + if (handler) + handler(ec, length); + }); } @@ -146,7 +154,7 @@ void StreamSessionWebsocket::on_read_ws(beast::error_code ec, std::size_t bytes_ baseMessage_.received = now; if (messageReceiver_ != nullptr) - messageReceiver_->onMessageReceived(this, baseMessage_, data + base_msg_size_); + messageReceiver_->onMessageReceived(shared_from_this(), baseMessage_, data + base_msg_size_); buffer_.consume(bytes_transferred); do_read_ws(); diff --git a/server/stream_session_ws.hpp b/server/stream_session_ws.hpp index 685df47c..49faa204 100644 --- a/server/stream_session_ws.hpp +++ b/server/stream_session_ws.hpp @@ -1,6 +1,6 @@ /*** This file is part of snapcast - Copyright (C) 2014-2024 Johannes Pohl + Copyright (C) 2014-2025 Johannes Pohl This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -48,24 +48,26 @@ using ssl_websocket = websocket::stream; class StreamSessionWebsocket : public StreamSession { public: - /// ctor. Received message from the client are passed to StreamMessageReceiver - StreamSessionWebsocket(StreamMessageReceiver* receiver, ssl_websocket&& ssl_ws); - StreamSessionWebsocket(StreamMessageReceiver* receiver, tcp_websocket&& tcp_ws); + /// c'tor for SSL. Received message from the client are passed to StreamMessageReceiver + StreamSessionWebsocket(StreamMessageReceiver* receiver, const ServerSettings& server_settings, ssl_websocket&& ssl_ws); + /// c'tor for TCP + StreamSessionWebsocket(StreamMessageReceiver* receiver, const ServerSettings& server_settings, tcp_websocket&& tcp_ws); ~StreamSessionWebsocket() override; void start() override; void stop() override; std::string getIP() override; -protected: - // Websocket methods - void sendAsync(const shared_const_buffer& buffer, const WriteHandler& handler) override; +private: + /// Send message @p buffer and pass result to @p handler + void sendAsync(const shared_const_buffer& buffer, WriteHandler&& handler) override; + /// Read callback void on_read_ws(beast::error_code ec, std::size_t bytes_transferred); + /// Read loop void do_read_ws(); - std::optional ssl_ws_; - std::optional tcp_ws_; + std::optional ssl_ws_; ///< SSL websocket + std::optional tcp_ws_; ///< TCP websocket -protected: - beast::flat_buffer buffer_; - bool is_ssl_; + beast::flat_buffer buffer_; ///< read buffer + bool is_ssl_; ///< are we in SSL mode? };