ThreadHive.cpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. // Copyright (C) 2009-2021, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/Util/ThreadHive.h>
  6. #include <cstring>
  7. #include <cstdio>
  8. namespace anki
  9. {
  10. #define ANKI_ENABLE_HIVE_DEBUG_PRINT 0
  11. #if ANKI_ENABLE_HIVE_DEBUG_PRINT
  12. # define ANKI_HIVE_DEBUG_PRINT(...) printf(__VA_ARGS__)
  13. #else
  14. # define ANKI_HIVE_DEBUG_PRINT(...) ((void)0)
  15. #endif
  16. class ThreadHive::Thread
  17. {
  18. public:
  19. U32 m_id; ///< An ID
  20. anki::Thread m_thread; ///< Runs the workingFunc
  21. ThreadHive* m_hive;
  22. /// Constructor
  23. Thread(U32 id, ThreadHive* hive, Bool pinToCore)
  24. : m_id(id)
  25. , m_thread("anki_threadhive")
  26. , m_hive(hive)
  27. {
  28. ANKI_ASSERT(hive);
  29. m_thread.start(this, threadCallback, ThreadCoreAffinityMask(false).set(m_id, pinToCore));
  30. }
  31. private:
  32. /// Thread callaback
  33. static Error threadCallback(anki::ThreadCallbackInfo& info)
  34. {
  35. Thread& self = *static_cast<Thread*>(info.m_userData);
  36. self.m_hive->threadRun(self.m_id);
  37. return Error::NONE;
  38. }
  39. };
  40. class ThreadHive::Task
  41. {
  42. public:
  43. Task* m_next; ///< Next in the list.
  44. ThreadHiveTaskCallback m_cb; ///< Callback that defines the task.
  45. void* m_arg; ///< Args for the callback.
  46. ThreadHiveSemaphore* m_waitSemaphore;
  47. ThreadHiveSemaphore* m_signalSemaphore;
  48. };
  49. ThreadHive::ThreadHive(U32 threadCount, GenericMemoryPoolAllocator<U8> alloc, Bool pinToCores)
  50. : m_slowAlloc(alloc)
  51. , m_alloc(alloc.getMemoryPool().getAllocationCallback(), alloc.getMemoryPool().getAllocationCallbackUserData(),
  52. 1024 * 4)
  53. , m_threadCount(threadCount)
  54. {
  55. m_threads = reinterpret_cast<Thread*>(m_slowAlloc.allocate(sizeof(Thread) * threadCount));
  56. for(U32 i = 0; i < threadCount; ++i)
  57. {
  58. ::new(&m_threads[i]) Thread(i, this, pinToCores);
  59. }
  60. }
  61. ThreadHive::~ThreadHive()
  62. {
  63. if(m_threads)
  64. {
  65. {
  66. LockGuard<Mutex> lock(m_mtx);
  67. m_quit = true;
  68. // Wake the threads
  69. m_cvar.notifyAll();
  70. }
  71. // Join and destroy
  72. U32 threadCount = m_threadCount;
  73. while(threadCount-- != 0)
  74. {
  75. Error err = m_threads[threadCount].m_thread.join();
  76. (void)err;
  77. m_threads[threadCount].~Thread();
  78. }
  79. m_slowAlloc.deallocate(static_cast<void*>(m_threads), m_threadCount * sizeof(Thread));
  80. }
  81. }
  82. void ThreadHive::submitTasks(ThreadHiveTask* tasks, const U32 taskCount)
  83. {
  84. ANKI_ASSERT(tasks && taskCount > 0);
  85. // Allocate tasks
  86. Task* const htasks = m_alloc.newArray<Task>(taskCount);
  87. // Initialize tasks
  88. Task* prevTask = nullptr;
  89. for(U32 i = 0; i < taskCount; ++i)
  90. {
  91. const ThreadHiveTask& inTask = tasks[i];
  92. Task& outTask = htasks[i];
  93. outTask.m_next = nullptr;
  94. outTask.m_cb = inTask.m_callback;
  95. outTask.m_arg = inTask.m_argument;
  96. outTask.m_waitSemaphore = inTask.m_waitSemaphore;
  97. outTask.m_signalSemaphore = inTask.m_signalSemaphore;
  98. // Connect tasks
  99. if(prevTask)
  100. {
  101. prevTask->m_next = &outTask;
  102. }
  103. prevTask = &outTask;
  104. }
  105. // Push work
  106. {
  107. LockGuard<Mutex> lock(m_mtx);
  108. if(m_head != nullptr)
  109. {
  110. ANKI_ASSERT(m_tail && m_head);
  111. m_tail->m_next = &htasks[0];
  112. m_tail = &htasks[taskCount - 1];
  113. }
  114. else
  115. {
  116. ANKI_ASSERT(m_tail == nullptr);
  117. m_head = &htasks[0];
  118. m_tail = &htasks[taskCount - 1];
  119. }
  120. m_pendingTasks += taskCount;
  121. ANKI_HIVE_DEBUG_PRINT("submit tasks\n");
  122. }
  123. // Notify all threads
  124. m_cvar.notifyAll();
  125. }
  126. void ThreadHive::threadRun(U32 threadId)
  127. {
  128. Task* task = nullptr;
  129. while(!waitForWork(threadId, task))
  130. {
  131. // Run the task
  132. ANKI_ASSERT(task && task->m_cb);
  133. ANKI_HIVE_DEBUG_PRINT("tid: %lu will exec %p (udata: %p)\n", threadId, static_cast<void*>(task),
  134. static_cast<void*>(task->m_arg));
  135. task->m_cb(task->m_arg, threadId, *this, task->m_signalSemaphore);
  136. #if ANKI_EXTRA_CHECKS
  137. task->m_cb = nullptr;
  138. #endif
  139. // Signal the semaphore as early as possible
  140. if(task->m_signalSemaphore)
  141. {
  142. const U32 out = task->m_signalSemaphore->m_atomic.fetchSub(1);
  143. (void)out;
  144. ANKI_ASSERT(out > 0u);
  145. ANKI_HIVE_DEBUG_PRINT("\tsem is %u\n", out - 1u);
  146. }
  147. }
  148. ANKI_HIVE_DEBUG_PRINT("tid: %lu thread quits!\n", threadId);
  149. }
  150. Bool ThreadHive::waitForWork(U32 threadId, Task*& task)
  151. {
  152. LockGuard<Mutex> lock(m_mtx);
  153. ANKI_HIVE_DEBUG_PRINT("tid: %lu locking\n", threadId);
  154. // Complete the previous task
  155. if(task)
  156. {
  157. --m_pendingTasks;
  158. if(task->m_signalSemaphore || m_pendingTasks == 0)
  159. {
  160. // A dependency maybe got resolved or we are out of tasks. Wake them all
  161. ANKI_HIVE_DEBUG_PRINT("tid: %lu wake all\n", threadId);
  162. m_cvar.notifyAll();
  163. }
  164. }
  165. while(!m_quit && (task = getNewTask()) == nullptr)
  166. {
  167. ANKI_HIVE_DEBUG_PRINT("tid: %lu waiting\n", threadId);
  168. // Wait if there is no work.
  169. m_cvar.wait(m_mtx);
  170. }
  171. return m_quit;
  172. }
  173. ThreadHive::Task* ThreadHive::getNewTask()
  174. {
  175. Task* prevTask = nullptr;
  176. Task* task = m_head;
  177. while(task)
  178. {
  179. // Check if there are dependencies
  180. const Bool allDepsCompleted = task->m_waitSemaphore == nullptr || task->m_waitSemaphore->m_atomic.load() == 0;
  181. if(allDepsCompleted)
  182. {
  183. // Found something, pop it
  184. if(prevTask)
  185. {
  186. prevTask->m_next = task->m_next;
  187. }
  188. if(m_head == task)
  189. {
  190. m_head = task->m_next;
  191. }
  192. if(m_tail == task)
  193. {
  194. m_tail = prevTask;
  195. }
  196. #if ANKI_EXTRA_CHECKS
  197. task->m_next = nullptr;
  198. #endif
  199. break;
  200. }
  201. prevTask = task;
  202. task = task->m_next;
  203. }
  204. return task;
  205. }
  206. void ThreadHive::waitAllTasks()
  207. {
  208. ANKI_HIVE_DEBUG_PRINT("mt: waiting all\n");
  209. LockGuard<Mutex> lock(m_mtx);
  210. while(m_pendingTasks > 0)
  211. {
  212. m_cvar.wait(m_mtx);
  213. }
  214. m_head = nullptr;
  215. m_tail = nullptr;
  216. m_alloc.getMemoryPool().reset();
  217. ANKI_HIVE_DEBUG_PRINT("mt: done waiting all\n");
  218. }
  219. } // end namespace anki