2
0

ThreadHive.cpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. // Copyright (C) 2009-2022, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/Util/ThreadHive.h>
  6. #include <cstring>
  7. #include <cstdio>
  8. namespace anki {
  9. #define ANKI_ENABLE_HIVE_DEBUG_PRINT 0
  10. #if ANKI_ENABLE_HIVE_DEBUG_PRINT
  11. # define ANKI_HIVE_DEBUG_PRINT(...) printf(__VA_ARGS__)
  12. #else
  13. # define ANKI_HIVE_DEBUG_PRINT(...) ((void)0)
  14. #endif
  15. class ThreadHive::Thread
  16. {
  17. public:
  18. U32 m_id; ///< An ID
  19. anki::Thread m_thread; ///< Runs the workingFunc
  20. ThreadHive* m_hive;
  21. /// Constructor
  22. Thread(U32 id, ThreadHive* hive, Bool pinToCore)
  23. : m_id(id)
  24. , m_thread("anki_threadhive")
  25. , m_hive(hive)
  26. {
  27. ANKI_ASSERT(hive);
  28. m_thread.start(this, threadCallback, ThreadCoreAffinityMask(false).set(m_id, pinToCore));
  29. }
  30. private:
  31. /// Thread callaback
  32. static Error threadCallback(anki::ThreadCallbackInfo& info)
  33. {
  34. Thread& self = *static_cast<Thread*>(info.m_userData);
  35. self.m_hive->threadRun(self.m_id);
  36. return Error::NONE;
  37. }
  38. };
  39. class ThreadHive::Task
  40. {
  41. public:
  42. Task* m_next; ///< Next in the list.
  43. ThreadHiveTaskCallback m_cb; ///< Callback that defines the task.
  44. void* m_arg; ///< Args for the callback.
  45. ThreadHiveSemaphore* m_waitSemaphore;
  46. ThreadHiveSemaphore* m_signalSemaphore;
  47. };
  48. ThreadHive::ThreadHive(U32 threadCount, GenericMemoryPoolAllocator<U8> alloc, Bool pinToCores)
  49. : m_slowAlloc(alloc)
  50. , m_alloc(alloc.getMemoryPool().getAllocationCallback(), alloc.getMemoryPool().getAllocationCallbackUserData(),
  51. 1024 * 4)
  52. , m_threadCount(threadCount)
  53. {
  54. m_threads = reinterpret_cast<Thread*>(m_slowAlloc.allocate(sizeof(Thread) * threadCount));
  55. for(U32 i = 0; i < threadCount; ++i)
  56. {
  57. ::new(&m_threads[i]) Thread(i, this, pinToCores);
  58. }
  59. }
  60. ThreadHive::~ThreadHive()
  61. {
  62. if(m_threads)
  63. {
  64. {
  65. LockGuard<Mutex> lock(m_mtx);
  66. m_quit = true;
  67. // Wake the threads
  68. m_cvar.notifyAll();
  69. }
  70. // Join and destroy
  71. U32 threadCount = m_threadCount;
  72. while(threadCount-- != 0)
  73. {
  74. [[maybe_unused]] const Error err = m_threads[threadCount].m_thread.join();
  75. m_threads[threadCount].~Thread();
  76. }
  77. m_slowAlloc.deallocate(static_cast<void*>(m_threads), m_threadCount * sizeof(Thread));
  78. }
  79. }
  80. void ThreadHive::submitTasks(ThreadHiveTask* tasks, const U32 taskCount)
  81. {
  82. ANKI_ASSERT(tasks && taskCount > 0);
  83. // Allocate tasks
  84. Task* const htasks = m_alloc.newArray<Task>(taskCount);
  85. // Initialize tasks
  86. Task* prevTask = nullptr;
  87. for(U32 i = 0; i < taskCount; ++i)
  88. {
  89. const ThreadHiveTask& inTask = tasks[i];
  90. Task& outTask = htasks[i];
  91. outTask.m_next = nullptr;
  92. outTask.m_cb = inTask.m_callback;
  93. outTask.m_arg = inTask.m_argument;
  94. outTask.m_waitSemaphore = inTask.m_waitSemaphore;
  95. outTask.m_signalSemaphore = inTask.m_signalSemaphore;
  96. // Connect tasks
  97. if(prevTask)
  98. {
  99. prevTask->m_next = &outTask;
  100. }
  101. prevTask = &outTask;
  102. }
  103. // Push work
  104. {
  105. LockGuard<Mutex> lock(m_mtx);
  106. if(m_head != nullptr)
  107. {
  108. ANKI_ASSERT(m_tail && m_head);
  109. m_tail->m_next = &htasks[0];
  110. m_tail = &htasks[taskCount - 1];
  111. }
  112. else
  113. {
  114. ANKI_ASSERT(m_tail == nullptr);
  115. m_head = &htasks[0];
  116. m_tail = &htasks[taskCount - 1];
  117. }
  118. m_pendingTasks += taskCount;
  119. ANKI_HIVE_DEBUG_PRINT("submit tasks\n");
  120. }
  121. // Notify all threads
  122. m_cvar.notifyAll();
  123. }
  124. void ThreadHive::threadRun(U32 threadId)
  125. {
  126. Task* task = nullptr;
  127. while(!waitForWork(threadId, task))
  128. {
  129. // Run the task
  130. ANKI_ASSERT(task && task->m_cb);
  131. ANKI_HIVE_DEBUG_PRINT("tid: %lu will exec %p (udata: %p)\n", threadId, static_cast<void*>(task),
  132. static_cast<void*>(task->m_arg));
  133. task->m_cb(task->m_arg, threadId, *this, task->m_signalSemaphore);
  134. #if ANKI_EXTRA_CHECKS
  135. task->m_cb = nullptr;
  136. #endif
  137. // Signal the semaphore as early as possible
  138. if(task->m_signalSemaphore)
  139. {
  140. [[maybe_unused]] const U32 out = task->m_signalSemaphore->m_atomic.fetchSub(1);
  141. ANKI_ASSERT(out > 0u);
  142. ANKI_HIVE_DEBUG_PRINT("\tsem is %u\n", out - 1u);
  143. }
  144. }
  145. ANKI_HIVE_DEBUG_PRINT("tid: %lu thread quits!\n", threadId);
  146. }
  147. Bool ThreadHive::waitForWork([[maybe_unused]] U32 threadId, Task*& task)
  148. {
  149. LockGuard<Mutex> lock(m_mtx);
  150. ANKI_HIVE_DEBUG_PRINT("tid: %lu locking\n", threadId);
  151. // Complete the previous task
  152. if(task)
  153. {
  154. --m_pendingTasks;
  155. if(task->m_signalSemaphore || m_pendingTasks == 0)
  156. {
  157. // A dependency maybe got resolved or we are out of tasks. Wake them all
  158. ANKI_HIVE_DEBUG_PRINT("tid: %lu wake all\n", threadId);
  159. m_cvar.notifyAll();
  160. }
  161. }
  162. while(!m_quit && (task = getNewTask()) == nullptr)
  163. {
  164. ANKI_HIVE_DEBUG_PRINT("tid: %lu waiting\n", threadId);
  165. // Wait if there is no work.
  166. m_cvar.wait(m_mtx);
  167. }
  168. return m_quit;
  169. }
  170. ThreadHive::Task* ThreadHive::getNewTask()
  171. {
  172. Task* prevTask = nullptr;
  173. Task* task = m_head;
  174. while(task)
  175. {
  176. // Check if there are dependencies
  177. const Bool allDepsCompleted = task->m_waitSemaphore == nullptr || task->m_waitSemaphore->m_atomic.load() == 0;
  178. if(allDepsCompleted)
  179. {
  180. // Found something, pop it
  181. if(prevTask)
  182. {
  183. prevTask->m_next = task->m_next;
  184. }
  185. if(m_head == task)
  186. {
  187. m_head = task->m_next;
  188. }
  189. if(m_tail == task)
  190. {
  191. m_tail = prevTask;
  192. }
  193. #if ANKI_EXTRA_CHECKS
  194. task->m_next = nullptr;
  195. #endif
  196. break;
  197. }
  198. prevTask = task;
  199. task = task->m_next;
  200. }
  201. return task;
  202. }
  203. void ThreadHive::waitAllTasks()
  204. {
  205. ANKI_HIVE_DEBUG_PRINT("mt: waiting all\n");
  206. LockGuard<Mutex> lock(m_mtx);
  207. while(m_pendingTasks > 0)
  208. {
  209. m_cvar.wait(m_mtx);
  210. }
  211. m_head = nullptr;
  212. m_tail = nullptr;
  213. m_alloc.getMemoryPool().reset();
  214. ANKI_HIVE_DEBUG_PRINT("mt: done waiting all\n");
  215. }
  216. } // end namespace anki