From c8fc3ae4b8b45560911bc23c9a7a69f8490a3cd1 Mon Sep 17 00:00:00 2001 From: BadAix Date: Sun, 13 Nov 2016 15:53:48 +0100 Subject: [PATCH] protect socket with mutex --- client/clientConnection.cpp | 15 +++++++++------ client/clientConnection.h | 3 ++- server/controlSession.cpp | 26 +++++++++++++++++++------- server/controlSession.h | 3 ++- server/streamSession.cpp | 13 +++++++++---- server/streamSession.h | 3 ++- 6 files changed, 43 insertions(+), 20 deletions(-) diff --git a/client/clientConnection.cpp b/client/clientConnection.cpp index c22dbf44..62b81020 100644 --- a/client/clientConnection.cpp +++ b/client/clientConnection.cpp @@ -42,7 +42,6 @@ ClientConnection::~ClientConnection() void ClientConnection::socketRead(void* _to, size_t _bytes) { -// std::unique_lock mlock(mutex_); size_t toRead = _bytes; size_t len = 0; do @@ -123,6 +122,7 @@ bool ClientConnection::send(const msg::BaseMessage* message) const { // std::unique_lock mlock(mutex_); //logD << "send: " << message->type << ", size: " << message->getSize() << "\n"; + std::lock_guard socketLock(socketMutex_); if (!connected()) return false; //logD << "send: " << message->type << ", size: " << message->getSize() << "\n"; @@ -145,10 +145,10 @@ shared_ptr ClientConnection::sendRequest(const msg::Base // logO << "Req: " << message->id << "\n"; shared_ptr pendingRequest(new PendingRequest(reqId_)); - std::unique_lock mlock(mutex_); + std::unique_lock lock(pendingRequestsMutex_); pendingRequests_.insert(pendingRequest); send(message); - if (pendingRequest->cv.wait_for(mlock, std::chrono::milliseconds(timeout)) == std::cv_status::no_timeout) + if (pendingRequest->cv.wait_for(lock, std::chrono::milliseconds(timeout)) == std::cv_status::no_timeout) { response = pendingRequest->response; sumTimeout_ = chronos::msec(0); @@ -176,12 +176,15 @@ void ClientConnection::getNextMessage() // logD << "getNextMessage: " << baseMessage.type << ", size: " << baseMessage.size << ", id: " << baseMessage.id << ", refers: " << baseMessage.refersTo << "\n"; if (baseMessage.size > buffer.size()) buffer.resize(baseMessage.size); - socketRead(&buffer[0], baseMessage.size); + { + std::lock_guard socketLock(socketMutex_); + socketRead(&buffer[0], baseMessage.size); + } tv t; baseMessage.received = t; { - std::unique_lock mlock(mutex_); + std::unique_lock lock(pendingRequestsMutex_); // logD << "got lock - getNextMessage: " << baseMessage.type << ", size: " << baseMessage.size << ", id: " << baseMessage.id << ", refers: " << baseMessage.refersTo << "\n"; { for (auto req: pendingRequests_) @@ -192,7 +195,7 @@ void ClientConnection::getNextMessage() req->response->message = baseMessage; req->response->buffer = (char*)malloc(baseMessage.size); memcpy(req->response->buffer, &buffer[0], baseMessage.size); - mlock.unlock(); + lock.unlock(); req->cv.notify_one(); return; } diff --git a/client/clientConnection.h b/client/clientConnection.h index 576d9469..d176ed48 100644 --- a/client/clientConnection.h +++ b/client/clientConnection.h @@ -108,11 +108,12 @@ protected: void getNextMessage(); asio::io_service io_service_; + mutable std::mutex socketMutex_; std::shared_ptr socket_; std::atomic active_; std::atomic connected_; MessageReceiver* messageReceiver_; - mutable std::mutex mutex_; + mutable std::mutex pendingRequestsMutex_; std::set> pendingRequests_; uint16_t reqId_; std::string host_; diff --git a/server/controlSession.cpp b/server/controlSession.cpp index 43896633..54b0e9df 100644 --- a/server/controlSession.cpp +++ b/server/controlSession.cpp @@ -41,7 +41,10 @@ ControlSession::~ControlSession() void ControlSession::start() { - active_ = true; + { + std::lock_guard activeLock(activeMutex_); + active_ = true; + } readerThread_ = new thread(&ControlSession::reader, this); writerThread_ = new thread(&ControlSession::writer, this); } @@ -49,13 +52,20 @@ void ControlSession::start() void ControlSession::stop() { - std::unique_lock mlock(mutex_); - active_ = false; + { + std::lock_guard activeLock(activeMutex_); + if (!active_) + return; + + active_ = false; + } + try { std::error_code ec; if (socket_) { + std::lock_guard socketLock(socketMutex_); socket_->shutdown(asio::ip::tcp::socket::shutdown_both, ec); if (ec) logE << "Error in socket shutdown: " << ec.message() << "\n"; socket_->close(ec); @@ -95,9 +105,12 @@ void ControlSession::sendAsync(const std::string& message) bool ControlSession::send(const std::string& message) const { // logO << "send: " << message->type << ", size: " << message->size << ", id: " << message->id << ", refers: " << message->refersTo << "\n"; - std::unique_lock mlock(mutex_); - if (!socket_ || !active_) - return false; + std::lock_guard socketLock(socketMutex_); + { + std::lock_guard activeLock(activeMutex_); + if (!socket_ || !active_) + return false; + } asio::streambuf streambuf; std::ostream request_stream(&streambuf); request_stream << message << "\r\n"; @@ -109,7 +122,6 @@ bool ControlSession::send(const std::string& message) const void ControlSession::reader() { - active_ = true; try { std::stringstream message; diff --git a/server/controlSession.h b/server/controlSession.h index 197bcd85..25541b28 100644 --- a/server/controlSession.h +++ b/server/controlSession.h @@ -76,7 +76,8 @@ protected: void writer(); std::atomic active_; - mutable std::mutex mutex_; + mutable std::mutex activeMutex_; + mutable std::mutex socketMutex_; std::thread* readerThread_; std::thread* writerThread_; std::shared_ptr socket_; diff --git a/server/streamSession.cpp b/server/streamSession.cpp index 8ad4d085..c9c84130 100644 --- a/server/streamSession.cpp +++ b/server/streamSession.cpp @@ -55,7 +55,7 @@ const PcmStreamPtr StreamSession::pcmStream() const void StreamSession::start() { { - std::lock_guard mlock(mutex_); + std::lock_guard activeLock(activeMutex_); active_ = true; } readerThread_.reset(new thread(&StreamSession::reader, this)); @@ -66,7 +66,7 @@ void StreamSession::start() void StreamSession::stop() { { - std::lock_guard mlock(mutex_); + std::lock_guard activeLock(activeMutex_); if (!active_) return; @@ -78,6 +78,7 @@ void StreamSession::stop() std::error_code ec; if (socket_) { + std::lock_guard socketLock(socketMutex_); socket_->shutdown(asio::ip::tcp::socket::shutdown_both, ec); if (ec) logE << "Error in socket shutdown: " << ec.message() << "\n"; socket_->close(ec); @@ -144,8 +145,9 @@ bool StreamSession::send(const msg::BaseMessage* message) const { //TODO on exception: set active = false // logO << "send: " << message->type << ", size: " << message->getSize() << ", id: " << message->id << ", refers: " << message->refersTo << "\n"; + std::lock_guard socketLock(socketMutex_); { - std::lock_guard mlock(mutex_); + std::lock_guard activeLock(activeMutex_); if (!socket_ || !active_) return false; } @@ -176,7 +178,10 @@ void StreamSession::getNextMessage() // logO << "getNextMessage: " << baseMessage.type << ", size: " << baseMessage.size << ", id: " << baseMessage.id << ", refers: " << baseMessage.refersTo << "\n"; if (baseMessage.size > buffer.size()) buffer.resize(baseMessage.size); - socketRead(&buffer[0], baseMessage.size); + { + std::lock_guard socketLock(socketMutex_); + socketRead(&buffer[0], baseMessage.size); + } tv t; baseMessage.received = t; diff --git a/server/streamSession.h b/server/streamSession.h index 6d1146ad..fb31bddf 100644 --- a/server/streamSession.h +++ b/server/streamSession.h @@ -89,11 +89,12 @@ protected: void reader(); void writer(); - mutable std::mutex mutex_; + mutable std::mutex activeMutex_; std::atomic active_; std::unique_ptr readerThread_; std::unique_ptr writerThread_; + mutable std::mutex socketMutex_; std::shared_ptr socket_; MessageReceiver* messageReceiver_; Queue> messages_;