Browse Source

Improved `ThreadPool` implementation

gingerBill 4 years ago
parent
commit
25c3fd48f0
2 changed files with 161 additions and 61 deletions
  1. 102 60
      src/thread_pool.cpp
  2. 59 1
      src/threading.cpp

+ 102 - 60
src/thread_pool.cpp

@@ -4,85 +4,127 @@
 typedef WORKER_TASK_PROC(WorkerTaskProc);
 
 struct WorkerTask {
-	WorkerTask *next_task;
+	WorkerTask *    next;
 	WorkerTaskProc *do_work;
-	void *data;
+	void *          data;
 };
 
 struct ThreadPool {
-	std::atomic<isize> outstanding_task_count;
-	WorkerTask *volatile next_task;
-	BlockingMutex task_list_mutex;
-	isize thread_count;
+	gbAllocator   allocator;
+	BlockingMutex mutex;
+	Condition     task_cond;
+	
+	Slice<Thread> threads;
+	
+	WorkerTask *task_queue;
+	
+	std::atomic<isize> ready;
 };
 
-void thread_pool_thread_entry(ThreadPool *pool) {
-	while (pool->outstanding_task_count) {
-		if (!pool->next_task) {
-			yield(); // No need to grab the mutex.
-		} else {
-			mutex_lock(&pool->task_list_mutex);
-
-			if (pool->next_task) {
-				WorkerTask *task = pool->next_task;
-				pool->next_task = task->next_task;
-				mutex_unlock(&pool->task_list_mutex);
-				task->do_work(task->data);
-				pool->outstanding_task_count.fetch_sub(1);
-				gb_free(heap_allocator(), task);
-			} else {
-				mutex_unlock(&pool->task_list_mutex);
-			}
-		}
+void thread_pool_init(ThreadPool *pool, gbAllocator const &a, isize thread_count, char const *worker_name) {
+	pool->allocator = a;
+	mutex_init(&pool->mutex);
+	condition_init(&pool->task_cond);
+	
+	slice_init(&pool->threads, a, thread_count);
+	for_array(i, pool->threads) {
+		Thread *t = &pool->threads[i];
+		thread_init(t);
 	}
 }
 
-#if defined(GB_SYSTEM_WINDOWS)
-	DWORD __stdcall thread_pool_thread_entry_platform(void *arg) {
-		thread_pool_thread_entry((ThreadPool *) arg);
-		return 0;
-	}
+void thread_pool_destroy(ThreadPool *pool) {
+	condition_broadcast(&pool->task_cond);
 
-	void thread_pool_start_thread(ThreadPool *pool) {
-		CloseHandle(CreateThread(NULL, 0, thread_pool_thread_entry_platform, pool, 0, NULL));
+	for_array(i, pool->threads) {
+		Thread *t = &pool->threads[i];
+		thread_join(t);
 	}
-#else
-	void *thread_pool_thread_entry_platform(void *arg) {
-		thread_pool_thread_entry((ThreadPool *) arg);
-		return NULL;
+	
+	
+	for_array(i, pool->threads) {
+		Thread *t = &pool->threads[i];
+		thread_destroy(t);
 	}
+	
+	gb_free(pool->allocator, pool->threads.data);
+	condition_destroy(&pool->task_cond);
+	mutex_destroy(&pool->mutex);
+}
 
-	void thread_pool_start_thread(ThreadPool *pool) {
-		pthread_t handle;
-		pthread_create(&handle, NULL, thread_pool_thread_entry_platform, pool);
-		pthread_detach(handle);
+bool thread_pool_queue_empty(ThreadPool *pool) {
+	return pool->task_queue == nullptr;
+}
+WorkerTask *thread_pool_queue_pop(ThreadPool *pool) {
+	GB_ASSERT(pool->task_queue != nullptr);
+	WorkerTask *task = pool->task_queue;
+	pool->task_queue = task->next;
+	return task;
+}
+void thread_pool_queue_push(ThreadPool *pool, WorkerTask *task) {
+	GB_ASSERT(task != nullptr);
+	task->next = pool->task_queue;
+	pool->task_queue = task;
+}
+
+bool thread_pool_add_task(ThreadPool *pool, WorkerTaskProc *proc, void *data) {
+	GB_ASSERT(proc != nullptr);
+	mutex_lock(&pool->mutex);
+	WorkerTask *task = gb_alloc_item(permanent_allocator(), WorkerTask);
+	if (task == nullptr) {
+		mutex_unlock(&pool->mutex);
+		GB_PANIC("Out of memory");
+		return false;
 	}
-#endif
+	task->do_work = proc;
+	task->data = data;
+		
+	thread_pool_queue_push(pool, task);
+	pool->ready++;
+	mutex_unlock(&pool->mutex);
+	condition_signal(&pool->task_cond);
+	return true;
+}	
 
-void thread_pool_init(ThreadPool *pool, gbAllocator const &a, isize thread_count, char const *worker_prefix) {
-	memset(pool, 0, sizeof(ThreadPool));
-	mutex_init(&pool->task_list_mutex);
-	pool->thread_count = thread_count;
+THREAD_PROC(thread_pool_thread_proc) {
+	ThreadPool *pool = cast(ThreadPool *)thread->user_data;
+	
+	for (;;) {
+		mutex_lock(&pool->mutex);
+		
+		while (pool->ready > 0 && thread_pool_queue_empty(pool)) {
+			condition_wait(&pool->task_cond, &pool->mutex);
+		}
+		if (pool->ready == 0 && thread_pool_queue_empty(pool)) {
+			mutex_unlock(&pool->mutex);
+			return 0;
+		}
+		
+		WorkerTask *task = thread_pool_queue_pop(pool);
+		mutex_unlock(&pool->mutex);
+	
+		task->do_work(task->data);
+		if (--pool->ready == 0) {
+			condition_broadcast(&pool->task_cond);
+		}
+	}
 }
 
-void thread_pool_destroy(ThreadPool *pool) {
-	mutex_destroy(&pool->task_list_mutex);
-}
 
 void thread_pool_wait(ThreadPool *pool) {
-	for (int i = 0; i < pool->thread_count; i++) {
-		thread_pool_start_thread(pool);
+	for_array(i, pool->threads) {
+		Thread *t = &pool->threads[i];
+		thread_start(t, thread_pool_thread_proc, pool);
+	}
+	
+	Thread dummy = {};
+	dummy.proc = thread_pool_thread_proc;
+	dummy.user_data = pool;
+	thread_pool_thread_proc(&dummy);
+	
+	for_array(i, pool->threads) {
+		Thread *t = &pool->threads[i];
+		thread_join(t);
 	}
-	thread_pool_thread_entry(pool);
 }
 
-void thread_pool_add_task(ThreadPool *pool, WorkerTaskProc *proc, void *data) {
-	WorkerTask *task = gb_alloc_item(heap_allocator(), WorkerTask);
-	task->do_work = proc;
-	task->data = data;
-	mutex_lock(&pool->task_list_mutex);
-	task->next_task = pool->next_task;
-	pool->next_task = task;
-	pool->outstanding_task_count.fetch_add(1);
-	mutex_unlock(&pool->task_list_mutex);
-}

+ 59 - 1
src/threading.cpp

@@ -1,6 +1,7 @@
 struct BlockingMutex;
 struct RecursiveMutex;
 struct Semaphore;
+struct Condition;
 struct Thread;
 
 #define THREAD_PROC(name) isize name(struct Thread *thread)
@@ -41,6 +42,14 @@ void semaphore_post   (Semaphore *s, i32 count);
 void semaphore_wait   (Semaphore *s);
 void semaphore_release(Semaphore *s) { semaphore_post(s, 1); }
 
+
+void condition_init(Condition *c);
+void condition_destroy(Condition *c);
+void condition_broadcast(Condition *c);
+void condition_signal(Condition *c);
+void condition_wait(Condition *c, BlockingMutex *m);
+void condition_wait_with_timeout(Condition *c, BlockingMutex *m, u32 timeout_in_ms);
+
 u32  thread_current_id(void);
 
 void thread_init            (Thread *t);
@@ -108,6 +117,27 @@ void yield_process(void);
 	void semaphore_wait(Semaphore *s) {
 		WaitForSingleObjectEx(s->win32_handle, INFINITE, FALSE);
 	}
+	
+	struct Condition {
+		CONDITION_VARIABLE cond;
+	};
+	
+	void condition_init(Condition *c) {
+	}
+	void condition_destroy(Condition *c) {	
+	}
+	void condition_broadcast(Condition *c) {
+		WakeAllConditionVariable(&c->cond);
+	}
+	void condition_signal(Condition *c) {
+		WakeConditionVariable(&c->cond);
+	}
+	void condition_wait(Condition *c, BlockingMutex *m) {
+		SleepConditionVariableSRW(&c->cond, &m->srwlock, INFINITE, 0);
+	}
+	void condition_wait_with_timeout(Condition *c, BlockingMutex *m, u32 timeout_in_ms) {
+		SleepConditionVariableSRW(&c->cond, &m->srwlock, timeout_in_ms, 0);
+	}
 
 #else
 	struct BlockingMutex {
@@ -170,8 +200,36 @@ void yield_process(void);
 		void semaphore_post   (Semaphore *s, i32 count) { while (count --> 0) sem_post(&s->unix_handle); }
 		void semaphore_wait   (Semaphore *s)            { int i; do { i = sem_wait(&s->unix_handle); } while (i == -1 && errno == EINTR); }
 	#else
-	#error
+	#error Implement Semaphore for this platform
 	#endif
+		
+		
+	struct Condition {
+		pthread_cond_t pthread_cond;
+	};
+	
+	void condition_init(Condition *c) {
+		pthread_cond_init(&c->pthread_cond, NULL);
+	}
+	void condition_destroy(Condition *c) {	
+		pthread_cond_destroy(&c->pthread_cond);
+	}
+	void condition_broadcast(Condition *c) {
+		pthread_cond_broadcast(&c->pthread_cond);
+	}
+	void condition_signal(Condition *c) {
+		pthread_cond_signal(&c->pthread_cond);
+	}
+	void condition_wait(Condition *c, BlockingMutex *m) {
+		pthread_cond_wait(&c->pthread_cond, &m->pthread_mutex);
+	}
+	void condition_wait_with_timeout(Condition *c, BlockingMutex *m, u32 timeout_in_ms) {
+		struct timespec abstime = {};
+		timespec.tv_sec = timeout_in_ms/1000;
+		timespec.tv_nsec = cast(long)(timeout_in_ms%1000)*1e6;
+		pthread_cond_timedwait(&c->pthread_cond, &m->pthread_mutex, &abstime);
+		
+	}
 #endif