taskschedulerinternal.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. // Copyright 2009-2021 Intel Corporation
  2. // SPDX-License-Identifier: Apache-2.0
  3. #pragma once
  4. #include "../../include/embree4/rtcore.h"
  5. #include "../sys/platform.h"
  6. #include "../sys/alloc.h"
  7. #include "../sys/barrier.h"
  8. #include "../sys/thread.h"
  9. #include "../sys/mutex.h"
  10. #include "../sys/condition.h"
  11. #include "../sys/ref.h"
  12. #include "../sys/atomic.h"
  13. #include "../math/range.h"
  14. #include <list>
  15. namespace embree
  16. {
  17. /* The tasking system exports some symbols to be used by the tutorials. Thus we
  18. hide is also in the API namespace when requested. */
  19. RTC_NAMESPACE_BEGIN
  20. struct TaskScheduler : public RefCount
  21. {
  22. ALIGNED_STRUCT_(64);
  23. friend class Device;
  24. static const size_t TASK_STACK_SIZE = 4*1024; //!< task structure stack
  25. static const size_t CLOSURE_STACK_SIZE = 512*1024; //!< stack for task closures
  26. struct Thread;
  27. /*! virtual interface for all tasks */
  28. struct TaskFunction {
  29. virtual void execute() = 0;
  30. };
  31. struct TaskGroupContext {
  32. TaskGroupContext() : cancellingException(nullptr) {}
  33. std::exception_ptr cancellingException;
  34. };
  35. /*! builds a task interface from a closure */
  36. template<typename Closure>
  37. struct ClosureTaskFunction : public TaskFunction
  38. {
  39. Closure closure;
  40. __forceinline ClosureTaskFunction (const Closure& closure) : closure(closure) {}
  41. void execute() { closure(); };
  42. };
  43. struct __aligned(64) Task
  44. {
  45. /*! states a task can be in */
  46. enum { DONE, INITIALIZED };
  47. /*! switch from one state to another */
  48. __forceinline void switch_state(int from, int to)
  49. {
  50. __memory_barrier();
  51. MAYBE_UNUSED bool success = state.compare_exchange_strong(from,to);
  52. assert(success);
  53. }
  54. /*! try to switch from one state to another */
  55. __forceinline bool try_switch_state(int from, int to) {
  56. __memory_barrier();
  57. return state.compare_exchange_strong(from,to);
  58. }
  59. /*! increment/decrement dependency counter */
  60. void add_dependencies(int n) {
  61. dependencies+=n;
  62. }
  63. /*! initialize all tasks to DONE state by default */
  64. __forceinline Task()
  65. : state(DONE) {}
  66. /*! construction of new task */
  67. __forceinline Task (TaskFunction* closure, Task* parent, TaskGroupContext* context, size_t stackPtr, size_t N)
  68. : dependencies(1), stealable(true), closure(closure), parent(parent), context(context), stackPtr(stackPtr), N(N)
  69. {
  70. if (parent) parent->add_dependencies(+1);
  71. switch_state(DONE,INITIALIZED);
  72. }
  73. /*! construction of stolen task, stealing thread will decrement initial dependency */
  74. __forceinline Task (TaskFunction* closure, Task* parent, TaskGroupContext* context)
  75. : dependencies(1), stealable(false), closure(closure), parent(parent), context(context), stackPtr(-1), N(1)
  76. {
  77. switch_state(DONE,INITIALIZED);
  78. }
  79. /*! try to steal this task */
  80. bool try_steal(Task& child)
  81. {
  82. if (!stealable) return false;
  83. if (!try_switch_state(INITIALIZED,DONE)) return false;
  84. new (&child) Task(closure, this, context);
  85. return true;
  86. }
  87. /*! run this task */
  88. dll_export void run(Thread& thread);
  89. void run_internal(Thread& thread);
  90. public:
  91. std::atomic<int> state; //!< state this task is in
  92. std::atomic<int> dependencies; //!< dependencies to wait for
  93. std::atomic<bool> stealable; //!< true if task can be stolen
  94. TaskFunction* closure; //!< the closure to execute
  95. Task* parent; //!< parent task to signal when we are finished
  96. TaskGroupContext* context;
  97. size_t stackPtr; //!< stack location where closure is stored
  98. size_t N; //!< approximative size of task
  99. };
  100. struct TaskQueue
  101. {
  102. TaskQueue ()
  103. : left(0), right(0), stackPtr(0) {}
  104. __forceinline void* alloc(size_t bytes, size_t align = 64)
  105. {
  106. size_t ofs = bytes + ((align - stackPtr) & (align-1));
  107. //if (stackPtr + ofs > CLOSURE_STACK_SIZE)
  108. // throw std::runtime_error("closure stack overflow");
  109. if (stackPtr + ofs > CLOSURE_STACK_SIZE) {
  110. abort();
  111. }
  112. stackPtr += ofs;
  113. return &stack[stackPtr-bytes];
  114. }
  115. template<typename Closure>
  116. __forceinline void push_right(Thread& thread, const size_t size, const Closure& closure, TaskGroupContext* context)
  117. {
  118. //if (right >= TASK_STACK_SIZE)
  119. // throw std::runtime_error("task stack overflow");
  120. if (right >= TASK_STACK_SIZE) {
  121. abort();
  122. }
  123. /* allocate new task on right side of stack */
  124. size_t oldStackPtr = stackPtr;
  125. TaskFunction* func = new (alloc(sizeof(ClosureTaskFunction<Closure>))) ClosureTaskFunction<Closure>(closure);
  126. new (&tasks[right.load()]) Task(func,thread.task,context,oldStackPtr,size);
  127. right++;
  128. /* also move left pointer */
  129. if (left >= right-1) left = right-1;
  130. }
  131. dll_export bool execute_local(Thread& thread, Task* parent);
  132. bool execute_local_internal(Thread& thread, Task* parent);
  133. bool steal(Thread& thread);
  134. size_t getTaskSizeAtLeft();
  135. bool empty() { return right == 0; }
  136. public:
  137. /* task stack */
  138. Task tasks[TASK_STACK_SIZE];
  139. __aligned(64) std::atomic<size_t> left; //!< threads steal from left
  140. __aligned(64) std::atomic<size_t> right; //!< new tasks are added to the right
  141. /* closure stack */
  142. __aligned(64) char stack[CLOSURE_STACK_SIZE];
  143. size_t stackPtr;
  144. };
  145. /*! thread local structure for each thread */
  146. struct Thread
  147. {
  148. ALIGNED_STRUCT_(64);
  149. Thread (size_t threadIndex, const Ref<TaskScheduler>& scheduler)
  150. : threadIndex(threadIndex), task(nullptr), scheduler(scheduler) {}
  151. __forceinline size_t threadCount() {
  152. return scheduler->threadCounter;
  153. }
  154. size_t threadIndex; //!< ID of this thread
  155. TaskQueue tasks; //!< local task queue
  156. Task* task; //!< current active task
  157. Ref<TaskScheduler> scheduler; //!< pointer to task scheduler
  158. };
  159. /*! pool of worker threads */
  160. struct ThreadPool
  161. {
  162. ThreadPool (bool set_affinity);
  163. ~ThreadPool ();
  164. /*! starts the threads */
  165. dll_export void startThreads();
  166. /*! sets number of threads to use */
  167. void setNumThreads(size_t numThreads, bool startThreads = false);
  168. /*! adds a task scheduler object for scheduling */
  169. dll_export void add(const Ref<TaskScheduler>& scheduler);
  170. /*! remove the task scheduler object again */
  171. dll_export void remove(const Ref<TaskScheduler>& scheduler);
  172. /*! returns number of threads of the thread pool */
  173. size_t size() const { return numThreads; }
  174. /*! main loop for all threads */
  175. void thread_loop(size_t threadIndex);
  176. private:
  177. std::atomic<size_t> numThreads;
  178. std::atomic<size_t> numThreadsRunning;
  179. bool set_affinity;
  180. std::atomic<bool> running;
  181. std::vector<thread_t> threads;
  182. private:
  183. MutexSys mutex;
  184. ConditionSys condition;
  185. std::list<Ref<TaskScheduler> > schedulers;
  186. };
  187. TaskScheduler ();
  188. ~TaskScheduler ();
  189. /*! initializes the task scheduler */
  190. static void create(size_t numThreads, bool set_affinity, bool start_threads);
  191. /*! destroys the task scheduler again */
  192. static void destroy();
  193. /*! lets new worker threads join the tasking system */
  194. void join();
  195. void reset();
  196. /*! let a worker thread allocate a thread index */
  197. dll_export ssize_t allocThreadIndex();
  198. /*! wait for some number of threads available (threadCount includes main thread) */
  199. void wait_for_threads(size_t threadCount);
  200. /*! thread loop for all worker threads */
  201. void thread_loop(size_t threadIndex);
  202. /*! steals a task from a different thread */
  203. bool steal_from_other_threads(Thread& thread);
  204. template<typename Predicate, typename Body>
  205. static void steal_loop(Thread& thread, const Predicate& pred, const Body& body);
  206. /* spawn a new task at the top of the threads task stack */
  207. template<typename Closure>
  208. void spawn_root(const Closure& closure, TaskGroupContext* context, size_t size = 1, bool useThreadPool = true)
  209. {
  210. if (useThreadPool) startThreads();
  211. size_t threadIndex = allocThreadIndex();
  212. std::unique_ptr<Thread> mthread(new Thread(threadIndex,this)); // too large for stack allocation
  213. Thread& thread = *mthread;
  214. assert(threadLocal[threadIndex].load() == nullptr);
  215. threadLocal[threadIndex] = &thread;
  216. Thread* oldThread = swapThread(&thread);
  217. thread.tasks.push_right(thread,size,closure,context);
  218. {
  219. Lock<MutexSys> lock(mutex);
  220. anyTasksRunning++;
  221. hasRootTask = true;
  222. condition.notify_all();
  223. }
  224. if (useThreadPool) addScheduler(this);
  225. while (thread.tasks.execute_local(thread,nullptr));
  226. anyTasksRunning--;
  227. if (useThreadPool) removeScheduler(this);
  228. threadLocal[threadIndex] = nullptr;
  229. swapThread(oldThread);
  230. /* remember exception to throw */
  231. std::exception_ptr except = nullptr;
  232. if (context->cancellingException != nullptr) except = context->cancellingException;
  233. /* wait for all threads to terminate */
  234. threadCounter--;
  235. while (threadCounter > 0) yield();
  236. context->cancellingException = nullptr;
  237. /* re-throw proper exception */
  238. if (except != nullptr) {
  239. std::rethrow_exception(except);
  240. }
  241. }
  242. /* spawn a new task at the top of the threads task stack */
  243. template<typename Closure>
  244. static __forceinline void spawn(size_t size, const Closure& closure, TaskGroupContext* context)
  245. {
  246. Thread* thread = TaskScheduler::thread();
  247. if (likely(thread != nullptr)) thread->tasks.push_right(*thread,size,closure,context);
  248. else instance()->spawn_root(closure,context,size);
  249. }
  250. /* spawn a new task at the top of the threads task stack */
  251. template<typename Closure>
  252. static __forceinline void spawn(const Closure& closure, TaskGroupContext* taskGroupContext) {
  253. spawn(1,closure,taskGroupContext);
  254. }
  255. /* spawn a new task set */
  256. template<typename Index, typename Closure>
  257. static void spawn(const Index begin, const Index end, const Index blockSize, const Closure& closure, TaskGroupContext* context)
  258. {
  259. spawn(end-begin, [=]()
  260. {
  261. if (end-begin <= blockSize) {
  262. return closure(range<Index>(begin,end));
  263. }
  264. const Index center = (begin+end)/2;
  265. spawn(begin,center,blockSize,closure,context);
  266. spawn(center,end ,blockSize,closure,context);
  267. wait();
  268. },context);
  269. }
  270. /* work on spawned subtasks and wait until all have finished */
  271. dll_export static void wait();
  272. /* returns the ID of the current thread */
  273. dll_export static size_t threadID();
  274. /* returns the index (0..threadCount-1) of the current thread */
  275. dll_export static size_t threadIndex();
  276. /* returns the total number of threads */
  277. dll_export static size_t threadCount();
  278. private:
  279. /* returns the thread local task list of this worker thread */
  280. dll_export static Thread* thread();
  281. /* sets the thread local task list of this worker thread */
  282. dll_export static Thread* swapThread(Thread* thread);
  283. /*! returns the taskscheduler object to be used by the master thread */
  284. dll_export static TaskScheduler* instance();
  285. /*! starts the threads */
  286. dll_export static void startThreads();
  287. /*! adds a task scheduler object for scheduling */
  288. dll_export static void addScheduler(const Ref<TaskScheduler>& scheduler);
  289. /*! remove the task scheduler object again */
  290. dll_export static void removeScheduler(const Ref<TaskScheduler>& scheduler);
  291. private:
  292. std::vector<atomic<Thread*>> threadLocal;
  293. std::atomic<size_t> threadCounter;
  294. std::atomic<size_t> anyTasksRunning;
  295. std::atomic<bool> hasRootTask;
  296. MutexSys mutex;
  297. ConditionSys condition;
  298. private:
  299. static size_t g_numThreads;
  300. static __thread TaskScheduler* g_instance;
  301. static __thread Thread* thread_local_thread;
  302. static ThreadPool* threadPool;
  303. };
  304. RTC_NAMESPACE_END
  305. #if defined(RTC_NAMESPACE)
  306. using RTC_NAMESPACE::TaskScheduler;
  307. #endif
  308. }