Sfoglia il codice sorgente

WorkerThreadPool: Support daemon-like tasks (via yield semantics)

Pedro J. Estébanez 1 anno fa
parent
commit
1b104ffcd8

+ 95 - 59
core/object/worker_thread_pool.cpp

@@ -35,6 +35,8 @@
 #include "core/os/thread_safe.h"
 #include "core/templates/command_queue_mt.h"
 
+WorkerThreadPool::Task *const WorkerThreadPool::ThreadData::YIELDING = (Task *)1;
+
 void WorkerThreadPool::Task::free_template_userdata() {
 	ERR_FAIL_NULL(template_userdata);
 	ERR_FAIL_NULL(native_func_userdata);
@@ -391,83 +393,117 @@ Error WorkerThreadPool::wait_for_task_completion(TaskID p_task_id) {
 	task_mutex.unlock();
 
 	if (caller_pool_thread) {
-		while (true) {
-			Task *task_to_process = nullptr;
-			{
-				MutexLock lock(task_mutex);
-				bool was_signaled = caller_pool_thread->signaled;
-				caller_pool_thread->signaled = false;
-
-				if (task->completed) {
-					// This thread was awaken also for some reason, but it's about to exit.
-					// Let's find out what may be pending and forward the requests.
-					if (!exit_threads && was_signaled) {
-						uint32_t to_process = task_queue.first() ? 1 : 0;
-						uint32_t to_promote = caller_pool_thread->current_task->low_priority && low_priority_task_queue.first() ? 1 : 0;
-						if (to_process || to_promote) {
-							// This thread must be left alone since it won't loop again.
-							caller_pool_thread->signaled = true;
-							_notify_threads(caller_pool_thread, to_process, to_promote);
-						}
-					}
+		_wait_collaboratively(caller_pool_thread, task);
+		task->waiting_pool--;
+		if (task->waiting_pool == 0 && task->waiting_user == 0) {
+			tasks.erase(p_task_id);
+			task_allocator.free(task);
+		}
+	} else {
+		task->done_semaphore.wait();
+		task_mutex.lock();
+		task->waiting_user--;
+		if (task->waiting_pool == 0 && task->waiting_user == 0) {
+			tasks.erase(p_task_id);
+			task_allocator.free(task);
+		}
+		task_mutex.unlock();
+	}
 
-					task->waiting_pool--;
-					if (task->waiting_pool == 0 && task->waiting_user == 0) {
-						tasks.erase(p_task_id);
-						task_allocator.free(task);
-					}
+	return OK;
+}
 
-					break;
-				}
+void WorkerThreadPool::_wait_collaboratively(ThreadData *p_caller_pool_thread, Task *p_task) {
+	// Keep processing tasks until the condition to stop waiting is met.
 
-				if (!exit_threads) {
-					// This is a thread from the pool. It shouldn't just idle.
-					// Let's try to process other tasks while we wait.
+#define IS_WAIT_OVER (unlikely(p_task == ThreadData::YIELDING) ? p_caller_pool_thread->yield_is_over : p_task->completed)
 
-					if (caller_pool_thread->current_task->low_priority && low_priority_task_queue.first()) {
-						if (_try_promote_low_priority_task()) {
-							_notify_threads(caller_pool_thread, 1, 0);
-						}
+	while (true) {
+		Task *task_to_process = nullptr;
+		{
+			MutexLock lock(task_mutex);
+			bool was_signaled = p_caller_pool_thread->signaled;
+			p_caller_pool_thread->signaled = false;
+
+			if (IS_WAIT_OVER) {
+				p_caller_pool_thread->yield_is_over = false;
+				if (!exit_threads && was_signaled) {
+					// This thread was awaken for some additional reason, but it's about to exit.
+					// Let's find out what may be pending and forward the requests.
+					uint32_t to_process = task_queue.first() ? 1 : 0;
+					uint32_t to_promote = p_caller_pool_thread->current_task->low_priority && low_priority_task_queue.first() ? 1 : 0;
+					if (to_process || to_promote) {
+						// This thread must be left alone since it won't loop again.
+						p_caller_pool_thread->signaled = true;
+						_notify_threads(p_caller_pool_thread, to_process, to_promote);
 					}
+				}
+
+				break;
+			}
 
-					if (singleton->task_queue.first()) {
-						task_to_process = task_queue.first()->self();
-						task_queue.remove(task_queue.first());
+			if (!exit_threads) {
+				if (p_caller_pool_thread->current_task->low_priority && low_priority_task_queue.first()) {
+					if (_try_promote_low_priority_task()) {
+						_notify_threads(p_caller_pool_thread, 1, 0);
 					}
+				}
 
-					if (!task_to_process) {
-						caller_pool_thread->awaited_task = task;
+				if (singleton->task_queue.first()) {
+					task_to_process = task_queue.first()->self();
+					task_queue.remove(task_queue.first());
+				}
 
-						if (flushing_cmd_queue) {
-							flushing_cmd_queue->unlock();
-						}
-						caller_pool_thread->cond_var.wait(lock);
-						if (flushing_cmd_queue) {
-							flushing_cmd_queue->lock();
-						}
+				if (!task_to_process) {
+					p_caller_pool_thread->awaited_task = p_task;
 
-						DEV_ASSERT(exit_threads || caller_pool_thread->signaled || task->completed);
-						caller_pool_thread->awaited_task = nullptr;
+					if (flushing_cmd_queue) {
+						flushing_cmd_queue->unlock();
+					}
+					p_caller_pool_thread->cond_var.wait(lock);
+					if (flushing_cmd_queue) {
+						flushing_cmd_queue->lock();
 					}
-				}
-			}
 
-			if (task_to_process) {
-				_process_task(task_to_process);
+					DEV_ASSERT(exit_threads || p_caller_pool_thread->signaled || IS_WAIT_OVER);
+					p_caller_pool_thread->awaited_task = nullptr;
+				}
 			}
 		}
-	} else {
-		task->done_semaphore.wait();
-		task_mutex.lock();
-		task->waiting_user--;
-		if (task->waiting_pool == 0 && task->waiting_user == 0) {
-			tasks.erase(p_task_id);
-			task_allocator.free(task);
+
+		if (task_to_process) {
+			_process_task(task_to_process);
 		}
+	}
+}
+
+void WorkerThreadPool::yield() {
+	int th_index = get_thread_index();
+	ERR_FAIL_COND_MSG(th_index == -1, "This function can only be called from a worker thread.");
+	_wait_collaboratively(&threads[th_index], ThreadData::YIELDING);
+}
+
+void WorkerThreadPool::notify_yield_over(TaskID p_task_id) {
+	task_mutex.lock();
+	Task **taskp = tasks.getptr(p_task_id);
+	if (!taskp) {
 		task_mutex.unlock();
+		ERR_FAIL_MSG("Invalid Task ID.");
 	}
+	Task *task = *taskp;
 
-	return OK;
+#ifdef DEBUG_ENABLED
+	if (task->pool_thread_index == get_thread_index()) {
+		WARN_PRINT("A worker thread is attempting to notify itself. That makes no sense.");
+	}
+#endif
+
+	ThreadData &td = threads[task->pool_thread_index];
+	td.yield_is_over = true;
+	td.signaled = true;
+	td.cond_var.notify_one();
+
+	task_mutex.unlock();
 }
 
 WorkerThreadPool::GroupID WorkerThreadPool::_add_group_task(const Callable &p_callable, void (*p_func)(void *, uint32_t), void *p_userdata, BaseTemplateUserdata *p_template_userdata, int p_elements, int p_tasks, bool p_high_priority, const String &p_description) {

+ 16 - 3
core/object/worker_thread_pool.h

@@ -107,13 +107,21 @@ private:
 	BinaryMutex task_mutex;
 
 	struct ThreadData {
+		static Task *const YIELDING; // Too bad constexpr doesn't work here.
+
 		uint32_t index = 0;
 		Thread thread;
-		bool ready_for_scripting = false;
-		bool signaled = false;
+		bool ready_for_scripting : 1;
+		bool signaled : 1;
+		bool yield_is_over : 1;
 		Task *current_task = nullptr;
-		Task *awaited_task = nullptr; // Null if not awaiting the condition variable. Special value for idle-waiting.
+		Task *awaited_task = nullptr; // Null if not awaiting the condition variable, or special value (YIELDING).
 		ConditionVariable cond_var;
+
+		ThreadData() :
+				ready_for_scripting(false),
+				signaled(false),
+				yield_is_over(false) {}
 	};
 
 	TightLocalVector<ThreadData> threads;
@@ -177,6 +185,8 @@ private:
 		}
 	};
 
+	void _wait_collaboratively(ThreadData *p_caller_pool_thread, Task *p_task);
+
 protected:
 	static void _bind_methods();
 
@@ -196,6 +206,9 @@ public:
 	bool is_task_completed(TaskID p_task_id) const;
 	Error wait_for_task_completion(TaskID p_task_id);
 
+	void yield();
+	void notify_yield_over(TaskID p_task_id);
+
 	template <typename C, typename M, typename U>
 	GroupID add_template_group_task(C *p_instance, M p_method, U p_userdata, int p_elements, int p_tasks = -1, bool p_high_priority = false, const String &p_description = String()) {
 		typedef GroupUserData<C, M, U> GroupUD;

+ 34 - 10
tests/core/templates/test_command_queue.h

@@ -33,6 +33,7 @@
 
 #include "core/config/project_settings.h"
 #include "core/math/random_number_generator.h"
+#include "core/object/worker_thread_pool.h"
 #include "core/os/os.h"
 #include "core/os/thread.h"
 #include "core/templates/command_queue_mt.h"
@@ -100,7 +101,7 @@ public:
 	ThreadWork reader_threadwork;
 	ThreadWork writer_threadwork;
 
-	CommandQueueMT command_queue = CommandQueueMT(true);
+	CommandQueueMT command_queue;
 
 	enum TestMsgType {
 		TEST_MSG_FUNC1_TRANSFORM,
@@ -119,6 +120,7 @@ public:
 	bool exit_threads = false;
 
 	Thread reader_thread;
+	WorkerThreadPool::TaskID reader_task_id = WorkerThreadPool::INVALID_TASK_ID;
 	Thread writer_thread;
 
 	int func1_count = 0;
@@ -148,11 +150,16 @@ public:
 	void reader_thread_loop() {
 		reader_threadwork.thread_wait_for_work();
 		while (!exit_threads) {
-			if (message_count_to_read < 0) {
+			if (reader_task_id == WorkerThreadPool::INVALID_TASK_ID) {
 				command_queue.flush_all();
-			}
-			for (int i = 0; i < message_count_to_read; i++) {
-				command_queue.wait_and_flush();
+			} else {
+				if (message_count_to_read < 0) {
+					command_queue.flush_all();
+				}
+				for (int i = 0; i < message_count_to_read; i++) {
+					WorkerThreadPool::get_singleton()->yield();
+					command_queue.wait_and_flush();
+				}
 			}
 			message_count_to_read = 0;
 
@@ -216,8 +223,13 @@ public:
 		sts->writer_thread_loop();
 	}
 
-	void init_threads() {
-		reader_thread.start(&SharedThreadState::static_reader_thread_loop, this);
+	void init_threads(bool p_use_thread_pool_sync = false) {
+		if (p_use_thread_pool_sync) {
+			reader_task_id = WorkerThreadPool::get_singleton()->add_native_task(&SharedThreadState::static_reader_thread_loop, this, true);
+			command_queue.set_pump_task_id(reader_task_id);
+		} else {
+			reader_thread.start(&SharedThreadState::static_reader_thread_loop, this);
+		}
 		writer_thread.start(&SharedThreadState::static_writer_thread_loop, this);
 	}
 	void destroy_threads() {
@@ -225,16 +237,20 @@ public:
 		reader_threadwork.main_start_work();
 		writer_threadwork.main_start_work();
 
-		reader_thread.wait_to_finish();
+		if (reader_task_id != WorkerThreadPool::INVALID_TASK_ID) {
+			WorkerThreadPool::get_singleton()->wait_for_task_completion(reader_task_id);
+		} else {
+			reader_thread.wait_to_finish();
+		}
 		writer_thread.wait_to_finish();
 	}
 };
 
-TEST_CASE("[CommandQueue] Test Queue Basics") {
+static void test_command_queue_basic(bool p_use_thread_pool_sync) {
 	const char *COMMAND_QUEUE_SETTING = "memory/limits/command_queue/multithreading_queue_size_kb";
 	ProjectSettings::get_singleton()->set_setting(COMMAND_QUEUE_SETTING, 1);
 	SharedThreadState sts;
-	sts.init_threads();
+	sts.init_threads(p_use_thread_pool_sync);
 
 	sts.add_msg_to_write(SharedThreadState::TEST_MSG_FUNC1_TRANSFORM);
 	sts.writer_threadwork.main_start_work();
@@ -272,6 +288,14 @@ TEST_CASE("[CommandQueue] Test Queue Basics") {
 			ProjectSettings::get_singleton()->property_get_revert(COMMAND_QUEUE_SETTING));
 }
 
+TEST_CASE("[CommandQueue] Test Queue Basics") {
+	test_command_queue_basic(false);
+}
+
+TEST_CASE("[CommandQueue] Test Queue Basics with WorkerThreadPool sync.") {
+	test_command_queue_basic(true);
+}
+
 TEST_CASE("[CommandQueue] Test Queue Wrapping to same spot.") {
 	const char *COMMAND_QUEUE_SETTING = "memory/limits/command_queue/multithreading_queue_size_kb";
 	ProjectSettings::get_singleton()->set_setting(COMMAND_QUEUE_SETTING, 1);

+ 67 - 0
tests/core/threads/test_worker_thread_pool.h

@@ -38,6 +38,7 @@
 namespace TestWorkerThreadPool {
 
 static LocalVector<SafeNumeric<int>> counter;
+static SafeFlag exit;
 
 static void static_test(void *p_arg) {
 	counter[(uint64_t)p_arg].increment();
@@ -106,6 +107,72 @@ TEST_CASE("[WorkerThreadPool] Process elements using group tasks") {
 	}
 }
 
+static void static_test_daemon(void *p_arg) {
+	while (!exit.is_set()) {
+		counter[0].add(1);
+		WorkerThreadPool::get_singleton()->yield();
+	}
+}
+
+static void static_busy_task(void *p_arg) {
+	while (!exit.is_set()) {
+		OS::get_singleton()->delay_usec(1);
+	}
+}
+
+static void static_legit_task(void *p_arg) {
+	*((bool *)p_arg) = counter[0].get() > 0;
+	counter[1].add(1);
+}
+
+TEST_CASE("[WorkerThreadPool] Run a yielding daemon as the only hope for other tasks to run") {
+	exit.clear();
+	counter.clear();
+	counter.resize(2);
+
+	WorkerThreadPool::TaskID daemon_task_id = WorkerThreadPool::get_singleton()->add_native_task(static_test_daemon, nullptr, true);
+
+	int num_threads = WorkerThreadPool::get_singleton()->get_thread_count();
+
+	// Keep all the other threads busy.
+	LocalVector<WorkerThreadPool::TaskID> task_ids;
+	for (int i = 0; i < num_threads - 1; i++) {
+		task_ids.push_back(WorkerThreadPool::get_singleton()->add_native_task(static_busy_task, nullptr, true));
+	}
+
+	LocalVector<WorkerThreadPool::TaskID> legit_task_ids;
+	LocalVector<bool> legit_task_needed_yield;
+	int legit_tasks_count = num_threads * 4;
+	legit_task_needed_yield.resize(legit_tasks_count);
+	for (int i = 0; i < legit_tasks_count; i++) {
+		legit_task_needed_yield[i] = false;
+		task_ids.push_back(WorkerThreadPool::get_singleton()->add_native_task(static_legit_task, &legit_task_needed_yield[i], i >= legit_tasks_count / 2));
+	}
+
+	while (counter[1].get() != legit_tasks_count) {
+		OS::get_singleton()->delay_usec(1);
+	}
+
+	exit.set();
+	for (uint32_t i = 0; i < task_ids.size(); i++) {
+		WorkerThreadPool::get_singleton()->wait_for_task_completion(task_ids[i]);
+	}
+	WorkerThreadPool::get_singleton()->notify_yield_over(daemon_task_id);
+	WorkerThreadPool::get_singleton()->wait_for_task_completion(daemon_task_id);
+
+	CHECK_MESSAGE(counter[0].get() > 0, "Daemon task should have looped at least once.");
+	CHECK_MESSAGE(counter[1].get() == legit_tasks_count, "All legit tasks should have been able to run.");
+
+	bool all_needed_yield = true;
+	for (int i = 0; i < legit_tasks_count; i++) {
+		if (!legit_task_needed_yield[i]) {
+			all_needed_yield = false;
+			break;
+		}
+	}
+	CHECK_MESSAGE(all_needed_yield, "All legit tasks should have needed the daemon yielding to run.");
+}
+
 } // namespace TestWorkerThreadPool
 
 #endif // TEST_WORKER_THREAD_POOL_H