Browse Source

Proper fix for thread pool deadlock at exit

Paul-Louis Ageneau 4 years ago
parent
commit
efe12f0b73
2 changed files with 26 additions and 22 deletions
  1. 22 19
      src/threadpool.cpp
  2. 4 3
      src/threadpool.hpp

+ 22 - 19
src/threadpool.cpp

@@ -21,10 +21,10 @@
 #include <cstdlib>
 
 namespace {
-	void joinThreadPoolInstance() {
-		rtc::ThreadPool::Instance().join();
-	}
-}
+
+void joinThreadPoolInstance() { rtc::ThreadPool::Instance().join(); }
+
+} // namespace
 
 namespace rtc {
 
@@ -33,9 +33,7 @@ ThreadPool &ThreadPool::Instance() {
 	return *instance;
 }
 
-ThreadPool::ThreadPool() {
-	std::atexit(joinThreadPoolInstance);
-}
+ThreadPool::ThreadPool() { std::atexit(joinThreadPoolInstance); }
 
 ThreadPool::~ThreadPool() {}
 
@@ -45,17 +43,21 @@ int ThreadPool::count() const {
 }
 
 void ThreadPool::spawn(int count) {
-	std::unique_lock lock(mWorkersMutex);
+	std::scoped_lock lock(mMutex, mWorkersMutex);
 	mJoining = false;
 	while (count-- > 0)
 		mWorkers.emplace_back(std::bind(&ThreadPool::run, this));
 }
 
 void ThreadPool::join() {
-	std::unique_lock lock(mWorkersMutex);
-	mJoining = true;
-	mCondition.notify_all();
+	{
+		std::unique_lock lock(mMutex);
+		mWaitingCondition.wait(lock, [&]() { return mWaitingWorkers == int(mWorkers.size()); });
+		mJoining = true;
+		mTasksCondition.notify_all();
+	}
 
+	std::unique_lock lock(mWorkersMutex);
 	for (auto &w : mWorkers)
 		w.join();
 
@@ -77,7 +79,7 @@ bool ThreadPool::runOne() {
 
 std::function<void()> ThreadPool::dequeue() {
 	std::unique_lock lock(mMutex);
-	while (true) {
+	while (!mJoining) {
 		if (!mTasks.empty()) {
 			if (mTasks.top().time <= clock::now()) {
 				auto func = std::move(mTasks.top().func);
@@ -85,16 +87,17 @@ std::function<void()> ThreadPool::dequeue() {
 				return func;
 			}
 
-			if (mJoining)
-				break;
+			++mWaitingWorkers;
+			mWaitingCondition.notify_all();
+			mTasksCondition.wait_until(lock, mTasks.top().time);
 
-			mCondition.wait_until(lock, mTasks.top().time);
 		} else {
-			if (mJoining)
-				break;
-
-			mCondition.wait(lock);
+			++mWaitingWorkers;
+			mWaitingCondition.notify_all();
+			mTasksCondition.wait(lock);
 		}
+
+		--mWaitingWorkers;
 	}
 	return nullptr;
 }

+ 4 - 3
src/threadpool.hpp

@@ -72,7 +72,8 @@ protected:
 	std::function<void()> dequeue(); // returns null function if joining
 
 	std::vector<std::thread> mWorkers;
-	std::atomic<bool> mJoining = false;
+	int mWaitingWorkers = 0;
+	bool mJoining = false;
 
 	struct Task {
 		clock::time_point time;
@@ -82,8 +83,8 @@ protected:
 	};
 	std::priority_queue<Task, std::deque<Task>, std::greater<Task>> mTasks;
 
+	std::condition_variable mTasksCondition, mWaitingCondition;
 	mutable std::mutex mMutex, mWorkersMutex;
-	std::condition_variable mCondition;
 };
 
 template <class F, class... Args>
@@ -114,7 +115,7 @@ auto ThreadPool::schedule(clock::time_point time, F &&f, Args &&...args)
 	std::future<R> result = task->get_future();
 
 	mTasks.push({time, [task = std::move(task), token = Init::Token()]() { return (*task)(); }});
-	mCondition.notify_one();
+	mTasksCondition.notify_one();
 	return result;
 }