Browse Source

Add task-stopping functionality to `thread.Pool`

Feoramund 1 year ago
parent
commit
558c330028
1 changed files with 115 additions and 15 deletions
  1. 115 15
      core/thread/thread_pool.odin

+ 115 - 15
core/thread/thread_pool.odin

@@ -44,6 +44,29 @@ Pool :: struct {
 	tasks_done: [dynamic]Task,
 }
 
+Pool_Thread_Data :: struct {
+	pool: ^Pool,
+	task: Task,
+}
+
+@(private="file")
+pool_thread_runner :: proc(t: ^Thread) {
+	data := cast(^Pool_Thread_Data)t.data
+	pool := data.pool
+
+	for intrinsics.atomic_load(&pool.is_running) {
+		sync.wait(&pool.sem_available)
+
+		if task, ok := pool_pop_waiting(pool); ok {
+			data.task = task
+			pool_do_work(pool, task)
+			data.task = {}
+		}
+	}
+
+	sync.post(&pool.sem_available, 1)
+}
+
 // Once initialized, the pool's memory address is not allowed to change until
 // it is destroyed. 
 //
@@ -58,21 +81,11 @@ pool_init :: proc(pool: ^Pool, allocator: mem.Allocator, thread_count: int) {
 	pool.is_running = true
 
 	for _, i in pool.threads {
-		t := create(proc(t: ^Thread) {
-			pool := (^Pool)(t.data)
-
-			for intrinsics.atomic_load(&pool.is_running) {
-				sync.wait(&pool.sem_available)
-
-				if task, ok := pool_pop_waiting(pool); ok {
-					pool_do_work(pool, task)
-				}
-			}
-
-			sync.post(&pool.sem_available, 1)
-		})
+		t := create(pool_thread_runner)
+		data := new(Pool_Thread_Data)
+		data.pool = pool
 		t.user_index = i
-		t.data = pool
+		t.data = data
 		pool.threads[i] = t
 	}
 }
@@ -82,6 +95,8 @@ pool_destroy :: proc(pool: ^Pool) {
 	delete(pool.tasks_done)
 
 	for &t in pool.threads {
+		data := cast(^Pool_Thread_Data)t.data
+		free(data, pool.allocator)
 		destroy(t)
 	}
 
@@ -103,7 +118,7 @@ pool_join :: proc(pool: ^Pool) {
 
 	yield()
 
-started_count: int
+	started_count: int
 	for started_count < len(pool.threads) {
 		started_count = 0
 		for t in pool.threads {
@@ -138,6 +153,91 @@ pool_add_task :: proc(pool: ^Pool, allocator: mem.Allocator, procedure: Task_Pro
 	sync.post(&pool.sem_available, 1)
 }
 
+// Forcibly stop a running task by its user index.
+//
+// This will terminate the underlying thread. Ideally, you should use some
+// means of communication to stop a task, as thread termination may leave
+// resources unclaimed.
+//
+// The thread will be restarted to accept new tasks.
+//
+// Returns true if the task was found and terminated.
+pool_stop_task :: proc(pool: ^Pool, user_index: int, exit_code: int = 1) -> bool {
+	sync.guard(&pool.mutex)
+
+	for t, i in pool.threads {
+		data := cast(^Pool_Thread_Data)t.data
+		if data.task.user_index == user_index && data.task.procedure != nil {
+			terminate(t, exit_code)
+
+			append(&pool.tasks_done, data.task)
+			intrinsics.atomic_add(&pool.num_done, 1)
+			intrinsics.atomic_sub(&pool.num_outstanding, 1)
+			intrinsics.atomic_sub(&pool.num_in_processing, 1)
+
+			destroy(t)
+
+			replacement := create(pool_thread_runner)
+			replacement.user_index = t.user_index
+			replacement.data = data
+			pool.threads[i] = replacement
+
+			start(replacement)
+			return true
+		}
+	}
+
+	return false
+}
+
+// Forcibly stop all running tasks.
+//
+// The same notes from `pool_stop_task` apply here.
+pool_stop_all_tasks :: proc(pool: ^Pool, exit_code: int = 1) {
+	sync.guard(&pool.mutex)
+
+	for t, i in pool.threads {
+		data := cast(^Pool_Thread_Data)t.data
+		if data.task.procedure != nil {
+			terminate(t, exit_code)
+
+			append(&pool.tasks_done, data.task)
+			intrinsics.atomic_add(&pool.num_done, 1)
+			intrinsics.atomic_sub(&pool.num_outstanding, 1)
+			intrinsics.atomic_sub(&pool.num_in_processing, 1)
+
+			destroy(t)
+
+			replacement := create(pool_thread_runner)
+			replacement.user_index = t.user_index
+			replacement.data = data
+			pool.threads[i] = replacement
+
+			start(replacement)
+		}
+	}
+}
+
+// Force the pool to stop all of its threads and put it into a state where
+// it will no longer run any more tasks.
+//
+// The pool must still be destroyed after this.
+pool_shutdown :: proc(pool: ^Pool, exit_code: int = 1) {
+	sync.guard(&pool.mutex)
+
+	for t in pool.threads {
+		terminate(t, exit_code)
+
+		data := cast(^Pool_Thread_Data)t.data
+		if data.task.procedure != nil {
+			append(&pool.tasks_done, data.task)
+			intrinsics.atomic_add(&pool.num_done, 1)
+			intrinsics.atomic_sub(&pool.num_outstanding, 1)
+			intrinsics.atomic_sub(&pool.num_in_processing, 1)
+		}
+	}
+}
+
 // Number of tasks waiting to be processed. Only informational, mostly for
 // debugging. Don't rely on this value being consistent with other num_*
 // values.