Browse Source

The ThreadHive is tested and ready to shred

Panagiotis Christopoulos Charitos 9 years ago
parent
commit
67c3431657
3 changed files with 206 additions and 127 deletions
  1. 21 19
      include/anki/util/ThreadHive.h
  2. 118 101
      src/util/ThreadHive.cpp
  3. 67 7
      tests/util/ThreadHive.cpp

+ 21 - 19
include/anki/util/ThreadHive.h

@@ -70,20 +70,20 @@ public:
 	void waitAllTasks();
 	void waitAllTasks();
 
 
 private:
 private:
-	static const U MAX_DEPS = 4;
+	static const U MAX_DEPS = 2;
 
 
 	/// Lightweight task.
 	/// Lightweight task.
 	class Task
 	class Task
 	{
 	{
 	public:
 	public:
-		ThreadHiveTaskCallback m_cb;
-		void* m_arg;
+		Task* m_next; ///< Next in the list.
 
 
-		union
-		{
-			Array<ThreadHiveDependencyHandle, MAX_DEPS> m_deps;
-			U64 m_depsU64;
-		};
+		ThreadHiveTaskCallback m_cb; ///< Callback that defines the task.
+		void* m_arg; ///< Args for the callback.
+
+		U16 m_depCount;
+		Array<ThreadHiveDependencyHandle, MAX_DEPS> m_deps;
+		Bool8 m_othersDepend; ///< Other tasks depend on this one.
 
 
 		Bool done() const
 		Bool done() const
 		{
 		{
@@ -91,30 +91,32 @@ private:
 		}
 		}
 	};
 	};
 
 
-	static_assert(sizeof(Task) == sizeof(void*) * 2 + 8, "Too big");
+	static_assert(sizeof(Task) == (sizeof(void*) * 3 + 8), "Wrong size");
 
 
 	GenericMemoryPoolAllocator<U8> m_alloc;
 	GenericMemoryPoolAllocator<U8> m_alloc;
 	ThreadHiveThread* m_threads = nullptr;
 	ThreadHiveThread* m_threads = nullptr;
 	U32 m_threadCount = 0;
 	U32 m_threadCount = 0;
 
 
-	DArray<Task> m_queue; ///< Task queue.
-	I32 m_head = 0; ///< Head of m_queue.
-	I32 m_tail = -1; ///< Tail of m_queue.
-	U64 m_workingThreadsMask = 0; ///< Mask with the threads that have work.
+	DArray<Task> m_storage; ///< Task storage.
+	Task* m_head = nullptr; ///< Head of the task list.
+	Task* m_tail = nullptr; ///< Tail of the task list.
 	Bool m_quit = false;
 	Bool m_quit = false;
-	U64 m_waitingThreadsMask = 0;
+	U32 m_pendingTasks = 0;
+	U32 m_allocatedTasks = 0;
 
 
 	Mutex m_mtx; ///< Protect the queue
 	Mutex m_mtx; ///< Protect the queue
 	ConditionVariable m_cvar;
 	ConditionVariable m_cvar;
 
 
-	Bool m_mainThreadStopWaiting = false;
-	Mutex m_mainThreadMtx;
-	ConditionVariable m_mainThreadCvar;
-
 	void run(U threadId);
 	void run(U threadId);
 
 
+	Bool waitForWork(
+		U threadId, Task*& task, ThreadHiveTaskCallback& cb, void*& arg);
+
 	/// Get new work from the queue.
 	/// Get new work from the queue.
-	ThreadHiveTaskCallback getNewWork(void*& arg);
+	Task* getNewTask(ThreadHiveTaskCallback& cb, void*& arg);
+
+	/// Complete a task.
+	void completeTask(U taskId);
 };
 };
 /// @}
 /// @}
 
 

+ 118 - 101
src/util/ThreadHive.cpp

@@ -14,7 +14,7 @@ namespace anki
 // Misc                                                                        =
 // Misc                                                                        =
 //==============================================================================
 //==============================================================================
 
 
-#define ANKI_ENABLE_HIVE_DEBUG_PRINT 1
+#define ANKI_ENABLE_HIVE_DEBUG_PRINT 0
 
 
 #if ANKI_ENABLE_HIVE_DEBUG_PRINT
 #if ANKI_ENABLE_HIVE_DEBUG_PRINT
 #define ANKI_HIVE_DEBUG_PRINT(...) printf(__VA_ARGS__)
 #define ANKI_HIVE_DEBUG_PRINT(...) printf(__VA_ARGS__)
@@ -67,13 +67,13 @@ ThreadHive::ThreadHive(U threadCount, GenericMemoryPoolAllocator<U8> alloc)
 		::new(&m_threads[i]) ThreadHiveThread(i, this);
 		::new(&m_threads[i]) ThreadHiveThread(i, this);
 	}
 	}
 
 
-	m_queue.create(m_alloc, 1024);
+	m_storage.create(m_alloc, 1024);
 }
 }
 
 
 //==============================================================================
 //==============================================================================
 ThreadHive::~ThreadHive()
 ThreadHive::~ThreadHive()
 {
 {
-	m_queue.destroy(m_alloc);
+	m_storage.destroy(m_alloc);
 
 
 	if(m_threads)
 	if(m_threads)
 	{
 	{
@@ -104,35 +104,59 @@ void ThreadHive::submitTasks(ThreadHiveTask* tasks, U taskCount)
 {
 {
 	ANKI_ASSERT(tasks && taskCount > 0);
 	ANKI_ASSERT(tasks && taskCount > 0);
 
 
-	// Create the tasks to temp memory to decrease thread contention
-	Array<Task, 64> tempTasks;
-	for(U i = 0; i < taskCount; ++i)
+	U allocatedTasks;
+
+	// Push work
 	{
 	{
-		tempTasks[i].m_cb = tasks[i].m_callback;
-		tempTasks[i].m_arg = tasks[i].m_argument;
-		tempTasks[i].m_depsU64 = 0;
+		LockGuard<Mutex> lock(m_mtx);
 
 
-		ANKI_ASSERT(tasks[i].m_inDependencies.getSize() <= MAX_DEPS
-			&& "For now only limited deps");
-		for(U j = 0; j < tasks[i].m_inDependencies.getSize(); ++j)
+		for(U i = 0; i < taskCount; ++i)
 		{
 		{
-			tempTasks[i].m_deps[j] = tasks[i].m_inDependencies[j];
-		}
-	}
+			const auto& inTask = tasks[i];
+			Task& outTask = m_storage[m_allocatedTasks];
+
+			outTask.m_cb = inTask.m_callback;
+			outTask.m_arg = inTask.m_argument;
+			outTask.m_depCount = 0;
+			outTask.m_next = nullptr;
+			outTask.m_othersDepend = false;
+
+			// Set the dependencies
+			ANKI_ASSERT(inTask.m_inDependencies.getSize() <= MAX_DEPS
+				&& "For now only limited deps");
+			for(U j = 0; j < inTask.m_inDependencies.getSize(); ++j)
+			{
+				ThreadHiveDependencyHandle dep = inTask.m_inDependencies[j];
+				ANKI_ASSERT(dep < m_allocatedTasks);
 
 
-	// Push work
-	I firstTaskIdx;
+				if(!m_storage[dep].done())
+				{
+					outTask.m_deps[outTask.m_depCount++] = dep;
+					m_storage[dep].m_othersDepend = true;
+				}
+			}
 
 
-	{
-		LockGuard<Mutex> lock(m_mtx);
+			// Push to the list
+			if(m_head == nullptr)
+			{
+				ANKI_ASSERT(m_tail == nullptr);
+				m_head = &m_storage[m_allocatedTasks];
+				m_tail = m_head;
+			}
+			else
+			{
+				ANKI_ASSERT(m_tail && m_head);
+				m_tail->m_next = &outTask;
+				m_tail = &outTask;
+			}
 
 
-		// "Allocate" storage for tasks
-		firstTaskIdx = m_tail + 1;
-		m_tail += taskCount;
+			++m_allocatedTasks;
+		}
 
 
-		// Store tasks
-		memcpy(&m_queue[firstTaskIdx], &tempTasks[0], sizeof(Task) * taskCount);
+		allocatedTasks = m_allocatedTasks;
+		m_pendingTasks += taskCount;
 
 
+		ANKI_HIVE_DEBUG_PRINT("submit tasks\n");
 		// Notify all threads
 		// Notify all threads
 		m_cvar.notifyAll();
 		m_cvar.notifyAll();
 	}
 	}
@@ -140,141 +164,134 @@ void ThreadHive::submitTasks(ThreadHiveTask* tasks, U taskCount)
 	// Set the out dependencies
 	// Set the out dependencies
 	for(U i = 0; i < taskCount; ++i)
 	for(U i = 0; i < taskCount; ++i)
 	{
 	{
-		tasks[i].m_outDependency = firstTaskIdx + i;
+		tasks[i].m_outDependency = allocatedTasks - taskCount + i;
 	}
 	}
 }
 }
 
 
 //==============================================================================
 //==============================================================================
 void ThreadHive::run(U threadId)
 void ThreadHive::run(U threadId)
 {
 {
-	U64 threadMask = 1 << threadId;
+	Task* task = nullptr;
+	ThreadHiveTaskCallback cb = nullptr;
+	void* arg = nullptr;
 
 
-	while(1)
+	while(!waitForWork(threadId, task, cb, arg))
 	{
 	{
-		// Wait for something
-		ThreadHiveTaskCallback cb;
-		void* arg;
-		Bool quit;
-
-		{
-			LockGuard<Mutex> lock(m_mtx);
-
-			ANKI_HIVE_DEBUG_PRINT("tid: %lu locking\n", threadId);
-
-			while(!m_quit && (cb = getNewWork(arg)) == nullptr)
-			{
-				ANKI_HIVE_DEBUG_PRINT("tid: %lu waiting, cb %p\n", 
-					threadId, 
-					reinterpret_cast<void*>(cb));
-
-				m_waitingThreadsMask |= threadMask;
+		// Run the task
+		cb(arg, threadId, *this);
+		ANKI_HIVE_DEBUG_PRINT("tid: %lu executed\n", threadId);
+	}
 
 
-				if(__builtin_popcount(m_waitingThreadsMask) == m_threadCount)
-				{
-					ANKI_HIVE_DEBUG_PRINT("tid: %lu all threads done. 0x%lu\n", 
-						threadId, 
-						m_waitingThreadsMask);
-					LockGuard<Mutex> lock2(m_mainThreadMtx);
+	ANKI_HIVE_DEBUG_PRINT("tid: %lu thread quits!\n", threadId);
+}
 
 
-					m_mainThreadStopWaiting = true;
+//==============================================================================
+Bool ThreadHive::waitForWork(
+	U threadId, Task*& task, ThreadHiveTaskCallback& cb, void*& arg)
+{
+	cb = nullptr;
+	arg = nullptr;
 
 
-					// Everyone is waiting. Wake the main thread
-					m_mainThreadCvar.notifyOne();
-				}
+	LockGuard<Mutex> lock(m_mtx);
 
 
-				// Wait if there is no work.
-				m_cvar.wait(m_mtx);
-			}
+	ANKI_HIVE_DEBUG_PRINT("tid: %lu locking\n", threadId);
 
 
-			m_waitingThreadsMask &= ~threadMask;
-			quit = m_quit;
-		}
+	// Complete the previous task
+	if(task)
+	{
+		task->m_cb = nullptr;
+		--m_pendingTasks;
 
 
-		if(quit)
+		if(task->m_othersDepend || m_pendingTasks == 0)
 		{
 		{
-			break;
+			// A dependency got resolved or we are out of tasks. Wake them all
+			ANKI_HIVE_DEBUG_PRINT("tid: %lu wake all\n", threadId);
+			m_cvar.notifyAll();
 		}
 		}
+	}
 
 
-		// Run the task
-		cb(arg, threadId, *this);
-		ANKI_HIVE_DEBUG_PRINT("dit: %lu executed\n", threadId);
+	while(!m_quit && (task = getNewTask(cb, arg)) == nullptr)
+	{
+		ANKI_HIVE_DEBUG_PRINT("tid: %lu waiting\n", threadId);
+
+		// Wait if there is no work.
+		m_cvar.wait(m_mtx);
 	}
 	}
 
 
-	ANKI_HIVE_DEBUG_PRINT("dit: %lu thread quits!\n", threadId);
+	return m_quit;
 }
 }
 
 
 //==============================================================================
 //==============================================================================
-ThreadHiveTaskCallback ThreadHive::getNewWork(void*& arg)
+ThreadHive::Task* ThreadHive::getNewTask(ThreadHiveTaskCallback& cb, void*& arg)
 {
 {
-	ThreadHiveTaskCallback cb = nullptr;
+	cb = nullptr;
 
 
-	for(I i = m_head; cb == nullptr && i <= m_tail; ++i)
+	Task* prevTask = nullptr;
+	Task* task = m_head;
+	while(task)
 	{
 	{
-		Task& task = m_queue[i];
-		if(!task.done())
+		if(!task->done())
 		{
 		{
 			// We may have a candiate
 			// We may have a candiate
 
 
 			// Check if there are dependencies
 			// Check if there are dependencies
 			Bool allDepsCompleted = true;
 			Bool allDepsCompleted = true;
-			if(task.m_depsU64 != 0)
+			for(U j = 0; j < task->m_depCount; ++j)
 			{
 			{
-				for(U j = 0; j < MAX_DEPS; ++j)
-				{
-					I32 dep = task.m_deps[j];
+				U dep = task->m_deps[j];
 
 
-					if(dep < m_head || dep > m_tail || !m_queue[dep].done())
-					{
-						allDepsCompleted = false;
-						break;
-					}
+				if(!m_storage[dep].done())
+				{
+					allDepsCompleted = false;
+					break;
 				}
 				}
 			}
 			}
 
 
 			if(allDepsCompleted)
 			if(allDepsCompleted)
 			{
 			{
-				// Found something
-				cb = task.m_cb;
-				arg = task.m_arg;
-
-				// "Complete" the task
-				task.m_cb = nullptr;
+				// Found something, pop it
+				cb = task->m_cb;
+				arg = task->m_arg;
 
 
-				if(ANKI_UNLIKELY(m_head == m_tail))
+				if(prevTask)
 				{
 				{
-					// Reset it
-					m_head = 0;
-					m_tail = -1;
+					prevTask->m_next = task->m_next;
 				}
 				}
-				else if(i == m_head)
+
+				if(m_head == task)
 				{
 				{
-					// Pop front
-					++m_head;
+					m_head = task->m_next;
 				}
 				}
-				else if(i == m_tail)
+
+				if(m_tail == task)
 				{
 				{
-					// Pop back
-					--m_tail;
+					m_tail = prevTask;
 				}
 				}
+				break;
 			}
 			}
 		}
 		}
+
+		prevTask = task;
+		task = task->m_next;
 	}
 	}
 
 
-	return cb;
+	return task;
 }
 }
 
 
 //==============================================================================
 //==============================================================================
 void ThreadHive::waitAllTasks()
 void ThreadHive::waitAllTasks()
 {
 {
 	ANKI_HIVE_DEBUG_PRINT("mt: waiting all\n");
 	ANKI_HIVE_DEBUG_PRINT("mt: waiting all\n");
-	LockGuard<Mutex> lock(m_mainThreadMtx);
 
 
-	while(!m_mainThreadStopWaiting)
+	LockGuard<Mutex> lock(m_mtx);
+	while(m_pendingTasks > 0)
 	{
 	{
-		m_mainThreadCvar.wait(m_mainThreadMtx);
+		m_cvar.wait(m_mtx);
 	}
 	}
 
 
-	m_mainThreadStopWaiting = false;
+	m_head = nullptr;
+	m_tail = nullptr;
+	m_allocatedTasks = 0;
 
 
 	ANKI_HIVE_DEBUG_PRINT("mt: done waiting all\n");
 	ANKI_HIVE_DEBUG_PRINT("mt: done waiting all\n");
 }
 }

+ 67 - 7
tests/util/ThreadHive.cpp

@@ -24,8 +24,8 @@ public:
 
 
 	union
 	union
 	{
 	{
-		Atomic<U32> m_countAtomic;
-		U32 m_count;
+		Atomic<I32> m_countAtomic;
+		I32 m_count;
 	};
 	};
 };
 };
 
 
@@ -51,14 +51,15 @@ static void taskToWaitOn(void* arg, U32, ThreadHive& hive)
 	ThreadHiveTestContext* ctx = static_cast<ThreadHiveTestContext*>(arg);
 	ThreadHiveTestContext* ctx = static_cast<ThreadHiveTestContext*>(arg);
 	std::this_thread::sleep_for(std::chrono::seconds(1));
 	std::this_thread::sleep_for(std::chrono::seconds(1));
 	ctx->m_count = 10;
 	ctx->m_count = 10;
-	std::this_thread::sleep_for(std::chrono::seconds(1));
+	std::this_thread::sleep_for(std::chrono::milliseconds(100));
 }
 }
 
 
 //==============================================================================
 //==============================================================================
-static void taskToWait(void* arg, U32, ThreadHive& hive)
+static void taskToWait(void* arg, U32 threadId, ThreadHive& hive)
 {
 {
 	ThreadHiveTestContext* ctx = static_cast<ThreadHiveTestContext*>(arg);
 	ThreadHiveTestContext* ctx = static_cast<ThreadHiveTestContext*>(arg);
-	ANKI_TEST_EXPECT_EQ(ctx->m_count, 10);
+	U prev = ctx->m_countAtomic.fetchAdd(1);
+	ANKI_TEST_EXPECT_GEQ(prev, 10);
 }
 }
 
 
 //==============================================================================
 //==============================================================================
@@ -69,10 +70,11 @@ ANKI_TEST(Util, ThreadHive)
 	ThreadHive hive(threadCount, alloc);
 	ThreadHive hive(threadCount, alloc);
 
 
 	// Simple test
 	// Simple test
+	if(1)
 	{
 	{
 		ThreadHiveTestContext ctx;
 		ThreadHiveTestContext ctx;
 		ctx.m_countAtomic.set(0);
 		ctx.m_countAtomic.set(0);
-		const U INITIAL_TASK_COUNT = 10;
+		const U INITIAL_TASK_COUNT = 100;
 
 
 		for(U i = 0; i < INITIAL_TASK_COUNT; ++i)
 		for(U i = 0; i < INITIAL_TASK_COUNT; ++i)
 		{
 		{
@@ -85,7 +87,7 @@ ANKI_TEST(Util, ThreadHive)
 	}
 	}
 
 
 	// Depedency tests
 	// Depedency tests
-	if(0)
+	if(1)
 	{
 	{
 		ThreadHiveTestContext ctx;
 		ThreadHiveTestContext ctx;
 		ctx.m_count = 0;
 		ctx.m_count = 0;
@@ -109,7 +111,65 @@ ANKI_TEST(Util, ThreadHive)
 
 
 		hive.submitTasks(&dtasks[0], DEP_TASKS);
 		hive.submitTasks(&dtasks[0], DEP_TASKS);
 
 
+		// Again
+		ThreadHiveTask dtasks2[DEP_TASKS];
+		for(U i = 0; i < DEP_TASKS; ++i)
+		{
+			dtasks2[i].m_callback = taskToWait;
+			dtasks2[i].m_argument = &ctx;
+			dtasks2[i].m_inDependencies = WArray<ThreadHiveDependencyHandle>(
+				&dtasks[i].m_outDependency, 1);
+		}
+
+		hive.submitTasks(&dtasks2[0], DEP_TASKS);
+
 		hive.waitAllTasks();
 		hive.waitAllTasks();
+
+		ANKI_TEST_EXPECT_EQ(ctx.m_countAtomic.get(), DEP_TASKS * 2 + 10);
+	}
+
+	// Fuzzy test
+	if(1)
+	{
+		ThreadHiveTestContext ctx;
+		ctx.m_count = 0;
+
+		I number = 0;
+		ThreadHiveDependencyHandle dep = 0;
+
+		const U SUBMISSION_COUNT = 100;
+		const U TASK_COUNT = 100;
+		for(U i = 0; i < SUBMISSION_COUNT; ++i)
+		{
+			for(U j = 0; j < TASK_COUNT; ++j)
+			{
+				Bool cb = rand() % 2;
+
+				number = (cb) ? number + 2 : number - 2;
+
+				ThreadHiveTask task;
+				task.m_callback = (cb) ? incNumber : decNumber;
+				task.m_argument = &ctx;
+
+				if((rand() % 3) == 0 && j > 0)
+				{
+					task.m_inDependencies =
+						WArray<ThreadHiveDependencyHandle>(&dep, 1);
+				}
+
+				hive.submitTasks(&task, 1);
+
+				if((rand() % 7) == 0)
+				{
+					dep = task.m_outDependency;
+				}
+			}
+
+			dep = 0;
+			hive.waitAllTasks();
+		}
+
+		ANKI_TEST_EXPECT_EQ(ctx.m_countAtomic.get(), number);
 	}
 	}
 }
 }