Add AuthInfo class

This commit is contained in:
badaix 2024-06-12 23:00:37 +02:00
parent c784e2526f
commit 878fecdc35
9 changed files with 215 additions and 29 deletions

View file

@ -37,6 +37,7 @@
// standard headers
#include <chrono>
#include <cstdint>
#include <ctime>
#include <exception>
#include <memory>
#include <optional>
@ -196,32 +197,32 @@ Jwt::Jwt() : claims({})
{
}
std::optional<std::chrono::seconds> Jwt::getIat() const
std::optional<std::chrono::system_clock::time_point> Jwt::getIat() const
{
if (!claims.contains("iat"))
return std::nullopt;
return std::chrono::seconds(claims.at("iat").get<int64_t>());
return std::chrono::system_clock::from_time_t(claims.at("iat").get<int64_t>());
}
void Jwt::setIat(const std::optional<std::chrono::seconds>& iat)
void Jwt::setIat(const std::optional<std::chrono::system_clock::time_point>& iat)
{
if (iat.has_value())
claims["iat"] = iat->count();
claims["iat"] = std::chrono::system_clock::to_time_t(iat.value());
else if (claims.contains("iat"))
claims.erase("iat");
}
std::optional<std::chrono::seconds> Jwt::getExp() const
std::optional<std::chrono::system_clock::time_point> Jwt::getExp() const
{
if (!claims.contains("exp"))
return std::nullopt;
return std::chrono::seconds(claims.at("exp").get<int64_t>());
return std::chrono::system_clock::from_time_t(claims.at("exp").get<int64_t>());
}
void Jwt::setExp(const std::optional<std::chrono::seconds>& exp)
void Jwt::setExp(const std::optional<std::chrono::system_clock::time_point>& exp)
{
if (exp.has_value())
claims["exp"] = exp->count();
claims["exp"] = std::chrono::system_clock::to_time_t(exp.value());
else if (claims.contains("exp"))
claims.erase("exp");
}

View file

@ -96,15 +96,15 @@ public:
/// Get the iat "Issued at time" claim
/// @return the claim or nullopt, if not present
std::optional<std::chrono::seconds> getIat() const;
std::optional<std::chrono::system_clock::time_point> getIat() const;
/// Set the iat "Issued at time" claim, use nullopt to delete the iat
void setIat(const std::optional<std::chrono::seconds>& iat);
void setIat(const std::optional<std::chrono::system_clock::time_point>& iat);
/// Get the exp "Expiration time" claim
/// @return the claim or nullopt, if not present
std::optional<std::chrono::seconds> getExp() const;
std::optional<std::chrono::system_clock::time_point> getExp() const;
/// Set the exp "Expiration time" claim, use nullopt to delete the exp
void setExp(const std::optional<std::chrono::seconds>& exp);
void setExp(const std::optional<std::chrono::system_clock::time_point>& exp);
/// Get the sub "Subject" claim
/// @return the claim or nullopt, if not present

View file

@ -1,4 +1,5 @@
set(SERVER_SOURCES
authinfo.cpp
config.cpp
control_server.cpp
control_session_tcp.cpp

88
server/authinfo.cpp Normal file
View file

@ -0,0 +1,88 @@
/***
This file is part of snapcast
Copyright (C) 2014-2024 Johannes Pohl
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
***/
// prototype/interface header file
#include "authinfo.hpp"
// local headers
#include "common/aixlog.hpp"
#include "common/base64.h"
#include "common/jwt.hpp"
#include "common/utils/string_utils.hpp"
// 3rd party headers
// standard headers
#include <chrono>
#include <fstream>
#include <string>
#include <string_view>
using namespace std;
static constexpr auto LOG_TAG = "AuthInfo";
AuthInfo::AuthInfo(std::string authheader)
{
LOG(INFO, LOG_TAG) << "Authorization: " << authheader << "\n";
std::string token(std::move(authheader));
static constexpr auto bearer = "bearer"sv;
auto pos = utils::string::tolower_copy(token).find(bearer);
if (pos != string::npos)
{
token = token.erase(0, pos + bearer.length());
utils::string::trim(token);
std::ifstream ifs("certs/snapserver.crt");
std::string certificate((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
Jwt jwt;
jwt.parse(token, certificate);
if (jwt.getExp().has_value())
expires_ = jwt.getExp().value();
username_ = jwt.getSub().value_or("");
LOG(INFO, LOG_TAG) << "Authorization token: " << token << ", user: " << username_ << ", claims: " << jwt.claims.dump() << "\n";
}
static constexpr auto basic = "basic"sv;
pos = utils::string::tolower_copy(token).find(basic);
if (pos != string::npos)
{
token = token.erase(0, pos + basic.length());
utils::string::trim(token);
username_ = base64_decode(token);
std::string password;
username_ = utils::string::split_left(username_, ':', password);
LOG(INFO, LOG_TAG) << "Authorization basic: " << token << ", user: " << username_ << ", password: " << password << "\n";
}
}
bool AuthInfo::valid() const
{
if (expires_.has_value())
{
LOG(INFO, LOG_TAG) << "Expires in " << std::chrono::duration_cast<std::chrono::seconds>(expires_.value() - std::chrono::system_clock::now()).count()
<< " sec\n";
return expires_ > std::chrono::system_clock::now();
}
return true;
}
const std::string& AuthInfo::username() const
{
return username_;
}

45
server/authinfo.hpp Normal file
View file

@ -0,0 +1,45 @@
/***
This file is part of snapcast
Copyright (C) 2014-2024 Johannes Pohl
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
***/
#pragma once
// local headers
#include "common/jwt.hpp"
// 3rd party headers
// standard headers
#include <chrono>
#include <optional>
#include <string>
class AuthInfo
{
public:
AuthInfo(std::string authheader);
virtual ~AuthInfo() = default;
bool valid() const;
const std::string& username() const;
private:
std::string username_;
std::optional<std::chrono::system_clock::time_point> expires_;
};

View file

@ -19,12 +19,14 @@
#pragma once
// local headers
#include "authinfo.hpp"
// 3rd party headers
// standard headers
#include <functional>
#include <memory>
#include <optional>
#include <string>
@ -63,6 +65,8 @@ public:
/// Sends a message to the client (asynchronous)
virtual void sendAsync(const std::string& message) = 0;
std::optional<AuthInfo> authinfo;
protected:
ControlMessageReceiver* message_receiver_;
};

View file

@ -19,20 +19,22 @@
// prototype/interface header file
#include "control_session_http.hpp"
// standard headers
#include <iostream>
#include <memory>
// local headers
#include "authinfo.hpp"
#include "common/aixlog.hpp"
#include "common/utils/file_utils.hpp"
#include "control_session_ws.hpp"
#include "stream_session_ws.hpp"
// 3rd party headers
#include <boost/asio/ssl/stream.hpp>
#include <boost/beast/http/buffer_body.hpp>
#include <boost/beast/http/file_body.hpp>
// local headers
#include "common/aixlog.hpp"
#include "common/utils/file_utils.hpp"
#include "control_session_ws.hpp"
#include "stream_session_ws.hpp"
// standard headers
#include <iostream>
#include <memory>
using namespace std;
namespace websocket = beast::websocket; // from <boost/beast/websocket.hpp>
@ -358,14 +360,16 @@ void ControlSessionHttp::handle_request(http::request<Body, http::basic_fields<A
void ControlSessionHttp::on_read(beast::error_code ec, std::size_t bytes_transferred)
{
// This means they closed the connection
if (ec == http::error::end_of_stream)
if ((ec == http::error::end_of_stream) || (ec == boost::asio::error::connection_reset))
{
boost::system::error_code res;
if (is_ssl_)
res = ssl_socket_->shutdown(res);
else
res = tcp_socket_->shutdown(tcp_socket::shutdown_send, ec);
if (res.failed())
ssl_socket_->async_shutdown(
[](const boost::system::error_code& error)
{
if (error.failed())
LOG(ERROR, LOG_TAG) << "Failed to shudown ssl socket: " << error << "\n";
});
else if (boost::system::error_code res = tcp_socket_->shutdown(tcp_socket::shutdown_send, ec); res.failed())
LOG(ERROR, LOG_TAG) << "Failed to shudown socket: " << res << "\n";
return;
}
@ -373,7 +377,7 @@ void ControlSessionHttp::on_read(beast::error_code ec, std::size_t bytes_transfe
// Handle the error, if any
if (ec)
{
LOG(ERROR, LOG_TAG) << "ControlSessionHttp::on_read error: " << ec.message() << "\n";
LOG(ERROR, LOG_TAG) << "ControlSessionHttp::on_read error: " << ec.message() << ", code: " << ec.value() << "\n";
return;
}
@ -444,6 +448,13 @@ void ControlSessionHttp::on_read(beast::error_code ec, std::size_t bytes_transfe
return;
}
std::string_view authheader = req_[beast::http::field::authorization];
if (!authheader.empty())
{
authinfo = AuthInfo(std::string(authheader));
}
// Send the response
handle_request(std::move(req_),
[this](auto&& response)

View file

@ -21,6 +21,7 @@
// local headers
#include "common/aixlog.hpp"
#include "common/jwt.hpp"
#include "common/message/client_info.hpp"
#include "common/message/hello.hpp"
#include "common/message/server_settings.hpp"
@ -30,6 +31,7 @@
// 3rd party headers
// standard headers
#include <chrono>
#include <iostream>
@ -402,6 +404,34 @@ void Server::processRequest(const jsonrpcpp::request_ptr request, const OnRespon
/// Notify others
notification = std::make_shared<jsonrpcpp::Notification>("Server.OnUpdate", jsonrpcpp::Parameter("server", server));
}
else if (request->method() == "Server.Authenticate")
{
// clang-format off
// Request: {"id":8,"jsonrpc":"2.0","method":"Server.Authenticate","params":{"user":"badaix","password":"secret"}}
// Response: {"id":8,"jsonrpc":"2.0","result":{"token":"<token>"}}
// clang-format on
if (request->params().has("token"))
{
auto token = request->params().get<std::string>("token");
LOG(INFO, LOG_TAG) << "Server.Authenticate, token: " << token << "\n";
result["token"] = token;
}
else if (request->params().has("user"))
{
auto user = request->params().get<std::string>("user");
Jwt jwt;
auto now = std::chrono::system_clock::now();
jwt.setIat(now);
jwt.setExp(now + 10h);
jwt.setSub(user);
std::ifstream ifs(settings_.ssl.private_key.c_str());
std::string private_key((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
std::optional<std::string> token = jwt.getToken(private_key);
result["token"] = token.value();
LOG(INFO, LOG_TAG) << "Server.Authenticate, user: " << user << ", password: " << request->params().get<std::string>("password")
<< ", jwt claims: " << jwt.claims.dump() << ", token: '" << token.value_or("") << "'\n";
}
}
else
throw jsonrpcpp::MethodNotFoundException(request->id());
}
@ -669,6 +699,11 @@ void Server::onMessageReceived(std::shared_ptr<ControlSession> controlSession, c
processRequest(request,
[this, controlSession, response_handler](jsonrpcpp::entity_ptr response, jsonrpcpp::notification_ptr notification)
{
if (controlSession->authinfo.has_value())
{
LOG(INFO, LOG_TAG) << "Request auth info - username: " << controlSession->authinfo->username()
<< ", valid: " << controlSession->authinfo->valid() << "\n";
}
saveConfig();
////cout << "Request: " << request->to_json().dump() << "\n";
if (notification)

View file

@ -30,6 +30,7 @@
#include <catch2/catch_test_macros.hpp>
// standard headers
#include <chrono>
#include <regex>
@ -154,7 +155,7 @@ TEST_CASE("JWT")
"-----END CERTIFICATE-----\n";
Jwt jwt;
jwt.setIat(std::chrono::seconds(1516239022));
jwt.setIat(std::chrono::system_clock::from_time_t(1516239022));
jwt.setSub("Badaix");
std::optional<std::string> token = jwt.getToken(key);
REQUIRE(token.has_value());
@ -168,7 +169,7 @@ TEST_CASE("JWT")
REQUIRE(jwt.getSub().has_value());
REQUIRE(jwt.getSub().value() == "Badaix");
REQUIRE(jwt.getIat().has_value());
REQUIRE(jwt.getIat().value() == std::chrono::seconds(1516239022));
REQUIRE(jwt.getIat().value() == std::chrono::system_clock::from_time_t(1516239022));
REQUIRE(!jwt.getExp().has_value());
}
}