ThreadHive.cpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. // Copyright (C) 2009-present, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/Util/ThreadHive.h>
  6. #include <AnKi/Util/String.h>
  7. #include <cstring>
  8. #include <cstdio>
  9. namespace anki {
  10. Atomic<U32> ThreadHive::m_uuid = {0};
  11. #define ANKI_ENABLE_HIVE_DEBUG_PRINT 0
  12. #if ANKI_ENABLE_HIVE_DEBUG_PRINT
  13. # define ANKI_HIVE_DEBUG_PRINT(...) printf(__VA_ARGS__)
  14. #else
  15. # define ANKI_HIVE_DEBUG_PRINT(...) ((void)0)
  16. #endif
  17. class ThreadHive::Thread
  18. {
  19. public:
  20. U32 m_id; ///< An ID
  21. anki::Thread m_thread; ///< Runs the workingFunc
  22. ThreadHive* m_hive;
  23. /// Constructor
  24. Thread(U32 id, ThreadHive* hive, Bool pinToCore, CString threadName)
  25. : m_id(id)
  26. , m_thread(threadName.cstr())
  27. , m_hive(hive)
  28. {
  29. ANKI_ASSERT(hive);
  30. m_thread.start(this, threadCallback, ThreadCoreAffinityMask(false).set(m_id, pinToCore));
  31. }
  32. private:
  33. /// Thread callaback
  34. static Error threadCallback(anki::ThreadCallbackInfo& info)
  35. {
  36. Thread& self = *static_cast<Thread*>(info.m_userData);
  37. self.m_hive->threadRun(self.m_id);
  38. return Error::kNone;
  39. }
  40. };
  41. class ThreadHive::Task
  42. {
  43. public:
  44. Task* m_next; ///< Next in the list.
  45. ThreadHiveTaskCallback m_cb; ///< Callback that defines the task.
  46. void* m_arg; ///< Args for the callback.
  47. ThreadHiveSemaphore* m_waitSemaphore;
  48. ThreadHiveSemaphore* m_signalSemaphore;
  49. };
  50. ThreadHive::ThreadHive(U32 threadCount, Bool pinToCores)
  51. : m_pool(stackPoolAllocate, nullptr, 4_KB)
  52. , m_threadCount(threadCount)
  53. {
  54. m_threads = static_cast<Thread*>(DefaultMemoryPool::getSingleton().allocate(sizeof(Thread) * threadCount, alignof(Thread)));
  55. const U32 uuid = m_uuid.fetchAdd(1);
  56. for(U32 i = 0; i < threadCount; ++i)
  57. {
  58. Array<Char, 32> threadName;
  59. snprintf(&threadName[0], threadName.getSize(), "Hive#%u/#%u", uuid, i);
  60. ::new(&m_threads[i]) Thread(i, this, pinToCores, &threadName[0]);
  61. }
  62. }
  63. ThreadHive::~ThreadHive()
  64. {
  65. if(m_threads)
  66. {
  67. {
  68. LockGuard<Mutex> lock(m_mtx);
  69. m_quit = true;
  70. // Wake the threads
  71. m_cvar.notifyAll();
  72. }
  73. // Join and destroy
  74. U32 threadCount = m_threadCount;
  75. while(threadCount-- != 0)
  76. {
  77. [[maybe_unused]] const Error err = m_threads[threadCount].m_thread.join();
  78. m_threads[threadCount].~Thread();
  79. }
  80. DefaultMemoryPool::getSingleton().free(static_cast<void*>(m_threads));
  81. }
  82. }
  83. void ThreadHive::submitTasks(ThreadHiveTask* tasks, const U32 taskCount)
  84. {
  85. ANKI_ASSERT(tasks && taskCount > 0);
  86. // Allocate tasks
  87. Task* const htasks = newArray<Task>(m_pool, taskCount);
  88. // Initialize tasks
  89. Task* prevTask = nullptr;
  90. for(U32 i = 0; i < taskCount; ++i)
  91. {
  92. const ThreadHiveTask& inTask = tasks[i];
  93. Task& outTask = htasks[i];
  94. outTask.m_next = nullptr;
  95. outTask.m_cb = inTask.m_callback;
  96. outTask.m_arg = inTask.m_argument;
  97. outTask.m_waitSemaphore = inTask.m_waitSemaphore;
  98. outTask.m_signalSemaphore = inTask.m_signalSemaphore;
  99. // Connect tasks
  100. if(prevTask)
  101. {
  102. prevTask->m_next = &outTask;
  103. }
  104. prevTask = &outTask;
  105. }
  106. // Push work
  107. {
  108. LockGuard<Mutex> lock(m_mtx);
  109. if(m_head != nullptr)
  110. {
  111. ANKI_ASSERT(m_tail && m_head);
  112. m_tail->m_next = &htasks[0];
  113. m_tail = &htasks[taskCount - 1];
  114. }
  115. else
  116. {
  117. ANKI_ASSERT(m_tail == nullptr);
  118. m_head = &htasks[0];
  119. m_tail = &htasks[taskCount - 1];
  120. }
  121. m_pendingTasks += taskCount;
  122. ANKI_HIVE_DEBUG_PRINT("submit tasks\n");
  123. }
  124. // Notify all threads
  125. m_cvar.notifyAll();
  126. }
  127. void ThreadHive::threadRun(U32 threadId)
  128. {
  129. Task* task = nullptr;
  130. while(!waitForWork(threadId, task))
  131. {
  132. // Run the task
  133. ANKI_ASSERT(task && task->m_cb);
  134. ANKI_HIVE_DEBUG_PRINT("tid: %lu will exec %p (udata: %p)\n", threadId, static_cast<void*>(task), static_cast<void*>(task->m_arg));
  135. task->m_cb(task->m_arg, threadId, *this, task->m_signalSemaphore);
  136. #if ANKI_EXTRA_CHECKS
  137. task->m_cb = nullptr;
  138. #endif
  139. // Signal the semaphore as early as possible
  140. if(task->m_signalSemaphore)
  141. {
  142. [[maybe_unused]] const U32 out = task->m_signalSemaphore->m_atomic.fetchSub(1);
  143. ANKI_ASSERT(out > 0u);
  144. ANKI_HIVE_DEBUG_PRINT("\tsem is %u\n", out - 1u);
  145. }
  146. }
  147. ANKI_HIVE_DEBUG_PRINT("tid: %lu thread quits!\n", threadId);
  148. }
  149. Bool ThreadHive::waitForWork([[maybe_unused]] U32 threadId, Task*& task)
  150. {
  151. LockGuard<Mutex> lock(m_mtx);
  152. ANKI_HIVE_DEBUG_PRINT("tid: %lu locking\n", threadId);
  153. // Complete the previous task
  154. if(task)
  155. {
  156. --m_pendingTasks;
  157. if(task->m_signalSemaphore || m_pendingTasks == 0)
  158. {
  159. // A dependency maybe got resolved or we are out of tasks. Wake them all
  160. ANKI_HIVE_DEBUG_PRINT("tid: %lu wake all\n", threadId);
  161. m_cvar.notifyAll();
  162. }
  163. }
  164. while(!m_quit && (task = getNewTask()) == nullptr)
  165. {
  166. ANKI_HIVE_DEBUG_PRINT("tid: %lu waiting\n", threadId);
  167. // Wait if there is no work.
  168. m_cvar.wait(m_mtx);
  169. }
  170. return m_quit;
  171. }
  172. ThreadHive::Task* ThreadHive::getNewTask()
  173. {
  174. Task* prevTask = nullptr;
  175. Task* task = m_head;
  176. while(task)
  177. {
  178. // Check if there are dependencies
  179. const Bool allDepsCompleted = task->m_waitSemaphore == nullptr || task->m_waitSemaphore->m_atomic.load() == 0;
  180. if(allDepsCompleted)
  181. {
  182. // Found something, pop it
  183. if(prevTask)
  184. {
  185. prevTask->m_next = task->m_next;
  186. }
  187. if(m_head == task)
  188. {
  189. m_head = task->m_next;
  190. }
  191. if(m_tail == task)
  192. {
  193. m_tail = prevTask;
  194. }
  195. #if ANKI_EXTRA_CHECKS
  196. task->m_next = nullptr;
  197. #endif
  198. break;
  199. }
  200. prevTask = task;
  201. task = task->m_next;
  202. }
  203. return task;
  204. }
  205. void ThreadHive::waitAllTasks()
  206. {
  207. ANKI_HIVE_DEBUG_PRINT("mt: waiting all\n");
  208. LockGuard<Mutex> lock(m_mtx);
  209. while(m_pendingTasks > 0)
  210. {
  211. m_cvar.wait(m_mtx);
  212. }
  213. m_head = nullptr;
  214. m_tail = nullptr;
  215. m_pool.reset();
  216. ANKI_HIVE_DEBUG_PRINT("mt: done waiting all\n");
  217. }
  218. } // end namespace anki