Przeglądaj źródła

WorkerThreadPool: Avoid deadlocks when CommandQueueMT is involved

This commit lets CommandQueueMT play nicely with the WorkerThreadPool to avoid
non-progressable situations caused by an interdependence between both. While a
command queue is being flushed, it allows the WTP to release its lock while tasks
are being awaited so they can make progress in case they need in turn to post
to the command queue.
Pedro J. Estébanez 1 rok temu
rodzic
commit
ae418f9469

+ 28 - 0
core/object/worker_thread_pool.cpp

@@ -33,6 +33,7 @@
 #include "core/object/script_language.h"
 #include "core/object/script_language.h"
 #include "core/os/os.h"
 #include "core/os/os.h"
 #include "core/os/thread_safe.h"
 #include "core/os/thread_safe.h"
+#include "core/templates/command_queue_mt.h"
 
 
 void WorkerThreadPool::Task::free_template_userdata() {
 void WorkerThreadPool::Task::free_template_userdata() {
 	ERR_FAIL_NULL(template_userdata);
 	ERR_FAIL_NULL(template_userdata);
@@ -43,6 +44,8 @@ void WorkerThreadPool::Task::free_template_userdata() {
 
 
 WorkerThreadPool *WorkerThreadPool::singleton = nullptr;
 WorkerThreadPool *WorkerThreadPool::singleton = nullptr;
 
 
+thread_local CommandQueueMT *WorkerThreadPool::flushing_cmd_queue = nullptr;
+
 void WorkerThreadPool::_process_task(Task *p_task) {
 void WorkerThreadPool::_process_task(Task *p_task) {
 	int pool_thread_index = thread_ids[Thread::get_caller_id()];
 	int pool_thread_index = thread_ids[Thread::get_caller_id()];
 	ThreadData &curr_thread = threads[pool_thread_index];
 	ThreadData &curr_thread = threads[pool_thread_index];
@@ -428,7 +431,15 @@ Error WorkerThreadPool::wait_for_task_completion(TaskID p_task_id) {
 
 
 					if (!task_to_process) {
 					if (!task_to_process) {
 						caller_pool_thread->awaited_task = task;
 						caller_pool_thread->awaited_task = task;
+
+						if (flushing_cmd_queue) {
+							flushing_cmd_queue->unlock();
+						}
 						caller_pool_thread->cond_var.wait(lock);
 						caller_pool_thread->cond_var.wait(lock);
+						if (flushing_cmd_queue) {
+							flushing_cmd_queue->lock();
+						}
+
 						DEV_ASSERT(exit_threads || caller_pool_thread->signaled || task->completed);
 						DEV_ASSERT(exit_threads || caller_pool_thread->signaled || task->completed);
 						caller_pool_thread->awaited_task = nullptr;
 						caller_pool_thread->awaited_task = nullptr;
 					}
 					}
@@ -540,7 +551,14 @@ void WorkerThreadPool::wait_for_group_task_completion(GroupID p_group) {
 
 
 	{
 	{
 		Group *group = *groupp;
 		Group *group = *groupp;
+
+		if (flushing_cmd_queue) {
+			flushing_cmd_queue->unlock();
+		}
 		group->done_semaphore.wait();
 		group->done_semaphore.wait();
+		if (flushing_cmd_queue) {
+			flushing_cmd_queue->lock();
+		}
 
 
 		uint32_t max_users = group->tasks_used + 1; // Add 1 because the thread waiting for it is also user. Read before to avoid another thread freeing task after increment.
 		uint32_t max_users = group->tasks_used + 1; // Add 1 because the thread waiting for it is also user. Read before to avoid another thread freeing task after increment.
 		uint32_t finished_users = group->finished.increment(); // fetch happens before inc, so increment later.
 		uint32_t finished_users = group->finished.increment(); // fetch happens before inc, so increment later.
@@ -563,6 +581,16 @@ int WorkerThreadPool::get_thread_index() {
 	return singleton->thread_ids.has(tid) ? singleton->thread_ids[tid] : -1;
 	return singleton->thread_ids.has(tid) ? singleton->thread_ids[tid] : -1;
 }
 }
 
 
+void WorkerThreadPool::thread_enter_command_queue_mt_flush(CommandQueueMT *p_queue) {
+	ERR_FAIL_COND(flushing_cmd_queue != nullptr);
+	flushing_cmd_queue = p_queue;
+}
+
+void WorkerThreadPool::thread_exit_command_queue_mt_flush() {
+	ERR_FAIL_NULL(flushing_cmd_queue);
+	flushing_cmd_queue = nullptr;
+}
+
 void WorkerThreadPool::init(int p_thread_count, float p_low_priority_task_ratio) {
 void WorkerThreadPool::init(int p_thread_count, float p_low_priority_task_ratio) {
 	ERR_FAIL_COND(threads.size() > 0);
 	ERR_FAIL_COND(threads.size() > 0);
 	if (p_thread_count < 0) {
 	if (p_thread_count < 0) {

+ 7 - 0
core/object/worker_thread_pool.h

@@ -41,6 +41,8 @@
 #include "core/templates/rid.h"
 #include "core/templates/rid.h"
 #include "core/templates/safe_refcount.h"
 #include "core/templates/safe_refcount.h"
 
 
+class CommandQueueMT;
+
 class WorkerThreadPool : public Object {
 class WorkerThreadPool : public Object {
 	GDCLASS(WorkerThreadPool, Object)
 	GDCLASS(WorkerThreadPool, Object)
 public:
 public:
@@ -135,6 +137,8 @@ private:
 
 
 	static WorkerThreadPool *singleton;
 	static WorkerThreadPool *singleton;
 
 
+	static thread_local CommandQueueMT *flushing_cmd_queue;
+
 	TaskID _add_task(const Callable &p_callable, void (*p_func)(void *), void *p_userdata, BaseTemplateUserdata *p_template_userdata, bool p_high_priority, const String &p_description);
 	TaskID _add_task(const Callable &p_callable, void (*p_func)(void *), void *p_userdata, BaseTemplateUserdata *p_template_userdata, bool p_high_priority, const String &p_description);
 	GroupID _add_group_task(const Callable &p_callable, void (*p_func)(void *, uint32_t), void *p_userdata, BaseTemplateUserdata *p_template_userdata, int p_elements, int p_tasks, bool p_high_priority, const String &p_description);
 	GroupID _add_group_task(const Callable &p_callable, void (*p_func)(void *, uint32_t), void *p_userdata, BaseTemplateUserdata *p_template_userdata, int p_elements, int p_tasks, bool p_high_priority, const String &p_description);
 
 
@@ -197,6 +201,9 @@ public:
 	static WorkerThreadPool *get_singleton() { return singleton; }
 	static WorkerThreadPool *get_singleton() { return singleton; }
 	static int get_thread_index();
 	static int get_thread_index();
 
 
+	static void thread_enter_command_queue_mt_flush(CommandQueueMT *p_queue);
+	static void thread_exit_command_queue_mt_flush();
+
 	void init(int p_thread_count = -1, float p_low_priority_task_ratio = 0.3);
 	void init(int p_thread_count = -1, float p_low_priority_task_ratio = 0.3);
 	void finish();
 	void finish();
 	WorkerThreadPool();
 	WorkerThreadPool();

+ 31 - 19
core/templates/command_queue_mt.h

@@ -31,6 +31,7 @@
 #ifndef COMMAND_QUEUE_MT_H
 #ifndef COMMAND_QUEUE_MT_H
 #define COMMAND_QUEUE_MT_H
 #define COMMAND_QUEUE_MT_H
 
 
+#include "core/object/worker_thread_pool.h"
 #include "core/os/memory.h"
 #include "core/os/memory.h"
 #include "core/os/mutex.h"
 #include "core/os/mutex.h"
 #include "core/os/semaphore.h"
 #include "core/os/semaphore.h"
@@ -306,15 +307,15 @@ class CommandQueueMT {
 
 
 	struct CommandBase {
 	struct CommandBase {
 		virtual void call() = 0;
 		virtual void call() = 0;
-		virtual void post() {}
-		virtual ~CommandBase() {}
+		virtual SyncSemaphore *get_sync_semaphore() { return nullptr; }
+		virtual ~CommandBase() = default; // Won't be called.
 	};
 	};
 
 
 	struct SyncCommand : public CommandBase {
 	struct SyncCommand : public CommandBase {
 		SyncSemaphore *sync_sem = nullptr;
 		SyncSemaphore *sync_sem = nullptr;
 
 
-		virtual void post() override {
-			sync_sem->sem.post();
+		virtual SyncSemaphore *get_sync_semaphore() override {
+			return sync_sem;
 		}
 		}
 	};
 	};
 
 
@@ -340,6 +341,7 @@ class CommandQueueMT {
 	SyncSemaphore sync_sems[SYNC_SEMAPHORES];
 	SyncSemaphore sync_sems[SYNC_SEMAPHORES];
 	Mutex mutex;
 	Mutex mutex;
 	Semaphore *sync = nullptr;
 	Semaphore *sync = nullptr;
+	uint64_t flush_read_ptr = 0;
 
 
 	template <class T>
 	template <class T>
 	T *allocate() {
 	T *allocate() {
@@ -362,31 +364,41 @@ class CommandQueueMT {
 	void _flush() {
 	void _flush() {
 		lock();
 		lock();
 
 
-		uint64_t read_ptr = 0;
-		uint64_t limit = command_mem.size();
-
-		while (read_ptr < limit) {
-			uint64_t size = *(uint64_t *)&command_mem[read_ptr];
-			read_ptr += 8;
-			CommandBase *cmd = reinterpret_cast<CommandBase *>(&command_mem[read_ptr]);
-
-			cmd->call(); //execute the function
-			cmd->post(); //release in case it needs sync/ret
-			cmd->~CommandBase(); //should be done, so erase the command
-
-			read_ptr += size;
+		WorkerThreadPool::thread_enter_command_queue_mt_flush(this);
+		while (flush_read_ptr < command_mem.size()) {
+			uint64_t size = *(uint64_t *)&command_mem[flush_read_ptr];
+			flush_read_ptr += 8;
+			CommandBase *cmd = reinterpret_cast<CommandBase *>(&command_mem[flush_read_ptr]);
+
+			SyncSemaphore *sync_sem = cmd->get_sync_semaphore();
+			cmd->call();
+			if (sync_sem) {
+				sync_sem->sem.post(); // Release in case it needs sync/ret.
+			}
+
+			if (unlikely(flush_read_ptr == 0)) {
+				// A reentrant call flushed.
+				DEV_ASSERT(command_mem.is_empty());
+				unlock();
+				return;
+			}
+
+			flush_read_ptr += size;
 		}
 		}
+		WorkerThreadPool::thread_exit_command_queue_mt_flush();
 
 
 		command_mem.clear();
 		command_mem.clear();
+		flush_read_ptr = 0;
 		unlock();
 		unlock();
 	}
 	}
 
 
-	void lock();
-	void unlock();
 	void wait_for_flush();
 	void wait_for_flush();
 	SyncSemaphore *_alloc_sync_sem();
 	SyncSemaphore *_alloc_sync_sem();
 
 
 public:
 public:
+	void lock();
+	void unlock();
+
 	/* NORMAL PUSH COMMANDS */
 	/* NORMAL PUSH COMMANDS */
 	DECL_PUSH(0)
 	DECL_PUSH(0)
 	SPACE_SEP_LIST(DECL_PUSH, 15)
 	SPACE_SEP_LIST(DECL_PUSH, 15)