Browse Source

Fix join on *nix.

Jeroen van Rijn 3 years ago
parent
commit
56e3b7cb7d
1 changed files with 14 additions and 5 deletions
  1. 14 5
      core/thread/thread_unix.odin

+ 14 - 5
core/thread/thread_unix.odin

@@ -7,6 +7,8 @@ import "core:intrinsics"
 import "core:sync"
 import "core:sys/unix"
 
+CAS :: intrinsics.atomic_compare_exchange_strong
+
 Thread_State :: enum u8 {
 	Started,
 	Joined,
@@ -98,7 +100,7 @@ _create :: proc(procedure: Thread_Proc, priority := Thread_Priority.Normal) -> ^
 }
 
 _start :: proc(t: ^Thread) {
-	sync.guard(&t.mutex)
+	// sync.guard(&t.mutex)
 	t.flags += { .Started }
 	sync.signal(&t.cond)
 }
@@ -108,15 +110,22 @@ _is_done :: proc(t: ^Thread) -> bool {
 }
 
 _join :: proc(t: ^Thread) {
-	sync.guard(&t.mutex)
+	// sync.guard(&t.mutex)
 
-	if .Joined in t.flags || unix.pthread_equal(unix.pthread_self(), t.unix_thread) {
+	if unix.pthread_equal(unix.pthread_self(), t.unix_thread) {
 		return
 	}
 
-	unix.pthread_join(t.unix_thread, nil)
+	// Preserve other flags besides `.Joined`, like `.Started`.
+	unjoined := intrinsics.atomic_load(&t.flags) - {.Joined}
+	joined   := unjoined + {.Joined}
 
-	t.flags += { .Joined }
+	// Try to set `t.flags` from unjoined to joined. If it returns joined,
+	// it means the previous value had that flag set and we can return.
+	if res, ok := CAS(&t.flags, unjoined, joined); res == joined && !ok {
+		return
+	}
+	unix.pthread_join(t.unix_thread, nil)
 }
 
 _join_multiple :: proc(threads: ..^Thread) {