Browse Source

Merge pull request #339 from paullouisageneau/fix-scheduling

Proper fix for possible deadlock at exit
Paul-Louis Ageneau 4 years ago
parent
commit
524c56dee9
4 changed files with 38 additions and 67 deletions
  1. 6 27
      src/processor.cpp
  2. 5 5
      src/processor.hpp
  3. 22 23
      src/threadpool.cpp
  4. 5 12
      src/threadpool.hpp

+ 6 - 27
src/processor.cpp

@@ -25,39 +25,18 @@ Processor::Processor(size_t limit) : mTasks(limit) {}
 Processor::~Processor() { join(); }
 
 void Processor::join() {
-	// We need to detect situations where the thread pool does not execute a pending task at exit
-	std::optional<unsigned int> counter;
-	while (true) {
-		std::shared_future<void> pending;
-		{
-			std::unique_lock lock(mMutex);
-			if (!mPending                               // no pending task
-			    || (counter && *counter == mCounter)) { // or no scheduled task after the last one
-
-				// Processing is stopped, clear everything and return
-				mPending.reset();
-				while (!mTasks.empty())
-					mTasks.pop();
-
-				return;
-			}
-
-			pending = *mPending;
-			counter = mCounter;
-		}
-
-		// Wait for the pending task
-		pending.wait();
-	}
+	std::unique_lock lock(mMutex);
+	mCondition.wait(lock, [this]() { return !mPending && mTasks.empty(); });
 }
 
 void Processor::schedule() {
 	std::unique_lock lock(mMutex);
 	if (auto next = mTasks.tryPop()) {
-		mPending = ThreadPool::Instance().enqueue(std::move(*next)).share();
-		++mCounter;
+		ThreadPool::Instance().enqueue(std::move(*next));
 	} else {
-		mPending.reset(); // No more tasks
+		// No more tasks
+		mPending = false;
+		mCondition.notify_all();
 	}
 }
 

+ 5 - 5
src/processor.hpp

@@ -54,10 +54,10 @@ protected:
 	const init_token mInitToken = Init::Token();
 
 	Queue<std::function<void()>> mTasks;
-	std::optional<std::shared_future<void>> mPending; // future of the pending task
-	unsigned int mCounter = 0; // Number of scheduled tasks
+	bool mPending = false; // true iff a task is pending in the thread pool
 
 	mutable std::mutex mMutex;
+	std::condition_variable mCondition;
 };
 
 template <class F, class... Args> void Processor::enqueue(F &&f, Args &&...args) {
@@ -65,12 +65,12 @@ template <class F, class... Args> void Processor::enqueue(F &&f, Args &&...args)
 	auto bound = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
 	auto task = [this, bound = std::move(bound)]() mutable {
 		scope_guard guard(std::bind(&Processor::schedule, this)); // chain the next task
-		bound();
+		return bound();
 	};
 
 	if (!mPending) {
-		mPending = ThreadPool::Instance().enqueue(std::move(task)).share();
-		++mCounter;
+		ThreadPool::Instance().enqueue(std::move(task));
+		mPending = true;
 	} else {
 		mTasks.push(std::move(task));
 	}

+ 22 - 23
src/threadpool.cpp

@@ -21,10 +21,10 @@
 #include <cstdlib>
 
 namespace {
-	void joinThreadPoolInstance() {
-		rtc::ThreadPool::Instance().join();
-	}
-}
+
+void joinThreadPoolInstance() { rtc::ThreadPool::Instance().join(); }
+
+} // namespace
 
 namespace rtc {
 
@@ -33,9 +33,7 @@ ThreadPool &ThreadPool::Instance() {
 	return *instance;
 }
 
-ThreadPool::ThreadPool() {
-	std::atexit(joinThreadPoolInstance);
-}
+ThreadPool::ThreadPool() { std::atexit(joinThreadPoolInstance); }
 
 ThreadPool::~ThreadPool() {}
 
@@ -45,17 +43,21 @@ int ThreadPool::count() const {
 }
 
 void ThreadPool::spawn(int count) {
-	std::unique_lock lock(mWorkersMutex);
+	std::scoped_lock lock(mMutex, mWorkersMutex);
 	mJoining = false;
 	while (count-- > 0)
 		mWorkers.emplace_back(std::bind(&ThreadPool::run, this));
 }
 
 void ThreadPool::join() {
-	std::unique_lock lock(mWorkersMutex);
-	mJoining = true;
-	mCondition.notify_all();
+	{
+		std::unique_lock lock(mMutex);
+		mWaitingCondition.wait(lock, [&]() { return mWaitingWorkers == int(mWorkers.size()); });
+		mJoining = true;
+		mTasksCondition.notify_all();
+	}
 
+	std::unique_lock lock(mWorkersMutex);
 	for (auto &w : mWorkers)
 		w.join();
 
@@ -77,7 +79,7 @@ bool ThreadPool::runOne() {
 
 std::function<void()> ThreadPool::dequeue() {
 	std::unique_lock lock(mMutex);
-	while (true) {
+	while (!mJoining) {
 		if (!mTasks.empty()) {
 			if (mTasks.top().time <= clock::now()) {
 				auto func = std::move(mTasks.top().func);
@@ -85,21 +87,18 @@ std::function<void()> ThreadPool::dequeue() {
 				return func;
 			}
 
-			if (mJoining)
-				break;
+			++mWaitingWorkers;
+			mWaitingCondition.notify_all();
+			mTasksCondition.wait_until(lock, mTasks.top().time);
 
-			mCondition.wait_until(lock, mTasks.top().time);
 		} else {
-			if (mJoining)
-				break;
-
-			mCondition.wait(lock);
+			++mWaitingWorkers;
+			mWaitingCondition.notify_all();
+			mTasksCondition.wait(lock);
 		}
-	}
-
-	while (!mTasks.empty())
-		mTasks.pop();
 
+		--mWaitingWorkers;
+	}
 	return nullptr;
 }
 

+ 5 - 12
src/threadpool.hpp

@@ -72,7 +72,8 @@ protected:
 	std::function<void()> dequeue(); // returns null function if joining
 
 	std::vector<std::thread> mWorkers;
-	std::atomic<bool> mJoining = false;
+	int mWaitingWorkers = 0;
+	bool mJoining = false;
 
 	struct Task {
 		clock::time_point time;
@@ -82,8 +83,8 @@ protected:
 	};
 	std::priority_queue<Task, std::deque<Task>, std::greater<Task>> mTasks;
 
+	std::condition_variable mTasksCondition, mWaitingCondition;
 	mutable std::mutex mMutex, mWorkersMutex;
-	std::condition_variable mCondition;
 };
 
 template <class F, class... Args>
@@ -100,16 +101,8 @@ auto ThreadPool::schedule(clock::duration delay, F &&f, Args &&...args)
 template <class F, class... Args>
 auto ThreadPool::schedule(clock::time_point time, F &&f, Args &&...args)
     -> invoke_future_t<F, Args...> {
-	using R = std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>;
 	std::unique_lock lock(mMutex);
-	if (mJoining) {
-		std::promise<R> promise;
-		std::future<R> result = promise.get_future();
-		promise.set_exception(std::make_exception_ptr(
-		    std::runtime_error("Scheduled a task while joining the thread pool")));
-		return result;
-	}
-
+	using R = std::invoke_result_t<std::decay_t<F>, std::decay_t<Args>...>;
 	auto bound = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
 	auto task = std::make_shared<std::packaged_task<R()>>([bound = std::move(bound)]() mutable {
 		try {
@@ -122,7 +115,7 @@ auto ThreadPool::schedule(clock::time_point time, F &&f, Args &&...args)
 	std::future<R> result = task->get_future();
 
 	mTasks.push({time, [task = std::move(task), token = Init::Token()]() { return (*task)(); }});
-	mCondition.notify_one();
+	mTasksCondition.notify_one();
 	return result;
 }