Authentication for streaming clients

This commit is contained in:
badaix 2025-02-08 23:06:56 +01:00
parent 67fd20619d
commit 02b8033728
20 changed files with 302 additions and 132 deletions

View file

@ -60,16 +60,23 @@ struct ClientSettings
/// Server settings /// Server settings
struct Server struct Server
{ {
/// Auth info
struct Auth
{
/// the scheme (Basic, Plain, bearer, ...)
std::string scheme;
/// the param (base64 encoded "<user>:<password>", "<user>:<password>", token, ...)
std::string param;
};
/// server host or IP address /// server host or IP address
std::string host; std::string host;
/// protocol: "tcp", "ws" or "wss" /// protocol: "tcp", "ws" or "wss"
std::string protocol{"tcp"}; std::string protocol{"tcp"};
/// server port /// server port
size_t port{1704}; size_t port{1704};
/// username /// auth info
std::optional<std::string> username; std::optional<Auth> auth;
/// password
std::optional<std::string> password;
/// server certificate /// server certificate
std::optional<std::filesystem::path> server_certificate; std::optional<std::filesystem::path> server_certificate;
/// Certificate file /// Certificate file

View file

@ -303,7 +303,7 @@ void Controller::getNextMessage()
else if (response->type == message_type::kError) else if (response->type == message_type::kError)
{ {
auto error = msg::message_cast<msg::Error>(std::move(response)); auto error = msg::message_cast<msg::Error>(std::move(response));
LOG(ERROR, LOG_TAG) << "Received error: " << error->message << ", code: " << error->code << "\n"; LOG(ERROR, LOG_TAG) << "Received error: " << error->error << ", message: " << error->message << ", code: " << error->code << "\n";
} }
else else
{ {
@ -460,7 +460,10 @@ void Controller::worker()
settings_.host_id = ::getHostId(macAddress); settings_.host_id = ::getHostId(macAddress);
// Say hello to the server // Say hello to the server
auto hello = std::make_shared<msg::Hello>(macAddress, settings_.host_id, settings_.instance, settings_.server.username, settings_.server.password); std::optional<msg::Hello::Auth> auth;
if (settings_.server.auth.has_value())
auth = msg::Hello::Auth{settings_.server.auth->scheme, settings_.server.auth->param};
auto hello = std::make_shared<msg::Hello>(macAddress, settings_.host_id, settings_.instance, auth);
clientConnection_->sendRequest<msg::ServerSettings>( clientConnection_->sendRequest<msg::ServerSettings>(
hello, 2s, [this](const boost::system::error_code& ec, std::unique_ptr<msg::ServerSettings> response) mutable hello, 2s, [this](const boost::system::error_code& ec, std::unique_ptr<msg::ServerSettings> response) mutable
{ {

View file

@ -17,6 +17,7 @@
***/ ***/
// local headers // local headers
#include "common/base64.h"
#include "common/popl.hpp" #include "common/popl.hpp"
#include "common/utils/string_utils.hpp" #include "common/utils/string_utils.hpp"
#include "controller.hpp" #include "controller.hpp"
@ -368,10 +369,14 @@ int main(int argc, char** argv)
throw SnapException("Snapclient is built without wss support"); throw SnapException("Snapclient is built without wss support");
#endif #endif
} }
if (!uri.user.empty())
settings.server.username = uri.user; if (!uri.user.empty() || !uri.password.empty())
if (!uri.password.empty()) {
settings.server.password = uri.password; ClientSettings::Server::Auth auth;
auth.scheme = "Basic";
auth.param = base64_encode(uri.user + ":" + uri.password);
settings.server.auth = auth;
}
} }
if (server_cert_opt->is_set()) if (server_cert_opt->is_set())

View file

@ -31,37 +31,45 @@ namespace msg
class Error : public BaseMessage class Error : public BaseMessage
{ {
public: public:
/// c'tor taking the @p code and @p message of error /// c'tor taking the @p code, @p error and @p message of error
explicit Error(uint32_t code, std::string message) : BaseMessage(message_type::kError), code(code), message(std::move(message)) explicit Error(uint32_t code, std::string error, std::string message)
: BaseMessage(message_type::kError), code(code), error(std::move(error)), message(std::move(message))
{ {
} }
Error() : Error(0, "") /// c'tor
Error() : Error(0, "", "")
{ {
} }
void read(std::istream& stream) override void read(std::istream& stream) override
{ {
readVal(stream, code); readVal(stream, code);
readVal(stream, error);
readVal(stream, message); readVal(stream, message);
} }
uint32_t getSize() const override uint32_t getSize() const override
{ {
return static_cast<uint32_t>(sizeof(uint32_t) // code return static_cast<uint32_t>(sizeof(uint32_t) // code
+ sizeof(uint32_t) // error string len
+ error.size() // error string
+ sizeof(uint32_t) // message len + sizeof(uint32_t) // message len
+ message.size()); // message; + message.size()); // message;
} }
/// error code /// error code
uint32_t code; uint32_t code;
/// error message /// error string
std::string error;
/// detailed error message
std::string message; std::string message;
protected: protected:
void doserialize(std::ostream& stream) const override void doserialize(std::ostream& stream) const override
{ {
writeVal(stream, code); writeVal(stream, code);
writeVal(stream, error);
writeVal(stream, message); writeVal(stream, message);
} }
}; };

View file

@ -36,14 +36,48 @@ namespace msg
class Hello : public JsonMessage class Hello : public JsonMessage
{ {
public: public:
/// Auth info
struct Auth
{
/// c'tor
Auth() = default;
/// c'tor construct from json
explicit Auth(const json& j)
{
if (j.contains("scheme"))
scheme = j["scheme"];
if (j.contains("param"))
param = j["param"];
}
/// c'tor construct from @p scheme and @p param
Auth(std::string scheme, std::string param) : scheme(std::move(scheme)), param(std::move(param))
{
}
/// @return serialized to json
json toJson() const
{
json j;
j["scheme"] = scheme;
j["param"] = param;
return j;
}
/// the scheme (Basic, Plain, bearer, ...)
std::string scheme;
/// the param (base64 encoded "<user>:<password>", "<user>:<password>", token, ...)
std::string param;
};
/// c'tor /// c'tor
Hello() : JsonMessage(message_type::kHello) Hello() : JsonMessage(message_type::kHello)
{ {
} }
/// c'tor taking @p macAddress, @p id and @p instance /// c'tor taking @p macAddress, @p id, @p instance and @p auth info
Hello(const std::string& mac_address, const std::string& id, size_t instance, std::optional<std::string> username, std::optional<std::string> password) Hello(const std::string& mac_address, const std::string& id, size_t instance, std::optional<Auth> auth) : JsonMessage(message_type::kHello)
: JsonMessage(message_type::kHello)
{ {
msg["MAC"] = mac_address; msg["MAC"] = mac_address;
msg["HostName"] = ::getHostName(); msg["HostName"] = ::getHostName();
@ -53,10 +87,8 @@ public:
msg["Arch"] = ::getArch(); msg["Arch"] = ::getArch();
msg["Instance"] = instance; msg["Instance"] = instance;
msg["ID"] = id; msg["ID"] = id;
if (username.has_value()) if (auth.has_value())
msg["Username"] = username.value(); msg["Auth"] = auth->toJson();
if (password.has_value())
msg["Password"] = password.value();
msg["SnapStreamProtocolVersion"] = 2; msg["SnapStreamProtocolVersion"] = 2;
} }
@ -129,20 +161,12 @@ public:
return id; return id;
} }
/// @return the username /// @return the auth info
std::optional<std::string> getUsername() const std::optional<Auth> getAuth() const
{ {
if (!msg.contains("Username")) if (!msg.contains("Auth"))
return std::nullopt; return std::nullopt;
return msg["Username"]; return Auth{msg["Auth"]};
}
/// @return the password
std::optional<std::string> getPassword() const
{
if (!msg.contains("Password"))
return std::nullopt;
return msg["Password"];
} }
}; };

View file

@ -113,20 +113,22 @@ Sample JSON payload (whitespace added for readability):
```json ```json
{ {
"Arch": "x86_64", "Arch": "x86_64",
"Auth": {
"param": "YmFkYWl4OnBhc3N3ZA==",
"scheme": "Basic"
},
"ClientName": "Snapclient", "ClientName": "Snapclient",
"HostName": "my_hostname", "HostName": "my_hostname",
"ID": "00:11:22:33:44:55", "ID": "00:11:22:33:44:55",
"Instance": 1, "Instance": 1,
"MAC": "00:11:22:33:44:55", "MAC": "00:11:22:33:44:55",
"OS": "Arch Linux", "OS": "Arch Linux",
"Username": "Badaix",
"Password": "$ecret",
"SnapStreamProtocolVersion": 2, "SnapStreamProtocolVersion": 2,
"Version": "0.17.1" "Version": "0.32.0"
} }
``` ```
The fields `Username` and `Password` are optional and only used if authentication and authorization is enabled on the server. The field `Auth` is optional and only used if authentication and authorization is enabled on the server.
### Client Info ### Client Info
@ -151,5 +153,7 @@ Sample JSON payload (whitespace added for readability):
| Field | Type | Description | | Field | Type | Description |
|---------|--------|----------------------------------------------------------| |---------|--------|----------------------------------------------------------|
| code | uint32 | Error code | | code | uint32 | Error code |
| size | uint32 | Size of the following error message | | size | uint32 | Size of the following error string |
| error | char[] | string containing the error (not null terminated) | | error | char[] | string containing the error (not null terminated) |
| size | uint32 | Size of the following error message |
| error | char[] | string containing error details (not null terminated) |

View file

@ -105,6 +105,9 @@ AuthInfo::AuthInfo(ServerSettings::Authorization auth_settings) : is_authenticat
ErrorCode AuthInfo::validateUser(const std::string& username, const std::optional<std::string>& password) const ErrorCode AuthInfo::validateUser(const std::string& username, const std::optional<std::string>& password) const
{ {
if (!auth_settings_.enabled)
return {};
auto iter = std::find_if(auth_settings_.users.begin(), auth_settings_.users.end(), auto iter = std::find_if(auth_settings_.users.begin(), auth_settings_.users.end(),
[&](const ServerSettings::Authorization::User& user) { return user.name == username; }); [&](const ServerSettings::Authorization::User& user) { return user.name == username; });
if (iter == auth_settings_.users.end()) if (iter == auth_settings_.users.end())
@ -119,12 +122,14 @@ ErrorCode AuthInfo::authenticate(const std::string& scheme, const std::string& p
{ {
std::string scheme_normed = utils::string::trim_copy(utils::string::tolower_copy(scheme)); std::string scheme_normed = utils::string::trim_copy(utils::string::tolower_copy(scheme));
std::string param_normed = utils::string::trim_copy(param); std::string param_normed = utils::string::trim_copy(param);
// if (scheme_normed == "bearer")
// return authenticateBearer(param_normed);
if (scheme_normed == "basic") if (scheme_normed == "basic")
return authenticateBasic(param_normed); return authenticateBasic(param_normed);
else if (scheme_normed == "plain")
return authenticatePlain(param_normed);
// else if (scheme_normed == "bearer")
// return authenticateBearer(param_normed);
return {AuthErrc::auth_scheme_not_supported, "Scheme must be 'Basic'"}; // or 'Bearer'"}; return {AuthErrc::auth_scheme_not_supported, "Scheme must be 'Basic' or 'Plain'"}; // or 'Bearer'"};
} }
@ -150,6 +155,21 @@ ErrorCode AuthInfo::authenticateBasic(const std::string& credentials)
return ec; return ec;
} }
ErrorCode AuthInfo::authenticatePlain(const std::string& user_password)
{
is_authenticated_ = false;
std::string password;
std::string username = utils::string::split_left(user_password, ':', password);
auto ec = validateUser(username_, password);
// TODO: don't log passwords
LOG(INFO, LOG_TAG) << "Authorization basic: " << user_password << ", user: " << username_ << ", password: " << password << "\n";
is_authenticated_ = (ec.value() == 0);
return ec;
}
#if 0 #if 0
ErrorCode AuthInfo::authenticateBearer(const std::string& token) ErrorCode AuthInfo::authenticateBearer(const std::string& token)
{ {

View file

@ -79,6 +79,8 @@ public:
/// Authenticate with basic scheme /// Authenticate with basic scheme
ErrorCode authenticateBasic(const std::string& credentials); ErrorCode authenticateBasic(const std::string& credentials);
/// Authenticate with <user>:<password>
ErrorCode authenticatePlain(const std::string& user_password);
/// Authenticate with bearer scheme /// Authenticate with bearer scheme
// ErrorCode authenticateBearer(const std::string& token); // ErrorCode authenticateBearer(const std::string& token);
/// Authenticate with basic or bearer scheme with an auth header /// Authenticate with basic or bearer scheme with an auth header

View file

@ -885,6 +885,8 @@ void ServerAuthenticateRequest::execute(const jsonrpcpp::request_ptr& request, A
// Response: {"id":8,"jsonrpc":"2.0","result":"ok"} // Response: {"id":8,"jsonrpc":"2.0","result":"ok"}
// Request: {"id":8,"jsonrpc":"2.0","method":"Server.Authenticate","params":{"scheme":"Bearer","param":"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MTg1NjQ1MTYsImlhdCI6MTcxODUyODUxNiwic3ViIjoiQmFkYWl4In0.gHrMVp7jTAg8aCSg3cttcfIxswqmOPuqVNOb5p79Cn0NmAqRmLXtDLX4QjOoOqqb66ezBBeikpNjPi_aO18YPoNmX9fPxSwcObTHBupnm5eugEpneMPDASFUSE2hg8rrD_OEoAVxx6hCLln7Z3ILyWDmR6jcmy7z0bp0BiAqOywUrFoVIsnlDZRs3wOaap5oS9J2oaA_gNi_7OuvAhrydn26LDhm0KiIqEcyIholkpRHrDYODkz98h2PkZdZ2U429tTvVhzDBJ1cBq2Zq3cvuMZT6qhwaUc8eYA8fUJ7g65iP4o2OZtUzlfEUqX1TKyuWuSK6CUlsZooNE-MSCT7_w"}} // Request: {"id":8,"jsonrpc":"2.0","method":"Server.Authenticate","params":{"scheme":"Bearer","param":"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MTg1NjQ1MTYsImlhdCI6MTcxODUyODUxNiwic3ViIjoiQmFkYWl4In0.gHrMVp7jTAg8aCSg3cttcfIxswqmOPuqVNOb5p79Cn0NmAqRmLXtDLX4QjOoOqqb66ezBBeikpNjPi_aO18YPoNmX9fPxSwcObTHBupnm5eugEpneMPDASFUSE2hg8rrD_OEoAVxx6hCLln7Z3ILyWDmR6jcmy7z0bp0BiAqOywUrFoVIsnlDZRs3wOaap5oS9J2oaA_gNi_7OuvAhrydn26LDhm0KiIqEcyIholkpRHrDYODkz98h2PkZdZ2U429tTvVhzDBJ1cBq2Zq3cvuMZT6qhwaUc8eYA8fUJ7g65iP4o2OZtUzlfEUqX1TKyuWuSK6CUlsZooNE-MSCT7_w"}}
// Response: {"id":8,"jsonrpc":"2.0","result":"ok"} // Response: {"id":8,"jsonrpc":"2.0","result":"ok"}
// Request: {"id":8,"jsonrpc":"2.0","method":"Server.Authenticate","params":{"scheme":"Plain","param":"<user>:<password>"}}
// Response: {"id":8,"jsonrpc":"2.0","result":"ok"}
// clang-format on // clang-format on
checkParams(request, {"scheme", "param"}); checkParams(request, {"scheme", "param"});

View file

@ -409,7 +409,7 @@ void ControlSessionHttp::on_read(beast::error_code ec, std::size_t bytes_transfe
} }
else // if (req_.target() == "/stream") else // if (req_.target() == "/stream")
{ {
auto ws_session = make_shared<StreamSessionWebsocket>(nullptr, std::move(*ws)); auto ws_session = make_shared<StreamSessionWebsocket>(nullptr, settings_, std::move(*ws));
message_receiver_->onNewSession(std::move(ws_session)); message_receiver_->onNewSession(std::move(ws_session));
} }
} }
@ -435,7 +435,7 @@ void ControlSessionHttp::on_read(beast::error_code ec, std::size_t bytes_transfe
} }
else // if (req_.target() == "/stream") else // if (req_.target() == "/stream")
{ {
auto ws_session = make_shared<StreamSessionWebsocket>(nullptr, std::move(*ws)); auto ws_session = make_shared<StreamSessionWebsocket>(nullptr, settings_, std::move(*ws));
message_receiver_->onNewSession(std::move(ws_session)); message_receiver_->onNewSession(std::move(ws_session));
} }
} }

View file

@ -22,6 +22,7 @@
// local headers // local headers
#include "common/aixlog.hpp" #include "common/aixlog.hpp"
#include "common/message/client_info.hpp" #include "common/message/client_info.hpp"
#include "common/message/error.hpp"
#include "common/message/hello.hpp" #include "common/message/hello.hpp"
#include "common/message/server_settings.hpp" #include "common/message/server_settings.hpp"
#include "common/message/time.hpp" #include "common/message/time.hpp"
@ -261,7 +262,7 @@ void Server::onMessageReceived(std::shared_ptr<ControlSession> controlSession, c
void Server::onMessageReceived(StreamSession* streamSession, const msg::BaseMessage& baseMessage, char* buffer) void Server::onMessageReceived(const std::shared_ptr<StreamSession>& streamSession, const msg::BaseMessage& baseMessage, char* buffer)
{ {
LOG(DEBUG, LOG_TAG) << "onMessageReceived: " << baseMessage.type << ", size: " << baseMessage.size << ", id: " << baseMessage.id LOG(DEBUG, LOG_TAG) << "onMessageReceived: " << baseMessage.type << ", size: " << baseMessage.size << ", id: " << baseMessage.id
<< ", refers: " << baseMessage.refersTo << ", sent: " << baseMessage.sent.sec << "," << baseMessage.sent.usec << ", refers: " << baseMessage.refersTo << ", sent: " << baseMessage.sent.sec << "," << baseMessage.sent.usec
@ -308,12 +309,48 @@ void Server::onMessageReceived(StreamSession* streamSession, const msg::BaseMess
msg::Hello helloMsg; msg::Hello helloMsg;
helloMsg.deserialize(baseMessage, buffer); helloMsg.deserialize(baseMessage, buffer);
streamSession->clientId = helloMsg.getUniqueId(); streamSession->clientId = helloMsg.getUniqueId();
auto auth = helloMsg.getAuth();
// TODO: don't log passwords
LOG(INFO, LOG_TAG) << "Hello from " << streamSession->clientId << ", host: " << helloMsg.getHostName() << ", v" << helloMsg.getVersion() LOG(INFO, LOG_TAG) << "Hello from " << streamSession->clientId << ", host: " << helloMsg.getHostName() << ", v" << helloMsg.getVersion()
<< ", ClientName: " << helloMsg.getClientName() << ", OS: " << helloMsg.getOS() << ", Arch: " << helloMsg.getArch() << ", ClientName: " << helloMsg.getClientName() << ", OS: " << helloMsg.getOS() << ", Arch: " << helloMsg.getArch()
<< ", Protocol version: " << helloMsg.getProtocolVersion() << ", Userrname: " << helloMsg.getUsername().value_or("<not set>") << ", Protocol version: " << helloMsg.getProtocolVersion() << ", Auth: " << auth.value_or(msg::Hello::Auth{}).toJson().dump()
<< ", Password: " << (helloMsg.getPassword().has_value() ? "<password is set>" : "<not set>") << "\n"; << "\n";
if (settings_.auth.enabled)
{
ErrorCode ec;
if (auth.has_value())
ec = streamSession->authinfo.authenticate(auth->scheme, auth->param);
if (!auth.has_value() || ec)
{
if (ec)
LOG(ERROR, LOG_TAG) << "Authentication failed: " << ec.detailed_message() << "\n";
else
LOG(ERROR, LOG_TAG) << "Authentication required\n";
auto error_msg = make_shared<msg::Error>(401, "Unauthorized", ec ? ec.detailed_message() : "Authentication required");
streamSession->send(error_msg, [streamSession](boost::system::error_code ec, std::size_t length)
{
LOG(DEBUG, LOG_TAG) << "Sent result: " << ec << ", len: " << length << "\n";
streamSession->stop(); streamSession->stop();
});
return; return;
}
if (!streamSession->authinfo.hasPermission("Streaming"))
{
std::string error = "Permission 'Streaming' missing";
LOG(ERROR, LOG_TAG) << error << "\n";
auto error_msg = make_shared<msg::Error>(403, "Forbidden", error);
streamSession->send(error_msg, [streamSession](boost::system::error_code ec, std::size_t length)
{
LOG(DEBUG, LOG_TAG) << "Sent result: " << ec << ", len: " << length << "\n";
streamSession->stop();
});
return;
}
}
bool newGroup(false); bool newGroup(false);
GroupPtr group = Config::instance().getGroupFromClient(streamSession->clientId); GroupPtr group = Config::instance().getGroupFromClient(streamSession->clientId);
if (group == nullptr) if (group == nullptr)

View file

@ -1,6 +1,6 @@
/*** /***
This file is part of snapcast This file is part of snapcast
Copyright (C) 2014-2024 Johannes Pohl Copyright (C) 2014-2025 Johannes Pohl
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
@ -65,7 +65,7 @@ public:
private: private:
/// Implementation of StreamMessageReceiver /// Implementation of StreamMessageReceiver
void onMessageReceived(StreamSession* streamSession, const msg::BaseMessage& baseMessage, char* buffer) override; void onMessageReceived(const std::shared_ptr<StreamSession>& streamSession, const msg::BaseMessage& baseMessage, char* buffer) override;
void onDisconnect(StreamSession* streamSession) override; void onDisconnect(StreamSession* streamSession) override;
/// Implementation of ControllMessageReceiver /// Implementation of ControllMessageReceiver

View file

@ -47,7 +47,7 @@ StreamServer::~StreamServer() = default;
void StreamServer::cleanup() void StreamServer::cleanup()
{ {
auto new_end = std::remove_if(sessions_.begin(), sessions_.end(), [](std::weak_ptr<StreamSession> session) { return session.expired(); }); auto new_end = std::remove_if(sessions_.begin(), sessions_.end(), [](const std::weak_ptr<StreamSession>& session) { return session.expired(); });
auto count = distance(new_end, sessions_.end()); auto count = distance(new_end, sessions_.end());
if (count > 0) if (count > 0)
{ {
@ -69,7 +69,7 @@ void StreamServer::addSession(std::shared_ptr<StreamSession> session)
} }
void StreamServer::onChunkEncoded(const PcmStream* pcmStream, bool isDefaultStream, std::shared_ptr<msg::PcmChunk> chunk, double /*duration*/) void StreamServer::onChunkEncoded(const PcmStream* pcmStream, bool isDefaultStream, const std::shared_ptr<msg::PcmChunk>& chunk, double /*duration*/)
{ {
// LOG(TRACE, LOG_TAG) << "onChunkRead (" << pcmStream->getName() << "): " << duration << "ms\n"; // LOG(TRACE, LOG_TAG) << "onChunkRead (" << pcmStream->getName() << "): " << duration << "ms\n";
shared_const_buffer buffer(*chunk); shared_const_buffer buffer(*chunk);
@ -112,7 +112,7 @@ void StreamServer::onChunkEncoded(const PcmStream* pcmStream, bool isDefaultStre
} }
void StreamServer::onMessageReceived(StreamSession* streamSession, const msg::BaseMessage& baseMessage, char* buffer) void StreamServer::onMessageReceived(const std::shared_ptr<StreamSession>& streamSession, const msg::BaseMessage& baseMessage, char* buffer)
{ {
try try
{ {
@ -122,7 +122,7 @@ void StreamServer::onMessageReceived(StreamSession* streamSession, const msg::Ba
catch (const std::exception& e) catch (const std::exception& e)
{ {
LOG(ERROR, LOG_TAG) << "Server::onMessageReceived exception: " << e.what() << ", message type: " << baseMessage.type << "\n"; LOG(ERROR, LOG_TAG) << "Server::onMessageReceived exception: " << e.what() << ", message type: " << baseMessage.type << "\n";
auto session = getStreamSession(streamSession); auto session = getStreamSession(streamSession.get());
session->stop(); session->stop();
} }
} }
@ -139,7 +139,7 @@ void StreamServer::onDisconnect(StreamSession* streamSession)
LOG(INFO, LOG_TAG) << "onDisconnect: " << session->clientId << "\n"; LOG(INFO, LOG_TAG) << "onDisconnect: " << session->clientId << "\n";
LOG(DEBUG, LOG_TAG) << "sessions: " << sessions_.size() << "\n"; LOG(DEBUG, LOG_TAG) << "sessions: " << sessions_.size() << "\n";
sessions_.erase(std::remove_if(sessions_.begin(), sessions_.end(), sessions_.erase(std::remove_if(sessions_.begin(), sessions_.end(),
[streamSession](std::weak_ptr<StreamSession> session) [streamSession](const std::weak_ptr<StreamSession>& session)
{ {
auto s = session.lock(); auto s = session.lock();
return s.get() == streamSession; return s.get() == streamSession;
@ -209,7 +209,7 @@ void StreamServer::handleAccept(tcp::socket socket)
socket.set_option(tcp::no_delay(true)); socket.set_option(tcp::no_delay(true));
LOG(NOTICE, LOG_TAG) << "StreamServer::NewConnection: " << socket.remote_endpoint().address().to_string() << "\n"; LOG(NOTICE, LOG_TAG) << "StreamServer::NewConnection: " << socket.remote_endpoint().address().to_string() << "\n";
shared_ptr<StreamSession> session = make_shared<StreamSessionTcp>(this, std::move(socket)); shared_ptr<StreamSession> session = make_shared<StreamSessionTcp>(this, settings_, std::move(socket));
addSession(session); addSession(session);
} }
catch (const std::exception& e) catch (const std::exception& e)

View file

@ -1,6 +1,6 @@
/*** /***
This file is part of snapcast This file is part of snapcast
Copyright (C) 2014-2024 Johannes Pohl Copyright (C) 2014-2025 Johannes Pohl
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
@ -53,19 +53,27 @@ using session_ptr = std::shared_ptr<StreamSession>;
class StreamServer : public StreamMessageReceiver class StreamServer : public StreamMessageReceiver
{ {
public: public:
/// c'tor
StreamServer(boost::asio::io_context& io_context, const ServerSettings& serverSettings, StreamMessageReceiver* messageReceiver = nullptr); StreamServer(boost::asio::io_context& io_context, const ServerSettings& serverSettings, StreamMessageReceiver* messageReceiver = nullptr);
/// d'tor
virtual ~StreamServer(); virtual ~StreamServer();
/// Start accepting connections
void start(); void start();
/// Stop accepting connections and active sessions
void stop(); void stop();
/// Send a message to all connceted clients /// Send a message to all connceted clients
// void send(const msg::BaseMessage* message); // void send(const msg::BaseMessage* message);
/// Add a new stream session
void addSession(std::shared_ptr<StreamSession> session); void addSession(std::shared_ptr<StreamSession> session);
void onChunkEncoded(const PcmStream* pcmStream, bool isDefaultStream, std::shared_ptr<msg::PcmChunk> chunk, double duration); /// Callback for chunks that are ready to be sent
void onChunkEncoded(const PcmStream* pcmStream, bool isDefaultStream, const std::shared_ptr<msg::PcmChunk>& chunk, double duration);
/// @return stream session for @p clientId
session_ptr getStreamSession(const std::string& clientId) const; session_ptr getStreamSession(const std::string& clientId) const;
/// @return stream session for @p session
session_ptr getStreamSession(StreamSession* session) const; session_ptr getStreamSession(StreamSession* session) const;
private: private:
@ -74,7 +82,7 @@ private:
void cleanup(); void cleanup();
/// Implementation of StreamMessageReceiver /// Implementation of StreamMessageReceiver
void onMessageReceived(StreamSession* streamSession, const msg::BaseMessage& baseMessage, char* buffer) override; void onMessageReceived(const std::shared_ptr<StreamSession>& streamSession, const msg::BaseMessage& baseMessage, char* buffer) override;
void onDisconnect(StreamSession* streamSession) override; void onDisconnect(StreamSession* streamSession) override;
mutable std::recursive_mutex sessionsMutex_; mutable std::recursive_mutex sessionsMutex_;

View file

@ -34,8 +34,8 @@ using namespace streamreader;
static constexpr auto LOG_TAG = "StreamSession"; static constexpr auto LOG_TAG = "StreamSession";
StreamSession::StreamSession(const boost::asio::any_io_executor& executor, StreamMessageReceiver* receiver) StreamSession::StreamSession(const boost::asio::any_io_executor& executor, const ServerSettings& server_settings, StreamMessageReceiver* receiver)
: messageReceiver_(receiver), pcmStream_(nullptr), strand_(boost::asio::make_strand(executor)) : authinfo(server_settings.auth), messageReceiver_(receiver), pcm_stream_(nullptr), strand_(boost::asio::make_strand(executor))
{ {
base_msg_size_ = baseMessage_.getSize(); base_msg_size_ = baseMessage_.getSize();
buffer_.resize(base_msg_size_); buffer_.resize(base_msg_size_);
@ -45,25 +45,29 @@ StreamSession::StreamSession(const boost::asio::any_io_executor& executor, Strea
void StreamSession::setPcmStream(PcmStreamPtr pcmStream) void StreamSession::setPcmStream(PcmStreamPtr pcmStream)
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
pcmStream_ = std::move(pcmStream); pcm_stream_ = std::move(pcmStream);
} }
const PcmStreamPtr StreamSession::pcmStream() const const PcmStreamPtr StreamSession::pcmStream() const
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return pcmStream_; return pcm_stream_;
} }
void StreamSession::send_next() void StreamSession::sendNext()
{ {
auto& buffer = messages_.front(); auto& buffer = messages_.front();
buffer.on_air = true; buffer.on_air = true;
boost::asio::post(strand_, [this, self = shared_from_this(), buffer]() boost::asio::post(strand_, [this, self = shared_from_this(), buffer]()
{ {
sendAsync(buffer, [this](boost::system::error_code ec, std::size_t length) sendAsync(buffer, [this, buffer](boost::system::error_code ec, std::size_t length)
{ {
auto write_handler = buffer.getWriteHandler();
if (write_handler)
write_handler(ec, length);
messages_.pop_front(); messages_.pop_front();
if (ec) if (ec)
{ {
@ -72,15 +76,15 @@ void StreamSession::send_next()
return; return;
} }
if (!messages_.empty()) if (!messages_.empty())
send_next(); sendNext();
}); });
}); });
} }
void StreamSession::send(shared_const_buffer const_buf) void StreamSession::send(shared_const_buffer const_buf, WriteHandler&& handler)
{ {
boost::asio::post(strand_, [this, self = shared_from_this(), const_buf]() boost::asio::post(strand_, [this, self = shared_from_this(), const_buf = std::move(const_buf), handler = std::move(handler)]() mutable
{ {
// delete PCM chunks that are older than the overall buffer duration // delete PCM chunks that are older than the overall buffer duration
messages_.erase(std::remove_if(messages_.begin(), messages_.end(), messages_.erase(std::remove_if(messages_.begin(), messages_.end(),
@ -94,25 +98,26 @@ void StreamSession::send(shared_const_buffer const_buf)
}), }),
messages_.end()); messages_.end());
messages_.push_back(const_buf); const_buf.setWriteHandler(std::move(handler));
messages_.push_back(std::move(const_buf));
if (messages_.size() > 1) if (messages_.size() > 1)
{ {
LOG(TRACE, LOG_TAG) << "outstanding async_write\n"; LOG(TRACE, LOG_TAG) << "outstanding async_write\n";
return; return;
} }
send_next(); sendNext();
}); });
} }
void StreamSession::send(msg::message_ptr message) void StreamSession::send(const msg::message_ptr& message, WriteHandler&& handler)
{ {
if (!message) if (!message)
return; return;
// TODO: better set the timestamp in send_next for more accurate time sync // TODO: better set the timestamp in send_next for more accurate time sync
send(shared_const_buffer(*message)); send(shared_const_buffer(*message), std::move(handler));
} }

View file

@ -1,6 +1,6 @@
/*** /***
This file is part of snapcast This file is part of snapcast
Copyright (C) 2014-2024 Johannes Pohl Copyright (C) 2014-2025 Johannes Pohl
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
@ -20,6 +20,7 @@
// local headers // local headers
#include "authinfo.hpp"
#include "common/message/message.hpp" #include "common/message/message.hpp"
#include "streamreader/stream_manager.hpp" #include "streamreader/stream_manager.hpp"
@ -40,29 +41,35 @@
class StreamSession; class StreamSession;
/// Write result callback function type
using WriteHandler = std::function<void(boost::system::error_code ec, std::size_t length)>;
/// Interface: callback for a received message. /// Interface: callback for a received message.
class StreamMessageReceiver class StreamMessageReceiver
{ {
public: public:
virtual void onMessageReceived(StreamSession* connection, const msg::BaseMessage& baseMessage, char* buffer) = 0; /// message received callback
virtual void onMessageReceived(const std::shared_ptr<StreamSession>& connection, const msg::BaseMessage& baseMessage, char* buffer) = 0;
/// disonnect callback
virtual void onDisconnect(StreamSession* connection) = 0; virtual void onDisconnect(StreamSession* connection) = 0;
}; };
// A reference-counted non-modifiable buffer class. /// A reference-counted non-modifiable buffer class.
class shared_const_buffer class shared_const_buffer
{ {
/// the message
struct Message struct Message
{ {
std::vector<char> data; std::vector<char> data; ///< data
bool is_pcm_chunk; bool is_pcm_chunk; ///< is it a PCM chunk
message_type type; message_type type; ///< message type
chronos::time_point_clk rec_time; chronos::time_point_clk rec_time; ///< recording time
}; };
public: public:
shared_const_buffer(msg::BaseMessage& message) : on_air(false) /// c'tor
explicit shared_const_buffer(msg::BaseMessage& message)
{ {
tv t; tv t;
message.sent = t; message.sent = t;
@ -83,32 +90,47 @@ public:
// Implement the ConstBufferSequence requirements. // Implement the ConstBufferSequence requirements.
using value_type = boost::asio::const_buffer; using value_type = boost::asio::const_buffer;
using const_iterator = const boost::asio::const_buffer*; using const_iterator = const boost::asio::const_buffer*;
/// begin iterator
const boost::asio::const_buffer* begin() const const boost::asio::const_buffer* begin() const
{ {
return &buffer_; return &buffer_;
} }
/// end iterator
const boost::asio::const_buffer* end() const const boost::asio::const_buffer* end() const
{ {
return &buffer_ + 1; return &buffer_ + 1;
} }
/// the payload
const Message& message() const const Message& message() const
{ {
return *message_; return *message_;
} }
bool on_air; /// set write callback
void setWriteHandler(WriteHandler&& handler)
{
handler_ = std::move(handler);
}
/// get write callback
const WriteHandler& getWriteHandler() const
{
return handler_;
}
/// is the buffer sent?
bool on_air{false};
private: private:
std::shared_ptr<Message> message_; std::shared_ptr<Message> message_;
boost::asio::const_buffer buffer_; boost::asio::const_buffer buffer_;
WriteHandler handler_;
}; };
/// Write result callback function type
using WriteHandler = std::function<void(boost::system::error_code ec, std::size_t length)>;
/// Endpoint for a connected client. /// Endpoint for a connected client.
/** /**
* Endpoint for a connected client. * Endpoint for a connected client.
@ -119,7 +141,7 @@ class StreamSession : public std::enable_shared_from_this<StreamSession>
{ {
public: public:
/// c'tor. Received message from the client are passed to StreamMessageReceiver /// c'tor. Received message from the client are passed to StreamMessageReceiver
StreamSession(const boost::asio::any_io_executor& executor, StreamMessageReceiver* receiver); StreamSession(const boost::asio::any_io_executor& executor, const ServerSettings& server_settings, StreamMessageReceiver* receiver);
/// d'tor /// d'tor
virtual ~StreamSession() = default; virtual ~StreamSession() = default;
@ -139,33 +161,40 @@ public:
protected: protected:
/// Send data @p buffer to the streaming client, result is returned in the callback @p handler /// Send data @p buffer to the streaming client, result is returned in the callback @p handler
virtual void sendAsync(const shared_const_buffer& buffer, const WriteHandler& handler) = 0; virtual void sendAsync(const shared_const_buffer& buffer, WriteHandler&& handler) = 0;
public: public:
/// Sends a message to the client (asynchronous) /// Sends a message to the client (asynchronous)
void send(msg::message_ptr message); void send(const msg::message_ptr& message, WriteHandler&& handler = nullptr);
/// Sends a message to the client (asynchronous) /// Sends a message to the client (asynchronous)
void send(shared_const_buffer const_buf); void send(shared_const_buffer const_buf, WriteHandler&& handler = nullptr);
/// Max playout latency. No need to send PCM data that is older than bufferMs /// Max playout latency. No need to send PCM data that is older than bufferMs
void setBufferMs(size_t bufferMs); void setBufferMs(size_t bufferMs);
/// Client id of the session
std::string clientId; std::string clientId;
/// Set the sessions PCM stream
void setPcmStream(streamreader::PcmStreamPtr pcmStream); void setPcmStream(streamreader::PcmStreamPtr pcmStream);
/// Get the sessions PCM stream
const streamreader::PcmStreamPtr pcmStream() const; const streamreader::PcmStreamPtr pcmStream() const;
protected: /// Authentication info attached to this session
void send_next(); AuthInfo authinfo;
msg::BaseMessage baseMessage_; protected:
std::vector<char> buffer_; /// Send next message from "messages_"
size_t base_msg_size_; void sendNext();
StreamMessageReceiver* messageReceiver_;
size_t bufferMs_; msg::BaseMessage baseMessage_; ///< base message buffer
streamreader::PcmStreamPtr pcmStream_; std::vector<char> buffer_; ///< buffer
boost::asio::strand<boost::asio::any_io_executor> strand_; size_t base_msg_size_; ///< size of a base message
std::deque<shared_const_buffer> messages_; StreamMessageReceiver* messageReceiver_; ///< message receiver
mutable std::mutex mutex_; size_t bufferMs_; ///< buffer size in [ms]
streamreader::PcmStreamPtr pcm_stream_; ///< the sessions PCM stream
boost::asio::strand<boost::asio::any_io_executor> strand_; ///< strand to sync IO on
std::deque<shared_const_buffer> messages_; ///< messages to be sent
mutable std::mutex mutex_; ///< protect pcm_stream_
}; };

View file

@ -34,8 +34,8 @@ using namespace streamreader;
static constexpr auto LOG_TAG = "StreamSessionTCP"; static constexpr auto LOG_TAG = "StreamSessionTCP";
StreamSessionTcp::StreamSessionTcp(StreamMessageReceiver* receiver, tcp::socket&& socket) StreamSessionTcp::StreamSessionTcp(StreamMessageReceiver* receiver, const ServerSettings& server_settings, tcp::socket&& socket)
: StreamSession(socket.get_executor(), receiver), socket_(std::move(socket)) : StreamSession(socket.get_executor(), server_settings, receiver), socket_(std::move(socket))
{ {
} }
@ -49,7 +49,7 @@ StreamSessionTcp::~StreamSessionTcp()
void StreamSessionTcp::start() void StreamSessionTcp::start()
{ {
read_next(); readNext();
} }
@ -83,7 +83,7 @@ std::string StreamSessionTcp::getIP()
} }
void StreamSessionTcp::read_next() void StreamSessionTcp::readNext()
{ {
boost::asio::async_read(socket_, boost::asio::buffer(buffer_, base_msg_size_), boost::asio::async_read(socket_, boost::asio::buffer(buffer_, base_msg_size_),
[this, self = shared_from_this()](boost::system::error_code ec, std::size_t length) mutable [this, self = shared_from_this()](boost::system::error_code ec, std::size_t length) mutable
@ -126,15 +126,19 @@ void StreamSessionTcp::read_next()
tv now; tv now;
baseMessage_.received = now; baseMessage_.received = now;
if (messageReceiver_ != nullptr) if (messageReceiver_ != nullptr)
messageReceiver_->onMessageReceived(this, baseMessage_, buffer_.data()); messageReceiver_->onMessageReceived(shared_from_this(), baseMessage_, buffer_.data());
read_next(); readNext();
}); });
}); });
} }
void StreamSessionTcp::sendAsync(const shared_const_buffer& buffer, const WriteHandler& handler) void StreamSessionTcp::sendAsync(const shared_const_buffer& buffer, WriteHandler&& handler)
{ {
boost::asio::async_write(socket_, buffer, boost::asio::async_write(socket_, buffer,
[self = shared_from_this(), buffer, handler](boost::system::error_code ec, std::size_t length) { handler(ec, length); }); [self = shared_from_this(), buffer, handler = std::move(handler)](boost::system::error_code ec, std::size_t length)
{
if (handler)
handler(ec, length);
});
} }

View file

@ -1,6 +1,6 @@
/*** /***
This file is part of snapcast This file is part of snapcast
Copyright (C) 2014-2024 Johannes Pohl Copyright (C) 2014-2025 Johannes Pohl
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
@ -40,15 +40,17 @@ class StreamSessionTcp : public StreamSession
{ {
public: public:
/// ctor. Received message from the client are passed to StreamMessageReceiver /// ctor. Received message from the client are passed to StreamMessageReceiver
StreamSessionTcp(StreamMessageReceiver* receiver, tcp::socket&& socket); StreamSessionTcp(StreamMessageReceiver* receiver, const ServerSettings& server_settings, tcp::socket&& socket);
~StreamSessionTcp() override; ~StreamSessionTcp() override;
void start() override; void start() override;
void stop() override; void stop() override;
std::string getIP() override; std::string getIP() override;
protected: protected:
void read_next(); /// Read next message
void sendAsync(const shared_const_buffer& buffer, const WriteHandler& handler) override; void readNext();
/// Send message @p buffer and pass result to @p handler
void sendAsync(const shared_const_buffer& buffer, WriteHandler&& handler) override;
private: private:
tcp::socket socket_; tcp::socket socket_;

View file

@ -1,6 +1,6 @@
/*** /***
This file is part of snapcast This file is part of snapcast
Copyright (C) 2014-2023 Johannes Pohl Copyright (C) 2014-2025 Johannes Pohl
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
@ -32,14 +32,14 @@ using namespace std;
static constexpr auto LOG_TAG = "StreamSessionWS"; static constexpr auto LOG_TAG = "StreamSessionWS";
StreamSessionWebsocket::StreamSessionWebsocket(StreamMessageReceiver* receiver, ssl_websocket&& ssl_ws) StreamSessionWebsocket::StreamSessionWebsocket(StreamMessageReceiver* receiver, const ServerSettings& server_settings, ssl_websocket&& ssl_ws)
: StreamSession(ssl_ws.get_executor(), receiver), ssl_ws_(std::move(ssl_ws)), is_ssl_(true) : StreamSession(ssl_ws.get_executor(), server_settings, receiver), ssl_ws_(std::move(ssl_ws)), is_ssl_(true)
{ {
LOG(DEBUG, LOG_TAG) << "StreamSessionWS, mode: ssl\n"; LOG(DEBUG, LOG_TAG) << "StreamSessionWS, mode: ssl\n";
} }
StreamSessionWebsocket::StreamSessionWebsocket(StreamMessageReceiver* receiver, tcp_websocket&& tcp_ws) StreamSessionWebsocket::StreamSessionWebsocket(StreamMessageReceiver* receiver, const ServerSettings& server_settings, tcp_websocket&& tcp_ws)
: StreamSession(tcp_ws.get_executor(), receiver), tcp_ws_(std::move(tcp_ws)), is_ssl_(false) : StreamSession(tcp_ws.get_executor(), server_settings, receiver), tcp_ws_(std::move(tcp_ws)), is_ssl_(false)
{ {
LOG(DEBUG, LOG_TAG) << "StreamSessionWS, mode: tcp\n"; LOG(DEBUG, LOG_TAG) << "StreamSessionWS, mode: tcp\n";
} }
@ -98,13 +98,21 @@ std::string StreamSessionWebsocket::getIP()
} }
void StreamSessionWebsocket::sendAsync(const shared_const_buffer& buffer, const WriteHandler& handler) void StreamSessionWebsocket::sendAsync(const shared_const_buffer& buffer, WriteHandler&& handler)
{ {
LOG(TRACE, LOG_TAG) << "sendAsync: " << buffer.message().type << "\n"; LOG(TRACE, LOG_TAG) << "sendAsync: " << buffer.message().type << "\n";
if (is_ssl_) if (is_ssl_)
ssl_ws_->async_write(buffer, [self = shared_from_this(), buffer, handler](boost::system::error_code ec, std::size_t length) { handler(ec, length); }); ssl_ws_->async_write(buffer, [self = shared_from_this(), buffer, handler = std::move(handler)](boost::system::error_code ec, std::size_t length)
{
if (handler)
handler(ec, length);
});
else else
tcp_ws_->async_write(buffer, [self = shared_from_this(), buffer, handler](boost::system::error_code ec, std::size_t length) { handler(ec, length); }); tcp_ws_->async_write(buffer, [self = shared_from_this(), buffer, handler = std::move(handler)](boost::system::error_code ec, std::size_t length)
{
if (handler)
handler(ec, length);
});
} }
@ -146,7 +154,7 @@ void StreamSessionWebsocket::on_read_ws(beast::error_code ec, std::size_t bytes_
baseMessage_.received = now; baseMessage_.received = now;
if (messageReceiver_ != nullptr) if (messageReceiver_ != nullptr)
messageReceiver_->onMessageReceived(this, baseMessage_, data + base_msg_size_); messageReceiver_->onMessageReceived(shared_from_this(), baseMessage_, data + base_msg_size_);
buffer_.consume(bytes_transferred); buffer_.consume(bytes_transferred);
do_read_ws(); do_read_ws();

View file

@ -1,6 +1,6 @@
/*** /***
This file is part of snapcast This file is part of snapcast
Copyright (C) 2014-2024 Johannes Pohl Copyright (C) 2014-2025 Johannes Pohl
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
@ -48,24 +48,26 @@ using ssl_websocket = websocket::stream<ssl_socket>;
class StreamSessionWebsocket : public StreamSession class StreamSessionWebsocket : public StreamSession
{ {
public: public:
/// ctor. Received message from the client are passed to StreamMessageReceiver /// c'tor for SSL. Received message from the client are passed to StreamMessageReceiver
StreamSessionWebsocket(StreamMessageReceiver* receiver, ssl_websocket&& ssl_ws); StreamSessionWebsocket(StreamMessageReceiver* receiver, const ServerSettings& server_settings, ssl_websocket&& ssl_ws);
StreamSessionWebsocket(StreamMessageReceiver* receiver, tcp_websocket&& tcp_ws); /// c'tor for TCP
StreamSessionWebsocket(StreamMessageReceiver* receiver, const ServerSettings& server_settings, tcp_websocket&& tcp_ws);
~StreamSessionWebsocket() override; ~StreamSessionWebsocket() override;
void start() override; void start() override;
void stop() override; void stop() override;
std::string getIP() override; std::string getIP() override;
protected: private:
// Websocket methods /// Send message @p buffer and pass result to @p handler
void sendAsync(const shared_const_buffer& buffer, const WriteHandler& handler) override; void sendAsync(const shared_const_buffer& buffer, WriteHandler&& handler) override;
/// Read callback
void on_read_ws(beast::error_code ec, std::size_t bytes_transferred); void on_read_ws(beast::error_code ec, std::size_t bytes_transferred);
/// Read loop
void do_read_ws(); void do_read_ws();
std::optional<ssl_websocket> ssl_ws_; std::optional<ssl_websocket> ssl_ws_; ///< SSL websocket
std::optional<tcp_websocket> tcp_ws_; std::optional<tcp_websocket> tcp_ws_; ///< TCP websocket
protected: beast::flat_buffer buffer_; ///< read buffer
beast::flat_buffer buffer_; bool is_ssl_; ///< are we in SSL mode?
bool is_ssl_;
}; };