Browse Source

handle steal-fail vs steal-empty

Colin Davidson 1 year ago
parent
commit
4420128dc1
1 changed files with 24 additions and 12 deletions
  1. 24 12
      src/thread_pool.cpp

+ 24 - 12
src/thread_pool.cpp

@@ -10,6 +10,12 @@ gb_internal void thread_pool_destroy(ThreadPool *pool);
 gb_internal bool thread_pool_add_task(ThreadPool *pool, WorkerTaskProc *proc, void *data);
 gb_internal bool thread_pool_add_task(ThreadPool *pool, WorkerTaskProc *proc, void *data);
 gb_internal void thread_pool_wait(ThreadPool *pool);
 gb_internal void thread_pool_wait(ThreadPool *pool);
 
 
+enum GrabState {
+	GrabSuccess = 0,
+	GrabEmpty   = 1,
+	GrabFailed  = 2,
+};
+
 struct ThreadPool {
 struct ThreadPool {
 	gbAllocator threads_allocator;
 	gbAllocator threads_allocator;
 	Slice<Thread> threads;
 	Slice<Thread> threads;
@@ -82,7 +88,7 @@ void thread_pool_queue_push(Thread *thread, WorkerTask task) {
 	futex_broadcast(&thread->pool->tasks_available);
 	futex_broadcast(&thread->pool->tasks_available);
 }
 }
 
 
-bool thread_pool_queue_take(Thread *thread, WorkerTask *task) {
+GrabState thread_pool_queue_take(Thread *thread, WorkerTask *task) {
 	ssize_t bot = thread->queue.bottom.load(std::memory_order_relaxed) - 1;
 	ssize_t bot = thread->queue.bottom.load(std::memory_order_relaxed) - 1;
 	TaskRingBuffer *cur_ring = thread->queue.ring.load(std::memory_order_relaxed);
 	TaskRingBuffer *cur_ring = thread->queue.ring.load(std::memory_order_relaxed);
 	thread->queue.bottom.store(bot, std::memory_order_relaxed);
 	thread->queue.bottom.store(bot, std::memory_order_relaxed);
@@ -98,28 +104,28 @@ bool thread_pool_queue_take(Thread *thread, WorkerTask *task) {
 			if (!thread->queue.top.compare_exchange_strong(top, top + 1, std::memory_order_seq_cst, std::memory_order_relaxed)) {
 			if (!thread->queue.top.compare_exchange_strong(top, top + 1, std::memory_order_seq_cst, std::memory_order_relaxed)) {
 				// Race failed
 				// Race failed
 				thread->queue.bottom.store(bot + 1, std::memory_order_relaxed);
 				thread->queue.bottom.store(bot + 1, std::memory_order_relaxed);
-				return false;
+				return GrabEmpty;
 			}
 			}
 
 
 			thread->queue.bottom.store(bot + 1, std::memory_order_relaxed);
 			thread->queue.bottom.store(bot + 1, std::memory_order_relaxed);
-			return true;
+			return GrabSuccess;
 		}
 		}
 
 
 		// We got a task without hitting a race
 		// We got a task without hitting a race
-		return true;
+		return GrabSuccess;
 	} else {
 	} else {
 		// Queue is empty
 		// Queue is empty
 		thread->queue.bottom.store(bot + 1, std::memory_order_relaxed);
 		thread->queue.bottom.store(bot + 1, std::memory_order_relaxed);
-		return false;
+		return GrabEmpty;
 	}
 	}
 }
 }
 
 
-bool thread_pool_queue_steal(Thread *thread, WorkerTask *task) {
+GrabState thread_pool_queue_steal(Thread *thread, WorkerTask *task) {
 	ssize_t top = thread->queue.top.load(std::memory_order_acquire);
 	ssize_t top = thread->queue.top.load(std::memory_order_acquire);
 	std::atomic_thread_fence(std::memory_order_seq_cst);
 	std::atomic_thread_fence(std::memory_order_seq_cst);
 	ssize_t bot = thread->queue.bottom.load(std::memory_order_acquire);
 	ssize_t bot = thread->queue.bottom.load(std::memory_order_acquire);
 
 
-	bool ret = false;
+	GrabState ret = GrabEmpty;
 	if (top < bot) {
 	if (top < bot) {
 		// Queue is not empty
 		// Queue is not empty
 		TaskRingBuffer *cur_ring = thread->queue.ring.load(std::memory_order_consume);
 		TaskRingBuffer *cur_ring = thread->queue.ring.load(std::memory_order_consume);
@@ -127,9 +133,9 @@ bool thread_pool_queue_steal(Thread *thread, WorkerTask *task) {
 
 
 		if (!thread->queue.top.compare_exchange_strong(top, top + 1, std::memory_order_seq_cst, std::memory_order_relaxed)) {
 		if (!thread->queue.top.compare_exchange_strong(top, top + 1, std::memory_order_seq_cst, std::memory_order_relaxed)) {
 			// Race failed
 			// Race failed
-			ret = false;
+			ret = GrabFailed;
 		} else {
 		} else {
-			ret = true;
+			ret = GrabSuccess;
 		}
 		}
 	}
 	}
 	return ret;
 	return ret;
@@ -149,7 +155,7 @@ gb_internal void thread_pool_wait(ThreadPool *pool) {
 
 
 	while (pool->tasks_left.load(std::memory_order_acquire)) {
 	while (pool->tasks_left.load(std::memory_order_acquire)) {
 		// if we've got tasks on our queue, run them
 		// if we've got tasks on our queue, run them
-		while (thread_pool_queue_take(current_thread, &task)) {
+		while (!thread_pool_queue_take(current_thread, &task)) {
 			task.do_work(task.data);
 			task.do_work(task.data);
 			pool->tasks_left.fetch_sub(1, std::memory_order_release);
 			pool->tasks_left.fetch_sub(1, std::memory_order_release);
 		}
 		}
@@ -178,7 +184,7 @@ gb_internal THREAD_PROC(thread_pool_thread_proc) {
 		usize finished_tasks = 0;
 		usize finished_tasks = 0;
 		i32 state;
 		i32 state;
 
 
-		while (thread_pool_queue_take(current_thread, &task)) {
+		while (!thread_pool_queue_take(current_thread, &task)) {
 			task.do_work(task.data);
 			task.do_work(task.data);
 			pool->tasks_left.fetch_sub(1, std::memory_order_release);
 			pool->tasks_left.fetch_sub(1, std::memory_order_release);
 
 
@@ -200,7 +206,13 @@ gb_internal THREAD_PROC(thread_pool_thread_proc) {
 
 
 				Thread *thread = &pool->threads.data[idx];
 				Thread *thread = &pool->threads.data[idx];
 				WorkerTask task;
 				WorkerTask task;
-				if (thread_pool_queue_steal(thread, &task)) {
+
+				GrabState ret = thread_pool_queue_steal(thread, &task);
+				if (ret == GrabFailed) {
+					goto main_loop_continue;
+				} else if (ret == GrabEmpty) {
+					continue;
+				} else if (ret == GrabSuccess) {
 					task.do_work(task.data);
 					task.do_work(task.data);
 					pool->tasks_left.fetch_sub(1, std::memory_order_release);
 					pool->tasks_left.fetch_sub(1, std::memory_order_release);