Use promise/future for sync messages

This commit is contained in:
badaix 2019-12-11 22:47:11 +01:00
parent 565da8c04a
commit d8a6e63691
2 changed files with 64 additions and 33 deletions

View file

@ -130,21 +130,23 @@ bool ClientConnection::send(const msg::BaseMessage* message)
} }
shared_ptr<msg::SerializedMessage> ClientConnection::sendRequest(const msg::BaseMessage* message, const chronos::msec& timeout) unique_ptr<msg::SerializedMessage> ClientConnection::sendRequest(const msg::BaseMessage* message, const chronos::msec& timeout)
{ {
shared_ptr<msg::SerializedMessage> response(nullptr); unique_ptr<msg::SerializedMessage> response(nullptr);
if (++reqId_ >= 10000) if (++reqId_ >= 10000)
reqId_ = 1; reqId_ = 1;
message->id = reqId_; message->id = reqId_;
// LOG(INFO) << "Req: " << message->id << "\n"; // LOG(INFO) << "Req: " << message->id << "\n";
shared_ptr<PendingRequest> pendingRequest(new PendingRequest(reqId_)); shared_ptr<PendingRequest> pendingRequest = make_shared<PendingRequest>(reqId_);
std::unique_lock<std::mutex> lock(pendingRequestsMutex_); { // scope for lock
pendingRequests_.insert(pendingRequest); std::unique_lock<std::mutex> lock(pendingRequestsMutex_);
send(message); pendingRequests_.insert(pendingRequest);
if (pendingRequest->cv.wait_for(lock, std::chrono::milliseconds(timeout)) == std::cv_status::no_timeout) send(message);
}
if ((response = pendingRequest->waitForResponse(std::chrono::milliseconds(timeout))) != nullptr)
{ {
response = pendingRequest->response;
sumTimeout_ = chronos::msec(0); sumTimeout_ = chronos::msec(0);
// LOG(INFO) << "Resp: " << pendingRequest->id << "\n"; // LOG(INFO) << "Resp: " << pendingRequest->id << "\n";
} }
@ -155,7 +157,11 @@ shared_ptr<msg::SerializedMessage> ClientConnection::sendRequest(const msg::Base
if (sumTimeout_ > chronos::sec(10)) if (sumTimeout_ > chronos::sec(10))
throw SnapException("sum timeout exceeded 10s"); throw SnapException("sum timeout exceeded 10s");
} }
pendingRequests_.erase(pendingRequest);
{ // scope for lock
std::unique_lock<std::mutex> lock(pendingRequestsMutex_);
pendingRequests_.erase(pendingRequest);
}
return response; return response;
} }
@ -174,27 +180,22 @@ void ClientConnection::getNextMessage()
// { // {
// std::lock_guard<std::mutex> socketLock(socketMutex_); // std::lock_guard<std::mutex> socketLock(socketMutex_);
socketRead(&buffer[0], baseMessage.size); socketRead(&buffer[0], baseMessage.size);
// }
tv t; tv t;
baseMessage.received = t; baseMessage.received = t;
// }
{ { // scope for lock
std::unique_lock<std::mutex> lock(pendingRequestsMutex_); std::unique_lock<std::mutex> lock(pendingRequestsMutex_);
// LOG(DEBUG) << "got lock - getNextMessage: " << baseMessage.type << ", size: " << baseMessage.size << ", id: " << baseMessage.id << ", for (auto req : pendingRequests_)
// refers: " << baseMessage.refersTo << "\n";
{ {
for (auto req : pendingRequests_) if (req->id() == baseMessage.refersTo)
{ {
if (req->id == baseMessage.refersTo) auto response = make_unique<msg::SerializedMessage>();
{ response->message = baseMessage;
req->response.reset(new msg::SerializedMessage()); response->buffer = (char*)malloc(baseMessage.size);
req->response->message = baseMessage; memcpy(response->buffer, &buffer[0], baseMessage.size);
req->response->buffer = (char*)malloc(baseMessage.size); req->setValue(std::move(response));
memcpy(req->response->buffer, &buffer[0], baseMessage.size); return;
lock.unlock();
req->cv.notify_one();
return;
}
} }
} }
} }

View file

@ -38,13 +38,43 @@ class ClientConnection;
/// Used to synchronize server requests (wait for server response) /// Used to synchronize server requests (wait for server response)
struct PendingRequest class PendingRequest
{ {
PendingRequest(uint16_t reqId) : id(reqId), response(nullptr){}; public:
PendingRequest(uint16_t reqId) : id_(reqId)
{
future_ = promise_.get_future();
};
uint16_t id; template <typename Rep, typename Period>
std::shared_ptr<msg::SerializedMessage> response; std::unique_ptr<msg::SerializedMessage> waitForResponse(const std::chrono::duration<Rep, Period>& timeout)
std::condition_variable cv; {
try
{
if (future_.wait_for(timeout) == std::future_status::ready)
return future_.get();
}
catch (...)
{
}
return nullptr;
}
void setValue(std::unique_ptr<msg::SerializedMessage> value)
{
promise_.set_value(std::move(value));
}
uint16_t id() const
{
return id_;
}
private:
uint16_t id_;
std::promise<std::unique_ptr<msg::SerializedMessage>> promise_;
std::future<std::unique_ptr<msg::SerializedMessage>> future_;
}; };
@ -79,16 +109,16 @@ public:
virtual bool send(const msg::BaseMessage* message); virtual bool send(const msg::BaseMessage* message);
/// Send request to the server and wait for answer /// Send request to the server and wait for answer
virtual std::shared_ptr<msg::SerializedMessage> sendRequest(const msg::BaseMessage* message, const chronos::msec& timeout = chronos::msec(1000)); virtual std::unique_ptr<msg::SerializedMessage> sendRequest(const msg::BaseMessage* message, const chronos::msec& timeout = chronos::msec(1000));
/// Send request to the server and wait for answer of type T /// Send request to the server and wait for answer of type T
template <typename T> template <typename T>
std::shared_ptr<T> sendReq(const msg::BaseMessage* message, const chronos::msec& timeout = chronos::msec(1000)) std::unique_ptr<T> sendReq(const msg::BaseMessage* message, const chronos::msec& timeout = chronos::msec(1000))
{ {
std::shared_ptr<msg::SerializedMessage> reply = sendRequest(message, timeout); std::unique_ptr<msg::SerializedMessage> reply = sendRequest(message, timeout);
if (!reply) if (!reply)
return nullptr; return nullptr;
std::shared_ptr<T> msg(new T); std::unique_ptr<T> msg(new T);
msg->deserialize(reply->message, reply->buffer); msg->deserialize(reply->message, reply->buffer);
return msg; return msg;
} }