Browse Source

Merge pull request #356 from paullouisageneau/fix-threadpool-workers-access

Fix unsynchronized access in thread pool
Paul-Louis Ageneau 4 years ago
parent
commit
faf3158609
2 changed files with 14 additions and 13 deletions
  1. 13 12
      src/threadpool.cpp
  2. 1 1
      src/threadpool.hpp

+ 13 - 12
src/threadpool.cpp

@@ -51,7 +51,7 @@ void ThreadPool::spawn(int count) {
 void ThreadPool::join() {
 void ThreadPool::join() {
 	{
 	{
 		std::unique_lock lock(mMutex);
 		std::unique_lock lock(mMutex);
-		mWaitingCondition.wait(lock, [&]() { return mWaitingWorkers == int(mWorkers.size()); });
+		mWaitingCondition.wait(lock, [&]() { return mBusyWorkers == 0; });
 		mJoining = true;
 		mJoining = true;
 		mTasksCondition.notify_all();
 		mTasksCondition.notify_all();
 	}
 	}
@@ -66,6 +66,8 @@ void ThreadPool::join() {
 }
 }
 
 
 void ThreadPool::run() {
 void ThreadPool::run() {
+	++mBusyWorkers;
+	scope_guard([&]() { --mBusyWorkers; });
 	while (runOne()) {
 	while (runOne()) {
 	}
 	}
 }
 }
@@ -81,24 +83,23 @@ bool ThreadPool::runOne() {
 std::function<void()> ThreadPool::dequeue() {
 std::function<void()> ThreadPool::dequeue() {
 	std::unique_lock lock(mMutex);
 	std::unique_lock lock(mMutex);
 	while (!mJoining) {
 	while (!mJoining) {
+		std::optional<clock::time_point> time;
 		if (!mTasks.empty()) {
 		if (!mTasks.empty()) {
-			if (mTasks.top().time <= clock::now()) {
+			time = mTasks.top().time;
+			if (*time <= clock::now()) {
 				auto func = std::move(mTasks.top().func);
 				auto func = std::move(mTasks.top().func);
 				mTasks.pop();
 				mTasks.pop();
 				return func;
 				return func;
 			}
 			}
-
-			++mWaitingWorkers;
-			mWaitingCondition.notify_all();
-			mTasksCondition.wait_until(lock, mTasks.top().time);
-
-		} else {
-			++mWaitingWorkers;
-			mWaitingCondition.notify_all();
-			mTasksCondition.wait(lock);
 		}
 		}
 
 
-		--mWaitingWorkers;
+		--mBusyWorkers;
+		scope_guard([&]() { ++mBusyWorkers; });
+		mWaitingCondition.notify_all();
+		if(time)
+			mTasksCondition.wait_until(lock, *time);
+		else
+			mTasksCondition.wait(lock);
 	}
 	}
 	return nullptr;
 	return nullptr;
 }
 }

+ 1 - 1
src/threadpool.hpp

@@ -72,7 +72,7 @@ protected:
 	std::function<void()> dequeue(); // returns null function if joining
 	std::function<void()> dequeue(); // returns null function if joining
 
 
 	std::vector<std::thread> mWorkers;
 	std::vector<std::thread> mWorkers;
-	int mWaitingWorkers = 0;
+	int mBusyWorkers = 0;
 	std::atomic<bool> mJoining = false;
 	std::atomic<bool> mJoining = false;
 
 
 	struct Task {
 	struct Task {