From c1120589989c8e57dcff6a540432c564aa408517 Mon Sep 17 00:00:00 2001 From: badaix Date: Mon, 1 Jul 2024 21:57:44 +0200 Subject: [PATCH] Pass complete Settings struct around --- server/authinfo.cpp | 228 ++++++++++++++++++++++---- server/authinfo.hpp | 70 +++++++- server/control_server.cpp | 45 +++-- server/control_server.hpp | 9 +- server/control_session.hpp | 17 +- server/control_session_http.cpp | 24 +-- server/control_session_http.hpp | 17 +- server/control_session_tcp.cpp | 5 +- server/control_session_tcp.hpp | 4 +- server/control_session_ws.cpp | 8 +- server/control_session_ws.hpp | 10 +- server/server.cpp | 104 +++++++----- server/server.hpp | 9 +- server/server_settings.hpp | 18 ++ server/snapserver.cpp | 12 ++ server/streamreader/control_error.cpp | 2 - test/CMakeLists.txt | 4 + test/test_main.cpp | 138 ++++++++++++++++ 18 files changed, 584 insertions(+), 140 deletions(-) diff --git a/server/authinfo.cpp b/server/authinfo.cpp index 27ae9d75..57e0441d 100644 --- a/server/authinfo.cpp +++ b/server/authinfo.cpp @@ -28,61 +28,227 @@ // 3rd party headers // standard headers +#include #include #include #include -#include +#include using namespace std; static constexpr auto LOG_TAG = "AuthInfo"; -AuthInfo::AuthInfo(std::string authheader) + +namespace snapcast::error::auth { - LOG(INFO, LOG_TAG) << "Authorization: " << authheader << "\n"; - std::string token(std::move(authheader)); - static constexpr auto bearer = "bearer"sv; - auto pos = utils::string::tolower_copy(token).find(bearer); - if (pos != string::npos) + +namespace detail +{ + +/// Error category for auth errors +struct category : public std::error_category +{ +public: + /// @return category name + const char* name() const noexcept override; + /// @return error message for @p value + std::string message(int value) const override; +}; + + +const char* category::name() const noexcept +{ + return "auth"; +} + +std::string category::message(int value) const +{ + switch (static_cast(value)) { - token = token.erase(0, pos + bearer.length()); - utils::string::trim(token); - std::ifstream ifs("certs/snapserver.crt"); - std::string certificate((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); - Jwt jwt; - jwt.parse(token, certificate); - if (jwt.getExp().has_value()) - expires_ = jwt.getExp().value(); - username_ = jwt.getSub().value_or(""); - LOG(INFO, LOG_TAG) << "Authorization token: " << token << ", user: " << username_ << ", claims: " << jwt.claims.dump() << "\n"; - } - static constexpr auto basic = "basic"sv; - pos = utils::string::tolower_copy(token).find(basic); - if (pos != string::npos) - { - token = token.erase(0, pos + basic.length()); - utils::string::trim(token); - username_ = base64_decode(token); - std::string password; - username_ = utils::string::split_left(username_, ':', password); - LOG(INFO, LOG_TAG) << "Authorization basic: " << token << ", user: " << username_ << ", password: " << password << "\n"; + case AuthErrc::auth_scheme_not_supported: + return "Authentication scheme not supported"; + case AuthErrc::failed_to_create_token: + return "Failed to create token"; + case AuthErrc::unknown_user: + return "Unknown user"; + case AuthErrc::wrong_password: + return "Wrong password"; + case AuthErrc::expired: + return "Expired"; + case AuthErrc::token_validation_failed: + return "Token validation failed"; + default: + return "Unknown"; } } +} // namespace detail -bool AuthInfo::valid() const +const std::error_category& category() +{ + // The category singleton + static detail::category instance; + return instance; +} + +} // namespace snapcast::error::auth + +std::error_code make_error_code(AuthErrc errc) +{ + return std::error_code(static_cast(errc), snapcast::error::auth::category()); +} + + +AuthInfo::AuthInfo(const ServerSettings& settings) : has_auth_info_(false), settings_(settings) +{ +} + + +ErrorCode AuthInfo::validateUser(const std::string& username, const std::optional& password) const +{ + auto iter = std::find_if(settings_.users.begin(), settings_.users.end(), [&](const ServerSettings::User& user) { return user.name == username; }); + if (iter == settings_.users.end()) + return ErrorCode{AuthErrc::unknown_user}; + if (password.has_value() && (iter->password != password.value())) + return ErrorCode{AuthErrc::wrong_password}; + return {}; +} + + +ErrorCode AuthInfo::authenticate(const std::string& scheme, const std::string& param) +{ + std::string scheme_normed = utils::string::trim_copy(utils::string::tolower_copy(scheme)); + std::string param_normed = utils::string::trim_copy(param); + if (scheme_normed == "bearer") + return authenticateBearer(param_normed); + else if (scheme_normed == "basic") + return authenticateBasic(param_normed); + + return {AuthErrc::auth_scheme_not_supported, "Scheme must be 'Basic' or 'Bearer'"}; +} + + +ErrorCode AuthInfo::authenticate(const std::string& auth) +{ + LOG(INFO, LOG_TAG) << "authenticate: " << auth << "\n"; + std::string param; + std::string scheme = utils::string::split_left(utils::string::trim_copy(auth), ' ', param); + return authenticate(scheme, param); +} + + +ErrorCode AuthInfo::authenticateBasic(const std::string& credentials) +{ + has_auth_info_ = false; + std::string username = base64_decode(credentials); + std::string password; + username_ = utils::string::split_left(username, ':', password); + auto ec = validateUser(username_, password); + + LOG(INFO, LOG_TAG) << "Authorization basic: " << credentials << ", user: " << username_ << ", password: " << password << "\n"; + has_auth_info_ = (ec.value() == 0); + return ec; +} + + +ErrorCode AuthInfo::authenticateBearer(const std::string& token) +{ + has_auth_info_ = false; + std::ifstream ifs(settings_.ssl.certificate); + std::string certificate((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + Jwt jwt; + if (!jwt.parse(token, certificate)) + return {AuthErrc::token_validation_failed}; + if (jwt.getExp().has_value()) + expires_ = jwt.getExp().value(); + username_ = jwt.getSub().value_or(""); + + LOG(INFO, LOG_TAG) << "Authorization token: " << token << ", user: " << username_ << ", claims: " << jwt.claims.dump() << "\n"; + + if (auto ec = validateUser(username_); ec) + return ec; + + if (isExpired()) + return {AuthErrc::expired}; + + has_auth_info_ = true; + return {}; +} + + +ErrorOr AuthInfo::getToken(const std::string& username, const std::string& password) const +{ + ErrorCode ec = validateUser(username, password); + if (ec) + return ec; + + Jwt jwt; + auto now = std::chrono::system_clock::now(); + jwt.setIat(now); + jwt.setExp(now + 10h); + jwt.setSub(username); + std::ifstream ifs(settings_.ssl.private_key); + std::string private_key((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + if (!ifs.good()) + return ErrorCode{std::make_error_code(std::errc::io_error), "Failed to read private key file"}; + // TODO tls: eroor handling + std::optional token = jwt.getToken(private_key); + if (!token.has_value()) + return ErrorCode{AuthErrc::failed_to_create_token}; + return token.value(); +} + + +bool AuthInfo::isExpired() const { if (expires_.has_value()) { LOG(INFO, LOG_TAG) << "Expires in " << std::chrono::duration_cast(expires_.value() - std::chrono::system_clock::now()).count() << " sec\n"; - return expires_ > std::chrono::system_clock::now(); + if (std::chrono::system_clock::now() > expires_.value()) + return true; } - return true; + return false; } + +bool AuthInfo::hasAuthInfo() const +{ + return has_auth_info_; +} + + +// ErrorCode AuthInfo::isValid(const std::string& command) const +// { +// std::ignore = command; +// if (isExpired()) +// return {AuthErrc::expired}; + +// return {}; +// } + const std::string& AuthInfo::username() const { return username_; } + + +bool AuthInfo::hasPermission(const std::string& resource) const +{ + if (!hasAuthInfo()) + return false; + + auto iter = std::find_if(settings_.users.begin(), settings_.users.end(), [&](const ServerSettings::User& user) { return user.name == username_; }); + if (iter == settings_.users.end()) + return false; + + auto perm_iter = std::find_if(iter->permissions.begin(), iter->permissions.end(), + [&](const std::string& permission) { return utils::string::wildcardMatch(permission, resource); }); + if (perm_iter != iter->permissions.end()) + { + LOG(DEBUG, LOG_TAG) << "Found permission for ressource '" << resource << "': '" << *perm_iter << "'\n"; + return true; + } + return false; +} diff --git a/server/authinfo.hpp b/server/authinfo.hpp index 7cc4f5f3..ab9870f5 100644 --- a/server/authinfo.hpp +++ b/server/authinfo.hpp @@ -19,7 +19,8 @@ #pragma once // local headers -#include "common/jwt.hpp" +#include "common/error_code.hpp" +#include "server_settings.hpp" // 3rd party headers @@ -27,19 +28,82 @@ #include #include #include +#include + +/// Authentication error codes +enum class AuthErrc +{ + auth_scheme_not_supported = 1, + failed_to_create_token = 2, + unknown_user = 3, + wrong_password = 4, + expired = 5, + token_validation_failed = 6, +}; + +namespace snapcast::error::auth +{ +const std::error_category& category(); +} +namespace std +{ +template <> +struct is_error_code_enum : public std::true_type +{ +}; +} // namespace std + +std::error_code make_error_code(AuthErrc); + +using snapcast::ErrorCode; +using snapcast::ErrorOr; + +/// Authentication Info class class AuthInfo { public: - AuthInfo(std::string authheader); + /// c'tor + explicit AuthInfo(const ServerSettings& settings); + // explicit AuthInfo(std::string authheader); + /// d'tor virtual ~AuthInfo() = default; - bool valid() const; + /// @return if authentication info is available + bool hasAuthInfo() const; + // ErrorCode isValid(const std::string& command) const; + /// @return the username const std::string& username() const; + /// Authenticate with basic scheme + ErrorCode authenticateBasic(const std::string& credentials); + /// Authenticate with bearer scheme + ErrorCode authenticateBearer(const std::string& token); + /// Authenticate with basic or bearer scheme with an auth header + ErrorCode authenticate(const std::string& auth); + /// Authenticate with scheme ("basic" or "bearer") and auth param + ErrorCode authenticate(const std::string& scheme, const std::string& param); + + /// @return JWS token for @p username and @p password + ErrorOr getToken(const std::string& username, const std::string& password) const; + /// @return if the authenticated user has permission to access @p ressource + bool hasPermission(const std::string& resource) const; + private: + /// has auth info + bool has_auth_info_; + /// auth user name std::string username_; + /// optional token expiration std::optional expires_; + /// server configuration + ServerSettings settings_; + + /// Validate @p username and @p password + /// @return true if username and password are correct + ErrorCode validateUser(const std::string& username, const std::optional& password = std::nullopt) const; + /// @return if the authentication is expired + bool isExpired() const; }; diff --git a/server/control_server.cpp b/server/control_server.cpp index 9f58a84f..9d051e68 100644 --- a/server/control_server.cpp +++ b/server/control_server.cpp @@ -39,11 +39,10 @@ static constexpr auto LOG_TAG = "ControlServer"; ControlServer::ControlServer(boost::asio::io_context& io_context, const ServerSettings& settings, ControlMessageReceiver* controlMessageReceiver) - : io_context_(io_context), ssl_context_(boost::asio::ssl::context::sslv23), tcp_settings_(settings.tcp), http_settings_(settings.http), - controlMessageReceiver_(controlMessageReceiver) + : io_context_(io_context), ssl_context_(boost::asio::ssl::context::sslv23), settings_(settings), controlMessageReceiver_(controlMessageReceiver) { const ServerSettings::Ssl& ssl = settings.ssl; - if (http_settings_.ssl_enabled) + if (settings_.http.ssl_enabled) { ssl_context_.set_options(boost::asio::ssl::context::default_workarounds | boost::asio::ssl::context::no_sslv2 | boost::asio::ssl::context::single_dh_use); @@ -99,7 +98,7 @@ void ControlServer::send(const std::string& message, const ControlSession* exclu } -void ControlServer::onMessageReceived(std::shared_ptr session, const std::string& message, const ResponseHander& response_handler) +void ControlServer::onMessageReceived(std::shared_ptr session, const std::string& message, const ResponseHandler& response_handler) { // LOG(DEBUG, LOG_TAG) << "received: \"" << message << "\"\n"; if (controlMessageReceiver_ != nullptr) @@ -138,19 +137,19 @@ void ControlServer::startAccept() auto port = socket.local_endpoint().port(); LOG(NOTICE, LOG_TAG) << "New connection from: " << socket.remote_endpoint().address().to_string() << ", port: " << port << endl; - if (port == http_settings_.ssl_port) + if (port == settings_.http.ssl_port) { - auto session = make_shared(this, ssl_socket(std::move(socket), ssl_context_), http_settings_); + auto session = make_shared(this, ssl_socket(std::move(socket), ssl_context_), settings_); onNewSession(std::move(session)); } - else if (port == http_settings_.port) + else if (port == settings_.http.port) { - auto session = make_shared(this, std::move(socket), http_settings_); + auto session = make_shared(this, std::move(socket), settings_); onNewSession(std::move(session)); } - else if (port == tcp_settings_.port) + else if (port == settings_.tcp.port) { - auto session = make_shared(this, std::move(socket)); + auto session = make_shared(this, std::move(socket), settings_); onNewSession(std::move(session)); } else @@ -171,15 +170,15 @@ void ControlServer::startAccept() void ControlServer::start() { - if (tcp_settings_.enabled) + if (settings_.tcp.enabled) { - for (const auto& address : tcp_settings_.bind_to_address) + for (const auto& address : settings_.tcp.bind_to_address) { try { - LOG(INFO, LOG_TAG) << "Creating TCP acceptor for address: " << address << ", port: " << tcp_settings_.port << "\n"; + LOG(INFO, LOG_TAG) << "Creating TCP acceptor for address: " << address << ", port: " << settings_.tcp.port << "\n"; acceptor_.emplace_back(make_unique(boost::asio::make_strand(io_context_.get_executor()), - tcp::endpoint(boost::asio::ip::address::from_string(address), tcp_settings_.port))); + tcp::endpoint(boost::asio::ip::address::from_string(address), settings_.tcp.port))); } catch (const boost::system::system_error& e) { @@ -187,17 +186,17 @@ void ControlServer::start() } } } - if (http_settings_.enabled || http_settings_.ssl_enabled) + if (settings_.http.enabled || settings_.http.ssl_enabled) { - if (http_settings_.enabled) + if (settings_.http.enabled) { - for (const auto& address : http_settings_.bind_to_address) + for (const auto& address : settings_.http.bind_to_address) { try { - LOG(INFO, LOG_TAG) << "Creating HTTP acceptor for address: " << address << ", port: " << http_settings_.port << "\n"; + LOG(INFO, LOG_TAG) << "Creating HTTP acceptor for address: " << address << ", port: " << settings_.http.port << "\n"; acceptor_.emplace_back(make_unique(boost::asio::make_strand(io_context_.get_executor()), - tcp::endpoint(boost::asio::ip::address::from_string(address), http_settings_.port))); + tcp::endpoint(boost::asio::ip::address::from_string(address), settings_.http.port))); } catch (const boost::system::system_error& e) { @@ -206,15 +205,15 @@ void ControlServer::start() } } - if (http_settings_.ssl_enabled) + if (settings_.http.ssl_enabled) { - for (const auto& address : http_settings_.ssl_bind_to_address) + for (const auto& address : settings_.http.ssl_bind_to_address) { try { - LOG(INFO, LOG_TAG) << "Creating HTTPS acceptor for address: " << address << ", port: " << http_settings_.ssl_port << "\n"; + LOG(INFO, LOG_TAG) << "Creating HTTPS acceptor for address: " << address << ", port: " << settings_.http.ssl_port << "\n"; acceptor_.emplace_back(make_unique(boost::asio::make_strand(io_context_.get_executor()), - tcp::endpoint(boost::asio::ip::address::from_string(address), http_settings_.ssl_port))); + tcp::endpoint(boost::asio::ip::address::from_string(address), settings_.http.ssl_port))); } catch (const boost::system::system_error& e) { diff --git a/server/control_server.hpp b/server/control_server.hpp index 7e770f39..45e35595 100644 --- a/server/control_server.hpp +++ b/server/control_server.hpp @@ -43,10 +43,14 @@ using acceptor_ptr = std::unique_ptr; class ControlServer : public ControlMessageReceiver { public: + /// c'tor ControlServer(boost::asio::io_context& io_context, const ServerSettings& settings, ControlMessageReceiver* controlMessageReceiver = nullptr); + /// d'tor virtual ~ControlServer(); + /// Start accepting control connections void start(); + /// Stop accepting connections and stop all running sessions void stop(); /// Send a message to all connected clients @@ -58,7 +62,7 @@ private: void cleanup(); /// Implementation of ControlMessageReceiver - void onMessageReceived(std::shared_ptr session, const std::string& message, const ResponseHander& response_handler) override; + void onMessageReceived(std::shared_ptr session, const std::string& message, const ResponseHandler& response_handler) override; void onNewSession(std::shared_ptr session) override; void onNewSession(std::shared_ptr session) override; @@ -69,7 +73,6 @@ private: boost::asio::io_context& io_context_; boost::asio::ssl::context ssl_context_; - ServerSettings::Tcp tcp_settings_; - ServerSettings::Http http_settings_; + ServerSettings settings_; ControlMessageReceiver* controlMessageReceiver_; }; diff --git a/server/control_session.hpp b/server/control_session.hpp index f5a52004..5d02957c 100644 --- a/server/control_session.hpp +++ b/server/control_session.hpp @@ -20,6 +20,7 @@ // local headers #include "authinfo.hpp" +#include "server_settings.hpp" // 3rd party headers @@ -37,10 +38,14 @@ class StreamSession; class ControlMessageReceiver { public: - using ResponseHander = std::function; + /// Response callback function for requests + using ResponseHandler = std::function; // TODO: rename, error handling - virtual void onMessageReceived(std::shared_ptr session, const std::string& message, const ResponseHander& response_handler) = 0; + /// Called when a comtrol message @p message is received by @p session, response is written to @p response_handler + virtual void onMessageReceived(std::shared_ptr session, const std::string& message, const ResponseHandler& response_handler) = 0; + /// Called when a comtrol session is created virtual void onNewSession(std::shared_ptr session) = 0; + /// Called when a stream session is created virtual void onNewSession(std::shared_ptr session) = 0; }; @@ -55,18 +60,22 @@ class ControlSession : public std::enable_shared_from_this { public: /// ctor. Received message from the client are passed to ControlMessageReceiver - ControlSession(ControlMessageReceiver* receiver) : message_receiver_(receiver) + ControlSession(ControlMessageReceiver* receiver, const ServerSettings& settings) : authinfo(settings), message_receiver_(receiver) { } virtual ~ControlSession() = default; + /// Start the control session virtual void start() = 0; + /// Stop the control session virtual void stop() = 0; /// Sends a message to the client (asynchronous) virtual void sendAsync(const std::string& message) = 0; - std::optional authinfo; + /// Authentication info attached to this session + AuthInfo authinfo; protected: + /// The control message receiver ControlMessageReceiver* message_receiver_; }; diff --git a/server/control_session_http.cpp b/server/control_session_http.cpp index d5b4c3e5..1d7ba5bd 100644 --- a/server/control_session_http.cpp +++ b/server/control_session_http.cpp @@ -149,14 +149,14 @@ std::string path_cat(boost::beast::string_view base, boost::beast::string_view p } } // namespace -ControlSessionHttp::ControlSessionHttp(ControlMessageReceiver* receiver, ssl_socket&& socket, const ServerSettings::Http& settings) - : ControlSession(receiver), ssl_socket_(std::move(socket)), settings_(settings), is_ssl_(true) +ControlSessionHttp::ControlSessionHttp(ControlMessageReceiver* receiver, ssl_socket&& socket, const ServerSettings& settings) + : ControlSession(receiver, settings), ssl_socket_(std::move(socket)), settings_(settings), is_ssl_(true) { LOG(DEBUG, LOG_TAG) << "ControlSessionHttp, mode: ssl, Local IP: " << ssl_socket_->next_layer().local_endpoint().address().to_string() << "\n"; } -ControlSessionHttp::ControlSessionHttp(ControlMessageReceiver* receiver, tcp_socket&& socket, const ServerSettings::Http& settings) - : ControlSession(receiver), tcp_socket_(std::move(socket)), settings_(settings), is_ssl_(false) +ControlSessionHttp::ControlSessionHttp(ControlMessageReceiver* receiver, tcp_socket&& socket, const ServerSettings& settings) + : ControlSession(receiver, settings), tcp_socket_(std::move(socket)), settings_(settings), is_ssl_(false) { LOG(DEBUG, LOG_TAG) << "ControlSessionHttp, mode: tcp, Local IP: " << tcp_socket_->local_endpoint().address().to_string() << "\n"; } @@ -288,7 +288,7 @@ void ControlSessionHttp::handle_request(http::request(message_receiver_, std::move(*ws)); + auto ws_session = make_shared(message_receiver_, std::move(*ws), settings_); message_receiver_->onNewSession(std::move(ws_session)); } else // if (req_.target() == "/stream") @@ -433,7 +433,7 @@ void ControlSessionHttp::on_read(beast::error_code ec, std::size_t bytes_transfe { if (req_.target() == "/jsonrpc") { - auto ws_session = make_shared(message_receiver_, std::move(*ws)); + auto ws_session = make_shared(message_receiver_, std::move(*ws), settings_); message_receiver_->onNewSession(std::move(ws_session)); } else // if (req_.target() == "/stream") @@ -452,7 +452,11 @@ void ControlSessionHttp::on_read(beast::error_code ec, std::size_t bytes_transfe std::string_view authheader = req_[beast::http::field::authorization]; if (!authheader.empty()) { - authinfo = AuthInfo(std::string(authheader)); + auto ec = authinfo.authenticate(std::string(authheader)); + if (ec) + { + LOG(ERROR, LOG_TAG) << "Authentication failed: " << ec.detailed_message() << "\n"; + } } // Send the response diff --git a/server/control_session_http.hpp b/server/control_session_http.hpp index 811d0e0c..a1b2900a 100644 --- a/server/control_session_http.hpp +++ b/server/control_session_http.hpp @@ -55,9 +55,10 @@ using ssl_socket = boost::asio::ssl::stream; class ControlSessionHttp : public ControlSession { public: - /// ctor. Received message from the client are passed to ControlMessageReceiver - ControlSessionHttp(ControlMessageReceiver* receiver, ssl_socket&& socket, const ServerSettings::Http& settings); - ControlSessionHttp(ControlMessageReceiver* receiver, tcp_socket&& socket, const ServerSettings::Http& settings); + /// c'tor for ssl sockets. Received message from the client are passed to ControlMessageReceiver + ControlSessionHttp(ControlMessageReceiver* receiver, ssl_socket&& socket, const ServerSettings& settings); + /// c'tor for tcp sockets + ControlSessionHttp(ControlMessageReceiver* receiver, tcp_socket&& socket, const ServerSettings& settings); ~ControlSessionHttp() override; void start() override; void stop() override; @@ -65,21 +66,21 @@ public: /// Sends a message to the client (asynchronous) void sendAsync(const std::string& message) override; -protected: - // HTTP methods +private: + /// HTTP on read callback void on_read(beast::error_code ec, std::size_t bytes_transferred); + /// HTTP on write callback void on_write(beast::error_code ec, std::size_t bytes, bool close); + /// Handle an incoming HTTP request template void handle_request(http::request>&& req, Send&& send); http::request req_; - -protected: std::optional tcp_socket_; std::optional ssl_socket_; beast::flat_buffer buffer_; - ServerSettings::Http settings_; + ServerSettings settings_; std::deque messages_; bool is_ssl_; }; diff --git a/server/control_session_tcp.cpp b/server/control_session_tcp.cpp index 4311ff06..6e653f04 100644 --- a/server/control_session_tcp.cpp +++ b/server/control_session_tcp.cpp @@ -26,6 +26,7 @@ // local headers #include "common/aixlog.hpp" +#include "server_settings.hpp" using namespace std; @@ -35,8 +36,8 @@ static constexpr auto LOG_TAG = "ControlSessionTCP"; // https://stackoverflow.com/questions/7754695/boost-asio-async-write-how-to-not-interleaving-async-write-calls/7756894 -ControlSessionTcp::ControlSessionTcp(ControlMessageReceiver* receiver, tcp::socket&& socket) - : ControlSession(receiver), socket_(std::move(socket)), strand_(boost::asio::make_strand(socket_.get_executor())) +ControlSessionTcp::ControlSessionTcp(ControlMessageReceiver* receiver, tcp::socket&& socket, const ServerSettings& settings) + : ControlSession(receiver, settings), socket_(std::move(socket)), strand_(boost::asio::make_strand(socket_.get_executor())) { } diff --git a/server/control_session_tcp.hpp b/server/control_session_tcp.hpp index b27979b1..59aa4b41 100644 --- a/server/control_session_tcp.hpp +++ b/server/control_session_tcp.hpp @@ -43,7 +43,7 @@ class ControlSessionTcp : public ControlSession { public: /// ctor. Received message from the client are passed to ControlMessageReceiver - ControlSessionTcp(ControlMessageReceiver* receiver, tcp::socket&& socket); + ControlSessionTcp(ControlMessageReceiver* receiver, tcp::socket&& socket, const ServerSettings& settings); ~ControlSessionTcp() override; void start() override; void stop() override; @@ -51,7 +51,7 @@ public: /// Sends a message to the client (asynchronous) void sendAsync(const std::string& message) override; -protected: +private: void do_read(); void send_next(); diff --git a/server/control_session_ws.cpp b/server/control_session_ws.cpp index cded57e6..256564fc 100644 --- a/server/control_session_ws.cpp +++ b/server/control_session_ws.cpp @@ -32,14 +32,14 @@ using namespace std; static constexpr auto LOG_TAG = "ControlSessionWS"; -ControlSessionWebsocket::ControlSessionWebsocket(ControlMessageReceiver* receiver, ssl_websocket&& ssl_ws) - : ControlSession(receiver), ssl_ws_(std::move(ssl_ws)), strand_(boost::asio::make_strand(ssl_ws_->get_executor())), is_ssl_(true) +ControlSessionWebsocket::ControlSessionWebsocket(ControlMessageReceiver* receiver, ssl_websocket&& ssl_ws, const ServerSettings& settings) + : ControlSession(receiver, settings), ssl_ws_(std::move(ssl_ws)), strand_(boost::asio::make_strand(ssl_ws_->get_executor())), is_ssl_(true) { LOG(DEBUG, LOG_TAG) << "ControlSessionWebsocket, mode: ssl\n"; } -ControlSessionWebsocket::ControlSessionWebsocket(ControlMessageReceiver* receiver, tcp_websocket&& tcp_ws) - : ControlSession(receiver), tcp_ws_(std::move(tcp_ws)), strand_(boost::asio::make_strand(tcp_ws_->get_executor())), is_ssl_(false) +ControlSessionWebsocket::ControlSessionWebsocket(ControlMessageReceiver* receiver, tcp_websocket&& tcp_ws, const ServerSettings& settings) + : ControlSession(receiver, settings), tcp_ws_(std::move(tcp_ws)), strand_(boost::asio::make_strand(tcp_ws_->get_executor())), is_ssl_(false) { LOG(DEBUG, LOG_TAG) << "ControlSessionWebsocket, mode: tcp\n"; } diff --git a/server/control_session_ws.hpp b/server/control_session_ws.hpp index e3e78075..a04f12f0 100644 --- a/server/control_session_ws.hpp +++ b/server/control_session_ws.hpp @@ -65,9 +65,10 @@ using ssl_websocket = websocket::stream; class ControlSessionWebsocket : public ControlSession { public: - /// ctor. Received message from the client are passed to ControlMessageReceiver - ControlSessionWebsocket(ControlMessageReceiver* receiver, ssl_websocket&& ssl_ws); - ControlSessionWebsocket(ControlMessageReceiver* receiver, tcp_websocket&& tcp_ws); + /// c'tor for ssl websockets. Received message from the client are passed to ControlMessageReceiver + ControlSessionWebsocket(ControlMessageReceiver* receiver, ssl_websocket&& ssl_ws, const ServerSettings& settings); + /// c'tor for TCP websockets. Received message from the client are passed to ControlMessageReceiver + ControlSessionWebsocket(ControlMessageReceiver* receiver, tcp_websocket&& tcp_ws, const ServerSettings& settings); ~ControlSessionWebsocket() override; void start() override; void stop() override; @@ -75,7 +76,7 @@ public: /// Sends a message to the client (asynchronous) void sendAsync(const std::string& message) override; -protected: +private: // Websocket methods void on_read_ws(beast::error_code ec, std::size_t bytes_transferred); void do_read_ws(); @@ -84,7 +85,6 @@ protected: std::optional ssl_ws_; std::optional tcp_ws_; -protected: beast::flat_buffer buffer_; boost::asio::strand strand_; std::deque messages_; diff --git a/server/server.cpp b/server/server.cpp index 3a88aafe..dcbc54a0 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -21,12 +21,15 @@ // local headers #include "common/aixlog.hpp" +#include "common/base64.h" #include "common/jwt.hpp" #include "common/message/client_info.hpp" #include "common/message/hello.hpp" #include "common/message/server_settings.hpp" #include "common/message/time.hpp" +#include "common/utils/string_utils.hpp" #include "config.hpp" +#include "jsonrpcpp.hpp" // 3rd party headers @@ -131,9 +134,9 @@ void Server::onDisconnect(StreamSession* streamSession) } -void Server::processRequest(const jsonrpcpp::request_ptr request, const OnResponse& on_response) const +void Server::processRequest(const jsonrpcpp::request_ptr request, AuthInfo& authinfo, const OnResponse& on_response) const { - jsonrpcpp::entity_ptr response; + jsonrpcpp::entity_ptr response = nullptr; jsonrpcpp::notification_ptr notification; try { @@ -407,30 +410,48 @@ void Server::processRequest(const jsonrpcpp::request_ptr request, const OnRespon else if (request->method() == "Server.Authenticate") { // clang-format off - // Request: {"id":8,"jsonrpc":"2.0","method":"Server.Authenticate","params":{"user":"badaix","password":"secret"}} + // Request: {"id":8,"jsonrpc":"2.0","method":"Server.Authenticate","params":{"scheme":"Basic","param":"YmFkYWl4OnNlY3JldA=="}} + // Response: {"id":8,"jsonrpc":"2.0","result":"ok"} + // Request: {"id":8,"jsonrpc":"2.0","method":"Server.Authenticate","params":{"scheme":"Bearer","param":"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MTg1NjQ1MTYsImlhdCI6MTcxODUyODUxNiwic3ViIjoiQmFkYWl4In0.gHrMVp7jTAg8aCSg3cttcfIxswqmOPuqVNOb5p79Cn0NmAqRmLXtDLX4QjOoOqqb66ezBBeikpNjPi_aO18YPoNmX9fPxSwcObTHBupnm5eugEpneMPDASFUSE2hg8rrD_OEoAVxx6hCLln7Z3ILyWDmR6jcmy7z0bp0BiAqOywUrFoVIsnlDZRs3wOaap5oS9J2oaA_gNi_7OuvAhrydn26LDhm0KiIqEcyIholkpRHrDYODkz98h2PkZdZ2U429tTvVhzDBJ1cBq2Zq3cvuMZT6qhwaUc8eYA8fUJ7g65iP4o2OZtUzlfEUqX1TKyuWuSK6CUlsZooNE-MSCT7_w"}} + // Response: {"id":8,"jsonrpc":"2.0","result":"ok"} + // clang-format on + if (!request->params().has("scheme")) + throw jsonrpcpp::InvalidParamsException("Parameter 'scheme' is missing", request->id()); + if (!request->params().has("param")) + throw jsonrpcpp::InvalidParamsException("Parameter 'param' is missing", request->id()); + + auto scheme = request->params().get("scheme"); + auto param = request->params().get("param"); + LOG(INFO, LOG_TAG) << "Authorization scheme: " << scheme << ", param: " << param << "\n"; + auto ec = authinfo.authenticate(scheme, param); + + if (ec) + response = make_shared(request->id(), jsonrpcpp::Error(ec.detailed_message(), ec.value())); + else + response = make_shared(request->id(), "ok"); + // LOG(DEBUG, LOG_TAG) << response->to_json().dump() << "\n"; + } + else if (request->method() == "Server.GetToken") + { + // clang-format off + // Request: {"id":8,"jsonrpc":"2.0","method":"Server.GetToken","params":{"username":"Badaix","password":"secret"}} // Response: {"id":8,"jsonrpc":"2.0","result":{"token":""}} // clang-format on - if (request->params().has("token")) - { - auto token = request->params().get("token"); - LOG(INFO, LOG_TAG) << "Server.Authenticate, token: " << token << "\n"; - result["token"] = token; - } - else if (request->params().has("user")) - { - auto user = request->params().get("user"); - Jwt jwt; - auto now = std::chrono::system_clock::now(); - jwt.setIat(now); - jwt.setExp(now + 10h); - jwt.setSub(user); - std::ifstream ifs(settings_.ssl.private_key.c_str()); - std::string private_key((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); - std::optional token = jwt.getToken(private_key); - result["token"] = token.value(); - LOG(INFO, LOG_TAG) << "Server.Authenticate, user: " << user << ", password: " << request->params().get("password") - << ", jwt claims: " << jwt.claims.dump() << ", token: '" << token.value_or("") << "'\n"; - } + if (!request->params().has("username")) + throw jsonrpcpp::InvalidParamsException("Parameter 'username' is missing", request->id()); + if (!request->params().has("password")) + throw jsonrpcpp::InvalidParamsException("Parameter 'password' is missing", request->id()); + + auto username = request->params().get("username"); + auto password = request->params().get("password"); + LOG(INFO, LOG_TAG) << "GetToken username: " << username << ", password: " << password << "\n"; + auto token = authinfo.getToken(username, password); + + if (token.hasError()) + response = make_shared(request->id(), jsonrpcpp::Error(token.getError().detailed_message(), token.getError().value())); + else + result["token"] = token.takeValue(); + // LOG(DEBUG, LOG_TAG) << response->to_json().dump() << "\n"; } else throw jsonrpcpp::MethodNotFoundException(request->id()); @@ -473,7 +494,7 @@ void Server::processRequest(const jsonrpcpp::request_ptr request, const OnRespon << ", params: " << (request->params().has("params") ? request->params().get("params") : "") << "\n"; // Find stream - string streamId = request->params().get("id"); + auto streamId = request->params().get("id"); PcmStreamPtr stream = streamManager_->getStream(streamId); if (stream == nullptr) throw jsonrpcpp::InternalErrorException("Stream not found", request->id()); @@ -565,9 +586,9 @@ void Server::processRequest(const jsonrpcpp::request_ptr request, const OnRespon auto value = request->params().get("value"); LOG(INFO, LOG_TAG) << "Stream '" << streamId << "' set property: " << name << " = " << value << "\n"; - auto handle_response = [request, on_response](const snapcast::ErrorCode& ec) + auto handle_response = [request, on_response](const std::string& command, const snapcast::ErrorCode& ec) { - LOG(ERROR, LOG_TAG) << "SetShuffle: " << ec << ", message: " << ec.detailed_message() << ", msg: " << ec.message() + LOG(ERROR, LOG_TAG) << "Result for '" << command << "': " << ec << ", message: " << ec.detailed_message() << ", msg: " << ec.message() << ", category: " << ec.category().name() << "\n"; std::shared_ptr response; if (ec) @@ -583,31 +604,31 @@ void Server::processRequest(const jsonrpcpp::request_ptr request, const OnRespon LoopStatus loop_status = loop_status_from_string(val); if (loop_status == LoopStatus::kUnknown) throw jsonrpcpp::InvalidParamsException("Value for loopStatus must be one of 'none', 'track', 'playlist'", request->id()); - stream->setLoopStatus(loop_status, [handle_response](const snapcast::ErrorCode& ec) { handle_response(ec); }); + stream->setLoopStatus(loop_status, [handle_response, name](const snapcast::ErrorCode& ec) { handle_response(name, ec); }); } else if (name == "shuffle") { if (!value.is_boolean()) throw jsonrpcpp::InvalidParamsException("Value for shuffle must be bool", request->id()); - stream->setShuffle(value.get(), [handle_response](const snapcast::ErrorCode& ec) { handle_response(ec); }); + stream->setShuffle(value.get(), [handle_response, name](const snapcast::ErrorCode& ec) { handle_response(name, ec); }); } else if (name == "volume") { if (!value.is_number_integer()) throw jsonrpcpp::InvalidParamsException("Value for volume must be an int", request->id()); - stream->setVolume(value.get(), [handle_response](const snapcast::ErrorCode& ec) { handle_response(ec); }); + stream->setVolume(value.get(), [handle_response, name](const snapcast::ErrorCode& ec) { handle_response(name, ec); }); } else if (name == "mute") { if (!value.is_boolean()) throw jsonrpcpp::InvalidParamsException("Value for mute must be bool", request->id()); - stream->setMute(value.get(), [handle_response](const snapcast::ErrorCode& ec) { handle_response(ec); }); + stream->setMute(value.get(), [handle_response, name](const snapcast::ErrorCode& ec) { handle_response(name, ec); }); } else if (name == "rate") { if (!value.is_number_float()) throw jsonrpcpp::InvalidParamsException("Value for rate must be float", request->id()); - stream->setRate(value.get(), [handle_response](const snapcast::ErrorCode& ec) { handle_response(ec); }); + stream->setRate(value.get(), [handle_response, name](const snapcast::ErrorCode& ec) { handle_response(name, ec); }); } else throw jsonrpcpp::InvalidParamsException("Property '" + name + "' not supported", request->id()); @@ -655,7 +676,8 @@ void Server::processRequest(const jsonrpcpp::request_ptr request, const OnRespon else throw jsonrpcpp::MethodNotFoundException(request->id()); - response = std::make_shared(*request, result); + if (!response) + response = std::make_shared(*request, result); } catch (const jsonrpcpp::RequestException& e) { @@ -671,7 +693,7 @@ void Server::processRequest(const jsonrpcpp::request_ptr request, const OnRespon } -void Server::onMessageReceived(std::shared_ptr controlSession, const std::string& message, const ResponseHander& response_handler) +void Server::onMessageReceived(std::shared_ptr controlSession, const std::string& message, const ResponseHandler& response_handler) { // LOG(DEBUG, LOG_TAG) << "onMessageReceived: " << message << "\n"; std::lock_guard lock(Config::instance().getMutex()); @@ -696,14 +718,14 @@ void Server::onMessageReceived(std::shared_ptr controlSession, c if (entity->is_request()) { jsonrpcpp::request_ptr request = dynamic_pointer_cast(entity); - processRequest(request, + processRequest(request, controlSession->authinfo, [this, controlSession, response_handler](jsonrpcpp::entity_ptr response, jsonrpcpp::notification_ptr notification) { - if (controlSession->authinfo.has_value()) - { - LOG(INFO, LOG_TAG) << "Request auth info - username: " << controlSession->authinfo->username() - << ", valid: " << controlSession->authinfo->valid() << "\n"; - } + // if (controlSession->authinfo.hasAuthInfo()) + // { + // LOG(INFO, LOG_TAG) << "Request auth info - username: " << controlSession->authinfo->username() + // << ", valid: " << controlSession->authinfo->valid() << "\n"; + // } saveConfig(); ////cout << "Request: " << request->to_json().dump() << "\n"; if (notification) @@ -733,7 +755,7 @@ void Server::onMessageReceived(std::shared_ptr controlSession, c if (batch_entity->is_request()) { jsonrpcpp::request_ptr request = dynamic_pointer_cast(batch_entity); - processRequest(request, + processRequest(request, controlSession->authinfo, [controlSession, response_handler, &responseBatch, ¬ificationBatch](jsonrpcpp::entity_ptr response, jsonrpcpp::notification_ptr notification) { diff --git a/server/server.hpp b/server/server.hpp index b7b007cd..9e23a417 100644 --- a/server/server.hpp +++ b/server/server.hpp @@ -20,6 +20,7 @@ // local headers +#include "authinfo.hpp" #include "common/message/message.hpp" #include "common/queue.hpp" #include "control_server.hpp" @@ -52,12 +53,16 @@ class Server : public StreamMessageReceiver, public ControlMessageReceiver, publ { public: // TODO: revise handler names + /// Response handler for json control requests, returning a @p response and/or a @p notification broadcast using OnResponse = std::function; + /// c'tor Server(boost::asio::io_context& io_context, const ServerSettings& serverSettings); virtual ~Server(); + /// Start the server (control server, stream server and stream manager) void start(); + /// Stop the server (control server, stream server and stream manager) void stop(); private: @@ -66,7 +71,7 @@ private: void onDisconnect(StreamSession* streamSession) override; /// Implementation of ControllMessageReceiver - void onMessageReceived(std::shared_ptr controlSession, const std::string& message, const ResponseHander& response_handler) override; + void onMessageReceived(std::shared_ptr controlSession, const std::string& message, const ResponseHandler& response_handler) override; void onNewSession(std::shared_ptr session) override { std::ignore = session; @@ -81,7 +86,7 @@ private: void onResync(const PcmStream* pcmStream, double ms) override; private: - void processRequest(const jsonrpcpp::request_ptr request, const OnResponse& on_response) const; + void processRequest(const jsonrpcpp::request_ptr request, AuthInfo& authinfo, const OnResponse& on_response) const; /// Save the server state deferred to prevent blocking and lower disk io /// @param deferred the delay after the last call to saveConfig void saveConfig(const std::chrono::milliseconds& deferred = std::chrono::seconds(2)); diff --git a/server/server_settings.hpp b/server/server_settings.hpp index f92d68a0..33d88ade 100644 --- a/server/server_settings.hpp +++ b/server/server_settings.hpp @@ -20,6 +20,7 @@ // local headers +#include "common/utils/string_utils.hpp" #include "image_cache.hpp" // standard headers @@ -45,6 +46,23 @@ struct ServerSettings std::string key_password{""}; }; + struct User + { + User(const std::string& user_permissions_password) + { + std::string perm; + name = utils::string::split_left(user_permissions_password, ':', perm); + perm = utils::string::split_left(perm, ':', password); + permissions = utils::string::split(perm, ','); + } + + std::string name; + std::vector permissions; + std::string password; + }; + + std::vector users; + struct Http { bool enabled{true}; diff --git a/server/snapserver.cpp b/server/snapserver.cpp index 9770e128..50702086 100644 --- a/server/snapserver.cpp +++ b/server/snapserver.cpp @@ -85,6 +85,9 @@ int main(int argc, char* argv[]) conf.add>("", "ssl.private_key", "private key file (PEM format)", settings.ssl.private_key, &settings.ssl.private_key); conf.add>("", "ssl.key_password", "key password (for encrypted private key)", settings.ssl.key_password, &settings.ssl.key_password); + // Users setting + auto users_value = conf.add>("", "users.user", "::"); + // HTTP RPC settings conf.add>("", "http.enabled", "enable HTTP Json RPC (HTTP POST and websockets)", settings.http.enabled, &settings.http.enabled); conf.add>("", "http.port", "which port the server should listen on", settings.http.port, &settings.http.port); @@ -265,6 +268,15 @@ int main(int argc, char* argv[]) settings.stream.sources.push_back(sourceValue->value(n)); } + for (size_t n = 0; n < users_value->count(); ++n) + { + settings.users.emplace_back(users_value->value(n)); + LOG(DEBUG, LOG_TAG) << "User: " << settings.users.back().name + << ", permissions: " << utils::string::container_to_string(settings.users.back().permissions) + << ", pw: " << settings.users.back().password << "\n"; + } + + #ifdef HAS_DAEMON std::unique_ptr daemon; if (daemonOption->is_set()) diff --git a/server/streamreader/control_error.cpp b/server/streamreader/control_error.cpp index df101765..aa0071d0 100644 --- a/server/streamreader/control_error.cpp +++ b/server/streamreader/control_error.cpp @@ -87,7 +87,5 @@ const std::error_category& category() std::error_code make_error_code(ControlErrc errc) { - // Create an error_code with the original mpg123 error value - // and the mpg123 error category. return std::error_code(static_cast(errc), snapcast::error::control::category()); } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index fe7f6873..45738f99 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -13,11 +13,15 @@ set(TEST_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test_main.cpp ${CMAKE_SOURCE_DIR}/common/jwt.cpp ${CMAKE_SOURCE_DIR}/common/base64.cpp + ${CMAKE_SOURCE_DIR}/common/utils/string_utils.cpp + ${CMAKE_SOURCE_DIR}/server/authinfo.cpp ${CMAKE_SOURCE_DIR}/server/streamreader/control_error.cpp ${CMAKE_SOURCE_DIR}/server/streamreader/properties.cpp ${CMAKE_SOURCE_DIR}/server/streamreader/metadata.cpp ${CMAKE_SOURCE_DIR}/server/streamreader/stream_uri.cpp) +include_directories(${Boost_INCLUDE_DIR}) + add_executable(snapcast_test ${TEST_SOURCES}) if(ANDROID) diff --git a/test/test_main.cpp b/test/test_main.cpp index 9f5adebc..e0e2499d 100644 --- a/test/test_main.cpp +++ b/test/test_main.cpp @@ -20,8 +20,12 @@ // local headers #include "common/aixlog.hpp" +#include "common/base64.h" +#include "common/error_code.hpp" #include "common/jwt.hpp" #include "common/utils/string_utils.hpp" +#include "server/authinfo.hpp" +#include "server/server_settings.hpp" #include "server/streamreader/control_error.hpp" #include "server/streamreader/properties.hpp" #include "server/streamreader/stream_uri.hpp" @@ -32,6 +36,8 @@ // standard headers #include #include +#include +#include using namespace std; @@ -41,6 +47,32 @@ TEST_CASE("String utils") { using namespace utils::string; REQUIRE(ltrim_copy(" test") == "test"); + + auto strings = split("1*2", '*'); + REQUIRE(strings.size() == 2); + REQUIRE(strings[0] == "1"); + REQUIRE(strings[1] == "2"); + + strings = split("1**2", '*'); + REQUIRE(strings.size() == 3); + REQUIRE(strings[0] == "1"); + REQUIRE(strings[1] == ""); + REQUIRE(strings[2] == "2"); + + strings = split("*1*2", '*'); + REQUIRE(strings.size() == 3); + REQUIRE(strings[0] == ""); + REQUIRE(strings[1] == "1"); + REQUIRE(strings[2] == "2"); + + strings = split("*1*2*", '*'); + REQUIRE(strings.size() == 3); + REQUIRE(strings[0] == ""); + REQUIRE(strings[1] == "1"); + REQUIRE(strings[2] == "2"); + + std::vector vec{"1", "2", "3"}; + REQUIRE(container_to_string(vec) == "1, 2, 3"); } @@ -574,4 +606,110 @@ TEST_CASE("Error") ec = make_error_code(ControlErrc::can_not_control); REQUIRE(ec.category() == snapcast::error::control::category()); std::cout << "Category: " << ec.category().name() << ", " << ec.message() << std::endl; + + snapcast::ErrorCode error_code{}; + REQUIRE(!error_code); +} + + + +TEST_CASE("ErrorOr") +{ + { + snapcast::ErrorOr error_or("test"); + REQUIRE(error_or.hasValue()); + REQUIRE(!error_or.hasError()); + // Get value by reference + REQUIRE(error_or.getValue() == "test"); + // Move value out + REQUIRE(error_or.takeValue() == "test"); + // Value has been moved out, get will return an empty string + REQUIRE(error_or.getValue() == ""); + } + + { + snapcast::ErrorOr error_or(make_error_code(ControlErrc::can_not_control)); + REQUIRE(error_or.hasError()); + REQUIRE(!error_or.hasValue()); + // Get error by reference + REQUIRE(error_or.getError() == make_error_code(ControlErrc::can_not_control)); + // Get error by reference + REQUIRE(error_or.getError() == ControlErrc::can_not_control); + // Get error by reference + REQUIRE(error_or.getError() != ControlErrc::parse_error); + // Get error by reference + REQUIRE(error_or.getError() == snapcast::ErrorCode(ControlErrc::can_not_control)); + // Move error out + REQUIRE(error_or.takeError() == snapcast::ErrorCode(ControlErrc::can_not_control)); + // Error is moved out, will return something else + // REQUIRE(error_or.getError() != snapcast::ErrorCode(ControlErrc::can_not_control)); + } +} + + +TEST_CASE("WildcardMatch") +{ + using namespace utils::string; + REQUIRE(wildcardMatch("*", "Server.getToken")); + REQUIRE(wildcardMatch("Server.*", "Server.getToken")); + REQUIRE(wildcardMatch("Server.getToken", "Server.getToken")); + REQUIRE(wildcardMatch("*.getToken", "Server.getToken")); + REQUIRE(wildcardMatch("*.get*", "Server.getToken")); + REQUIRE(wildcardMatch("**.get*", "Server.getToken")); + REQUIRE(wildcardMatch("*.get**", "Server.getToken")); + REQUIRE(wildcardMatch("*.ge**t*", "Server.getToken")); + + REQUIRE(!wildcardMatch("*.set*", "Server.getToken")); + REQUIRE(!wildcardMatch(".*", "Server.getToken")); + REQUIRE(!wildcardMatch("*.get", "Server.getToken")); + REQUIRE(wildcardMatch("*erver*get*", "Server.getToken")); + REQUIRE(!wildcardMatch("*get*erver*", "Server.getToken")); +} + + +TEST_CASE("Auth") +{ + { + ServerSettings settings; + ServerSettings::User user("badaix:*:secret"); + REQUIRE(user.permissions.size() == 1); + REQUIRE(user.permissions[0] == "*"); + settings.users.push_back(user); + + AuthInfo auth(settings); + auto ec = auth.authenticateBasic(base64_encode("badaix:secret")); + REQUIRE(!ec); + REQUIRE(auth.hasAuthInfo()); + REQUIRE(auth.hasPermission("stream")); + } + + { + ServerSettings settings; + ServerSettings::User user("badaix::secret"); + REQUIRE(user.permissions.empty()); + settings.users.push_back(user); + + AuthInfo auth(settings); + auto ec = auth.authenticateBasic(base64_encode("badaix:secret")); + REQUIRE(!ec); + REQUIRE(auth.hasAuthInfo()); + REQUIRE(!auth.hasPermission("stream")); + } + + { + ServerSettings settings; + ServerSettings::User user("badaix:*:secret"); + settings.users.push_back(user); + + AuthInfo auth(settings); + auto ec = auth.authenticateBasic(base64_encode("badaix:wrong_password")); + REQUIRE(ec == AuthErrc::wrong_password); + REQUIRE(!auth.hasAuthInfo()); + REQUIRE(!auth.hasPermission("stream")); + + ec = auth.authenticateBasic(base64_encode("unknown_user:secret")); + REQUIRE(ec == AuthErrc::unknown_user); + REQUIRE(!auth.hasAuthInfo()); + REQUIRE(!auth.hasPermission("stream")); + } }