Fix data race with pending requests

This commit is contained in:
badaix 2020-05-02 11:58:31 +02:00
parent bee9b2390c
commit 38cddf6424
3 changed files with 50 additions and 31 deletions

View file

@ -125,7 +125,7 @@ void ClientConnection::disconnect()
socket_.close(ec);
if (ec)
LOG(ERROR, LOG_TAG) << "Error in socket close: " << ec.message() << endl;
pendingRequests_.clear();
boost::asio::post(strand_, [this]() { pendingRequests_.clear(); });
LOG(DEBUG, LOG_TAG) << "Disconnected\n";
}
@ -171,13 +171,17 @@ void ClientConnection::send(const msg::message_ptr& message, const ResultHandler
void ClientConnection::sendRequest(const msg::message_ptr& message, const chronos::usec& timeout, const MessageHandler<msg::BaseMessage>& handler)
{
// LOG(INFO, LOG_TAG) << "Req: " << message->id << "\n";
boost::asio::post(strand_, [this, message, timeout, handler]() {
pendingRequests_.erase(
std::remove_if(pendingRequests_.begin(), pendingRequests_.end(), [](std::weak_ptr<PendingRequest> request) { return request.expired(); }),
pendingRequests_.end());
unique_ptr<msg::BaseMessage> response(nullptr);
if (++reqId_ >= 10000)
reqId_ = 1;
message->id = reqId_;
pendingRequests_.insert(make_unique<PendingRequest>(io_context_, strand_, reqId_, timeout, handler));
auto request = make_shared<PendingRequest>(io_context_, strand_, reqId_, handler);
pendingRequests_.push_back(request);
request->startTimer(timeout);
send(message, [handler](const boost::system::error_code& ec) {
if (ec)
handler(ec, nullptr);
@ -232,17 +236,19 @@ void ClientConnection::getNextMessage(const MessageHandler<msg::BaseMessage>& ha
return;
}
auto iter = std::find_if(
pendingRequests_.begin(), pendingRequests_.end(),
[this](const std::unique_ptr<PendingRequest>& request) { return request->id() == base_message_.refersTo; });
auto response = msg::factory::createMessage(base_message_, buffer_.data());
if (iter != pendingRequests_.end())
for (const auto& request : pendingRequests_)
{
(*iter)->setValue(std::move(response));
pendingRequests_.erase(iter);
if (auto req = request.lock())
{
if (req->id() == base_message_.refersTo)
{
req->setValue(std::move(response));
getNextMessage(handler);
return;
}
}
}
if (handler)
handler(ec, std::move(response));

View file

@ -44,26 +44,12 @@ template <typename Message>
using MessageHandler = std::function<void(const boost::system::error_code&, std::unique_ptr<Message>)>;
/// Used to synchronize server requests (wait for server response)
class PendingRequest
class PendingRequest : public std::enable_shared_from_this<PendingRequest>
{
public:
PendingRequest(boost::asio::io_context& io_context, boost::asio::io_context::strand& strand, uint16_t reqId, const chronos::usec& timeout,
PendingRequest(boost::asio::io_context& io_context, boost::asio::io_context::strand& strand, uint16_t reqId,
const MessageHandler<msg::BaseMessage>& handler)
: id_(reqId), timer_(io_context), strand_(strand), handler_(handler)
{
timer_.expires_after(timeout);
timer_.async_wait(boost::asio::bind_executor(strand_, [this](boost::system::error_code ec) {
if (!handler_)
return;
if (!ec)
{
handler_(boost::asio::error::timed_out, nullptr);
handler_ = nullptr;
}
else if (ec != boost::asio::error::operation_aborted)
handler_(ec, nullptr);
}));
};
: id_(reqId), timer_(io_context), strand_(strand), handler_(handler){};
virtual ~PendingRequest()
{
@ -83,6 +69,28 @@ public:
return id_;
}
void startTimer(const chronos::usec& timeout)
{
timer_.expires_after(timeout);
timer_.async_wait(boost::asio::bind_executor(strand_, [ this, self = shared_from_this() ](boost::system::error_code ec) {
if (!handler_)
return;
if (!ec)
{
handler_(boost::asio::error::timed_out, nullptr);
handler_ = nullptr;
}
else if (ec != boost::asio::error::operation_aborted)
handler_(ec, nullptr);
}));
}
bool operator<(const PendingRequest& other) const
{
return (id_ < other.id());
}
private:
uint16_t id_;
boost::asio::steady_timer timer_;
@ -153,7 +161,7 @@ protected:
boost::asio::io_context& io_context_;
tcp::resolver resolver_;
tcp::socket socket_;
std::set<std::unique_ptr<PendingRequest>> pendingRequests_;
std::vector<std::weak_ptr<PendingRequest>> pendingRequests_;
uint16_t reqId_;
ClientSettings::Server server_;

View file

@ -316,9 +316,14 @@ int main(int argc, char** argv)
auto meta(metaStderr ? std::make_unique<MetaStderrAdapter>() : std::make_unique<MetadataAdapter>());
auto controller = make_shared<Controller>(io_context, settings, std::move(meta));
controller->start();
// std::thread t([&] { io_context.run(); });
int num_threads = 0;
std::vector<std::thread> threads;
for (int n = 0; n < num_threads; ++n)
threads.emplace_back([&] { io_context.run(); });
io_context.run();
// t.join();
for (auto& t : threads)
t.join();
}
catch (const std::exception& e)
{