Browse Source

Merge pull request #1379 from bobk-rey/fix-mutex-deadlock

Fix TCP mutex deadlock
Paul-Louis Ageneau 3 months ago
parent
commit
5d4afed2d1
3 changed files with 89 additions and 82 deletions
  1. 50 44
      src/impl/pollservice.cpp
  2. 38 38
      src/impl/tcptransport.cpp
  3. 1 0
      src/impl/tcptransport.hpp

+ 50 - 44
src/impl/pollservice.cpp

@@ -105,57 +105,63 @@ void PollService::prepare(std::vector<struct pollfd> &pfds, optional<clock::time
 }
 
 void PollService::process(std::vector<struct pollfd> &pfds) {
-	std::unique_lock lock(mMutex);
-	auto it = pfds.begin();
-	if (it != pfds.end()) {
-		mInterrupter->process(*it++);
-	}
-	while (it != pfds.end()) {
-		socket_t sock = it->fd;
-		auto jt = mSocks->find(sock);
-		if (jt != mSocks->end()) {
-			try {
-				auto &entry = jt->second;
-				const auto &params = entry.params;
-
-				if (it->revents & POLLNVAL || it->revents & POLLERR ||
-				    (it->revents & POLLHUP &&
-				     !(it->events & POLLIN))) { // MacOS sets POLLHUP on connection failure
-					PLOG_VERBOSE << "Poll error event";
-					auto callback = std::move(params.callback);
-					mSocks->erase(sock);
-					callback(Event::Error);
-
-				} else if (it->revents & POLLIN || it->revents & POLLOUT || it->revents & POLLHUP) {
-					entry.until = params.timeout
-					                  ? std::make_optional(clock::now() + *params.timeout)
-					                  : nullopt;
-
-					auto callback = params.callback;
-					if (it->revents & POLLIN ||
-					    it->revents & POLLHUP) { // Windows does not set POLLIN on close
-						PLOG_VERBOSE << "Poll in event";
-						callback(Event::In);
-					}
-					if (it->revents & POLLOUT) {
-						PLOG_VERBOSE << "Poll out event";
-						callback(Event::Out);
+	using Callback = decltype(std::declval<Params>().callback);
+	std::vector<std::pair<Callback, Event>> todo;
+	{
+		std::unique_lock lock(mMutex);
+		auto it = pfds.begin();
+		if (it != pfds.end()) {
+			mInterrupter->process(*it++);
+		}
+		while (it != pfds.end()) {
+			socket_t sock = it->fd;
+			auto jt = mSocks->find(sock);
+			if (jt != mSocks->end()) {
+				try {
+					auto &entry = jt->second;
+					const auto &params = entry.params;
+
+					if (it->revents & POLLNVAL || it->revents & POLLERR ||
+					    (it->revents & POLLHUP &&
+					     !(it->events & POLLIN))) { // MacOS sets POLLHUP on connection failure
+						PLOG_VERBOSE << "Poll error event";
+						todo.emplace_back(std::move(params.callback), Event::Error);
+						mSocks->erase(sock);
+					} else if (it->revents & POLLIN || it->revents & POLLOUT || it->revents & POLLHUP) {
+						entry.until = params.timeout
+						                  ? std::make_optional(clock::now() + *params.timeout)
+						                  : nullopt;
+
+						const auto &callback = params.callback; // can't move, we may need it below
+						if (it->revents & POLLIN ||
+						    it->revents & POLLHUP) { // Windows does not set POLLIN on close
+							PLOG_VERBOSE << "Poll in event";
+							todo.emplace_back(callback, Event::In);
+						}
+						if (it->revents & POLLOUT) {
+							PLOG_VERBOSE << "Poll out event";
+							todo.emplace_back(callback, Event::Out);
+						}
+
+					} else if (entry.until && clock::now() >= *entry.until) {
+						PLOG_VERBOSE << "Poll timeout event";
+						todo.emplace_back(std::move(params.callback), Event::Timeout);
+						mSocks->erase(sock);
 					}
 
-				} else if (entry.until && clock::now() >= *entry.until) {
-					PLOG_VERBOSE << "Poll timeout event";
-					auto callback = std::move(params.callback);
+				} catch (const std::exception &e) {
+					PLOG_WARNING << e.what();
 					mSocks->erase(sock);
-					callback(Event::Timeout);
 				}
-
-			} catch (const std::exception &e) {
-				PLOG_WARNING << e.what();
-				mSocks->erase(sock);
 			}
+
+			++it;
 		}
+	}
 
-		++it;
+	// Now perform the callbacks
+	for (auto &[callback, event] : todo) {
+		callback(event);
 	}
 }
 

+ 38 - 38
src/impl/tcptransport.cpp

@@ -228,44 +228,9 @@ void TcpTransport::attempt() {
 		return;
 	}
 
-	// Poll out event callback
-	auto callback = [this](PollService::Event event) {
-		try {
-			if (event == PollService::Event::Error)
-				throw std::runtime_error("TCP connection failed");
-
-			if (event == PollService::Event::Timeout)
-				throw std::runtime_error("TCP connection timed out");
-
-			if (event != PollService::Event::Out)
-				return;
-
-			int err = 0;
-			socklen_t errlen = sizeof(err);
-			if (::getsockopt(mSock, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&err),
-			                 &errlen) != 0)
-				throw std::runtime_error("Failed to get socket error code");
-
-			if (err != 0) {
-				std::ostringstream msg;
-				msg << "TCP connection failed, errno=" << err;
-				throw std::runtime_error(msg.str());
-			}
-
-			// Success
-			PLOG_INFO << "TCP connected";
-			changeState(State::Connected);
-			setPoll(PollService::Direction::In);
-
-		} catch (const std::exception &e) {
-			PLOG_DEBUG << e.what();
-			PollService::Instance().remove(mSock);
-			ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
-		}
-	};
-
 	const auto timeout = 10s;
-	PollService::Instance().add(mSock, {PollService::Direction::Out, timeout, std::move(callback)});
+	PollService::Instance().add(mSock, {PollService::Direction::Out, timeout, 
+	    weak_bind(&TcpTransport::processConnect, this, _1)});
 }
 
 void TcpTransport::createSocket(const struct sockaddr *addr, socklen_t addrlen) {
@@ -326,7 +291,7 @@ void TcpTransport::configureSocket() {
 void TcpTransport::setPoll(PollService::Direction direction) {
 	PollService::Instance().add(
 	    mSock, {direction, direction == PollService::Direction::In ? mReadTimeout : nullopt,
-	            std::bind(&TcpTransport::process, this, _1)});
+	            weak_bind(&TcpTransport::process, this, _1)});
 }
 
 void TcpTransport::close() {
@@ -471,6 +436,41 @@ void TcpTransport::process(PollService::Event event) {
 	recv(nullptr);
 }
 
+void TcpTransport::processConnect(PollService::Event event) {
+	try {
+		if (event == PollService::Event::Error)
+			throw std::runtime_error("TCP connection failed");
+
+		if (event == PollService::Event::Timeout)
+			throw std::runtime_error("TCP connection timed out");
+
+		if (event != PollService::Event::Out)
+			return;
+
+		int err = 0;
+		socklen_t errlen = sizeof(err);
+		if (::getsockopt(mSock, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&err),
+						 &errlen) != 0)
+			throw std::runtime_error("Failed to get socket error code");
+
+		if (err != 0) {
+			std::ostringstream msg;
+			msg << "TCP connection failed, errno=" << err;
+			throw std::runtime_error(msg.str());
+		}
+
+		// Success
+		PLOG_INFO << "TCP connected";
+		changeState(State::Connected);
+		setPoll(PollService::Direction::In);
+
+	} catch (const std::exception &e) {
+		PLOG_DEBUG << e.what();
+		PollService::Instance().remove(mSock);
+		ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this));
+	}
+}
+
 } // namespace rtc::impl
 
 #endif

+ 1 - 0
src/impl/tcptransport.hpp

@@ -59,6 +59,7 @@ private:
 	void triggerBufferedAmount(size_t amount);
 
 	void process(PollService::Event event);
+	void processConnect(PollService::Event event);
 
 	const bool mIsActive;
 	string mHostname, mService;