Browse Source

Merge pull request #5289 from JackMordaunt/jfm-sync_chan_refactor

Jfm sync chan refactor
Laytan 3 months ago
parent
commit
fc7fc4d5cd
2 changed files with 251 additions and 31 deletions
  1. 74 31
      core/sync/chan/chan.odin
  2. 177 0
      tests/core/sync/chan/test_core_sync_chan.odin

+ 74 - 31
core/sync/chan/chan.odin

@@ -7,6 +7,14 @@ import "core:mem"
 import "core:sync"
 import "core:sync"
 import "core:math/rand"
 import "core:math/rand"
 
 
+when ODIN_TEST {
+/*
+Hook for testing _try_select_raw allowing the test harness to manipulate the
+channels prior to the select actually operating on them.
+*/
+__try_select_raw_pause : proc() = nil
+}
+
 /*
 /*
 Determines what operations `Chan` supports.
 Determines what operations `Chan` supports.
 */
 */
@@ -1105,15 +1113,27 @@ can_send :: proc "contextless" (c: ^Raw_Chan) -> bool {
 	return c.w_waiting == 0
 	return c.w_waiting == 0
 }
 }
 
 
+/*
+Specifies the direction of the selected channel.
+*/
+Select_Status :: enum {
+	None,
+	Recv,
+	Send,
+}
+
 
 
 /*
 /*
-Attempts to either send or receive messages on the specified channels.
+Attempts to either send or receive messages on the specified channels without blocking.
 
 
-`select_raw` first identifies which channels have messages ready to be received
+`try_select_raw` first identifies which channels have messages ready to be received
 and which are available for sending. It then randomly selects one operation
 and which are available for sending. It then randomly selects one operation
 (either a send or receive) to perform.
 (either a send or receive) to perform.
 
 
+If no channels have messages ready, the procedure is a noop.
+
 Note: Each message in `send_msgs` corresponds to the send channel at the same index in `sends`.
 Note: Each message in `send_msgs` corresponds to the send channel at the same index in `sends`.
+If the message is nil, corresponding send channel will be skipped.
 
 
 **Inputs**
 **Inputs**
 - `recv`: A slice of channels to read from
 - `recv`: A slice of channels to read from
@@ -1145,18 +1165,18 @@ Example:
 		// where the value from the read should be stored
 		// where the value from the read should be stored
 		received_value: int
 		received_value: int
 
 
-		idx, ok := chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
+		idx, ok := chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
 		fmt.println("SELECT:        ", idx, ok)
 		fmt.println("SELECT:        ", idx, ok)
 		fmt.println("RECEIVED VALUE ", received_value)
 		fmt.println("RECEIVED VALUE ", received_value)
 
 
-		idx, ok = chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
+		idx, ok = chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
 		fmt.println("SELECT:        ", idx, ok)
 		fmt.println("SELECT:        ", idx, ok)
 		fmt.println("RECEIVED VALUE ", received_value)
 		fmt.println("RECEIVED VALUE ", received_value)
 
 
 		// closing of a channel also affects the select operation
 		// closing of a channel also affects the select operation
 		chan.close(c)
 		chan.close(c)
 
 
-		idx, ok = chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
+		idx, ok = chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
 		fmt.println("SELECT:        ", idx, ok)
 		fmt.println("SELECT:        ", idx, ok)
 	}
 	}
 
 
@@ -1170,7 +1190,7 @@ Output:
 
 
 */
 */
 @(require_results)
 @(require_results)
-select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, ok: bool) #no_bounds_check {
+try_select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, status: Select_Status) #no_bounds_check {
 	Select_Op :: struct {
 	Select_Op :: struct {
 		idx:     int, // local to the slice that was given
 		idx:     int, // local to the slice that was given
 		is_recv: bool,
 		is_recv: bool,
@@ -1178,43 +1198,66 @@ select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []
 
 
 	candidate_count := builtin.len(recvs)+builtin.len(sends)
 	candidate_count := builtin.len(recvs)+builtin.len(sends)
 	candidates := ([^]Select_Op)(intrinsics.alloca(candidate_count*size_of(Select_Op), align_of(Select_Op)))
 	candidates := ([^]Select_Op)(intrinsics.alloca(candidate_count*size_of(Select_Op), align_of(Select_Op)))
-	count := 0
 
 
-	for c, i in recvs {
-		if can_recv(c) {
-			candidates[count] = {
-				is_recv = true,
-				idx     = i,
+	try_loop: for {
+		count := 0
+
+		for c, i in recvs {
+			if can_recv(c) {
+				candidates[count] = {
+					is_recv = true,
+					idx     = i,
+				}
+				count += 1
 			}
 			}
-			count += 1
 		}
 		}
-	}
 
 
-	for c, i in sends {
-		if can_send(c) {
-			candidates[count] = {
-				is_recv = false,
-				idx     = i,
+		for c, i in sends {
+			if i > builtin.len(send_msgs)-1 || send_msgs[i] == nil {
+				continue
+			}
+			if can_send(c)  {
+				candidates[count] = {
+					is_recv = false,
+					idx     = i,
+				}
+				count += 1
 			}
 			}
-			count += 1
 		}
 		}
-	}
 
 
-	if count == 0 {
-		return
-	}
+		if count == 0 {
+			return -1, .None
+		}
+
+		when ODIN_TEST {
+			if __try_select_raw_pause != nil {
+				__try_select_raw_pause()
+			}
+		}
 
 
-	select_idx = rand.int_max(count) if count > 0 else 0
+		candidate_idx := rand.int_max(count) if count > 0 else 0
 
 
-	sel := candidates[select_idx]
-	if sel.is_recv {
-		ok = recv_raw(recvs[sel.idx], recv_out)
-	} else {
-		ok = send_raw(sends[sel.idx], send_msgs[sel.idx])
+		sel := candidates[candidate_idx]
+		if sel.is_recv {
+			status = .Recv
+			if !try_recv_raw(recvs[sel.idx], recv_out) {
+				continue try_loop
+			}
+		} else {
+			status = .Send
+			if !try_send_raw(sends[sel.idx], send_msgs[sel.idx]) {
+				continue try_loop
+			}
+		}
+
+		return sel.idx, status
 	}
 	}
-	return
 }
 }
 
 
+@(require_results, deprecated = "use try_select_raw")
+select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, status: Select_Status) #no_bounds_check {
+	return try_select_raw(recvs, sends, send_msgs, recv_out)
+}
 
 
 /*
 /*
 `Raw_Queue` is a non-thread-safe queue implementation designed to store messages
 `Raw_Queue` is a non-thread-safe queue implementation designed to store messages

+ 177 - 0
tests/core/sync/chan/test_core_sync_chan.odin

@@ -272,3 +272,180 @@ test_accept_message_from_closed_buffered_chan :: proc(t: ^testing.T) {
 	testing.expect_value(t, result, 64)
 	testing.expect_value(t, result, 64)
 	testing.expect(t, ok)
 	testing.expect(t, ok)
 }
 }
+
+// Ensures that if any input channel is eligible to receive or send, the try_select_raw
+// operation will process it.
+@test
+test_try_select_raw_happy :: proc(t: ^testing.T) {
+	testing.set_fail_timeout(t, FAIL_TIME)
+
+	recv1, recv1_err := chan.create(chan.Chan(int), context.allocator)
+
+	assert(recv1_err == nil, "allocation failed")
+	defer chan.destroy(recv1)
+
+	recv2, recv2_err := chan.create(chan.Chan(int), 1, context.allocator)
+
+	assert(recv2_err == nil, "allocation failed")
+	defer chan.destroy(recv2)
+
+	send1, send1_err := chan.create(chan.Chan(int), 1, context.allocator)
+
+	assert(send1_err == nil, "allocation failed")
+	defer chan.destroy(send1)
+
+	msg := 42
+
+	// Preload recv2 to make it eligible for selection.
+	testing.expect_value(t, chan.send(recv2, msg), true)
+
+	recvs := [?]^chan.Raw_Chan{recv1, recv2}
+	sends := [?]^chan.Raw_Chan{send1}
+	msgs := [?]rawptr{&msg}
+	received_value: int
+
+	iteration_count := 0
+	did_none_count := 0
+	did_send_count := 0
+	did_receive_count := 0
+
+	// This loop is expected to iterate three times. Twice to do the receive and
+	// send operations, and a third time to exit.
+	receive_loop: for {
+
+		iteration_count += 1
+
+		idx, status := chan.try_select_raw(recvs[:], sends[:], msgs[:], &received_value)
+
+		switch status {
+		case .None:
+			did_none_count += 1
+			break receive_loop
+
+		case .Recv:
+			did_receive_count += 1
+			testing.expect_value(t, idx, 1)
+			testing.expect_value(t, received_value, msg)
+			received_value = 0
+
+		case .Send:
+			did_send_count += 1
+			testing.expect_value(t, idx, 0)
+			v, ok := chan.try_recv(send1)
+			testing.expect_value(t, ok, true)
+			testing.expect_value(t, v, msg)
+			msgs[0] = nil // nil out the message to avoid constantly resending the same value.
+		}
+	}
+
+	testing.expect_value(t, iteration_count, 3)
+	testing.expect_value(t, did_none_count, 1)
+	testing.expect_value(t, did_receive_count, 1)
+	testing.expect_value(t, did_send_count, 1)
+}
+
+// Ensures that if no input channels are eligible to receive or send, the
+// try_select_raw operation does not block.
+@test
+test_try_select_raw_default_state :: proc(t: ^testing.T) {
+	testing.set_fail_timeout(t, FAIL_TIME)
+
+	recv1, recv1_err := chan.create(chan.Chan(int), context.allocator)
+
+	assert(recv1_err == nil, "allocation failed")
+	defer chan.destroy(recv1)
+
+	recv2, recv2_err := chan.create(chan.Chan(int), context.allocator)
+
+	assert(recv2_err == nil, "allocation failed")
+	defer chan.destroy(recv2)
+
+	recvs := [?]^chan.Raw_Chan{recv1, recv2}
+	received_value: int
+
+	idx, status := chan.try_select_raw(recvs[:], nil, nil, &received_value)
+
+	testing.expect_value(t, idx, -1)
+	testing.expect_value(t, status, chan.Select_Status.None)
+}
+
+// Ensures that the operation will not block even if the input channels are
+// consumed by a competing thread; that is, a value is received from another
+// thread between calls to can_{send,recv} and try_{send,recv}_raw.
+@test
+test_try_select_raw_no_toctou :: proc(t: ^testing.T) {
+	testing.set_fail_timeout(t, FAIL_TIME)
+
+	// Trigger will be used to coordinate between the thief and the try_select.
+	trigger, trigger_err := chan.create(chan.Chan(any), context.allocator)
+
+	assert(trigger_err == nil, "allocation failed")
+	defer chan.destroy(trigger)
+
+	@(static)
+	__global_context_for_test: rawptr
+
+	__global_context_for_test = &trigger
+	defer __global_context_for_test = nil
+
+	// Setup the pause proc. This will be invoked after the input channels are
+	// checked for eligibility but before any channel operations are attempted.
+	chan.__try_select_raw_pause = proc() {
+		trigger := (cast(^chan.Chan(any))(__global_context_for_test))^
+
+		// Notify the thief that we are paused so that it can steal the value.
+		 _ = chan.send(trigger, "signal")
+
+		// Wait for comfirmation of the burglary.
+		_, _ = chan.recv(trigger)
+	}
+
+	defer chan.__try_select_raw_pause = nil
+
+	recv1, recv1_err := chan.create(chan.Chan(int), 1, context.allocator)
+
+	assert(recv1_err == nil, "allocation failed")
+	defer chan.destroy(recv1)
+
+	Context :: struct {
+		recv1: chan.Chan(int),
+		trigger: chan.Chan(any),
+	}
+
+	ctx := Context{
+		recv1 = recv1,
+		trigger = trigger,
+	}
+
+	// Spin up a thread that will steal the value from the input channel after
+	// try_select has already considered it eligible for selection.
+	thief := thread.create_and_start_with_poly_data(ctx, proc(ctx: Context) {
+		// Wait for eligibility check.
+		_, _ = chan.recv(ctx.trigger)
+
+		// Steal the value.
+		v, ok := chan.recv(ctx.recv1)
+
+		assert(ok, "recv1: expected to receive a value")
+		assert(v == 42, "recv1: unexpected receive value")
+
+		// Notify select that we have stolen the value and that it can proceed.
+		_ = chan.send(ctx.trigger, "signal")
+	})
+
+	recvs := [?]^chan.Raw_Chan{recv1}
+	received_value: int
+
+	// Ensure channel is eligible prior to entering the select.
+	testing.expect_value(t, chan.send(recv1, 42), true)
+
+	// Execute the try_select_raw, assert that we don't block, and that we receive
+	// .None status since the value was stolen by the other thread.
+	idx, status := chan.try_select_raw(recvs[:], nil, nil, &received_value)
+
+	testing.expect_value(t, idx, -1)
+	testing.expect_value(t, status, chan.Select_Status.None)
+
+	thread.join(thief)
+	thread.destroy(thief)
+}