Browse Source

Merge pull request #5289 from JackMordaunt/jfm-sync_chan_refactor

Jfm sync chan refactor
Laytan 3 tháng trước cách đây
mục cha
commit
fc7fc4d5cd
2 tập tin đã thay đổi với 251 bổ sung31 xóa
  1. 74 31
      core/sync/chan/chan.odin
  2. 177 0
      tests/core/sync/chan/test_core_sync_chan.odin

+ 74 - 31
core/sync/chan/chan.odin

@@ -7,6 +7,14 @@ import "core:mem"
 import "core:sync"
 import "core:math/rand"
 
+when ODIN_TEST {
+/*
+Hook for testing _try_select_raw allowing the test harness to manipulate the
+channels prior to the select actually operating on them.
+*/
+__try_select_raw_pause : proc() = nil
+}
+
 /*
 Determines what operations `Chan` supports.
 */
@@ -1105,15 +1113,27 @@ can_send :: proc "contextless" (c: ^Raw_Chan) -> bool {
 	return c.w_waiting == 0
 }
 
+/*
+Specifies the direction of the selected channel.
+*/
+Select_Status :: enum {
+	None,
+	Recv,
+	Send,
+}
+
 
 /*
-Attempts to either send or receive messages on the specified channels.
+Attempts to either send or receive messages on the specified channels without blocking.
 
-`select_raw` first identifies which channels have messages ready to be received
+`try_select_raw` first identifies which channels have messages ready to be received
 and which are available for sending. It then randomly selects one operation
 (either a send or receive) to perform.
 
+If no channels have messages ready, the procedure is a noop.
+
 Note: Each message in `send_msgs` corresponds to the send channel at the same index in `sends`.
+If the message is nil, corresponding send channel will be skipped.
 
 **Inputs**
 - `recv`: A slice of channels to read from
@@ -1145,18 +1165,18 @@ Example:
 		// where the value from the read should be stored
 		received_value: int
 
-		idx, ok := chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
+		idx, ok := chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
 		fmt.println("SELECT:        ", idx, ok)
 		fmt.println("RECEIVED VALUE ", received_value)
 
-		idx, ok = chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
+		idx, ok = chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
 		fmt.println("SELECT:        ", idx, ok)
 		fmt.println("RECEIVED VALUE ", received_value)
 
 		// closing of a channel also affects the select operation
 		chan.close(c)
 
-		idx, ok = chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
+		idx, ok = chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value)
 		fmt.println("SELECT:        ", idx, ok)
 	}
 
@@ -1170,7 +1190,7 @@ Output:
 
 */
 @(require_results)
-select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, ok: bool) #no_bounds_check {
+try_select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, status: Select_Status) #no_bounds_check {
 	Select_Op :: struct {
 		idx:     int, // local to the slice that was given
 		is_recv: bool,
@@ -1178,43 +1198,66 @@ select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []
 
 	candidate_count := builtin.len(recvs)+builtin.len(sends)
 	candidates := ([^]Select_Op)(intrinsics.alloca(candidate_count*size_of(Select_Op), align_of(Select_Op)))
-	count := 0
 
-	for c, i in recvs {
-		if can_recv(c) {
-			candidates[count] = {
-				is_recv = true,
-				idx     = i,
+	try_loop: for {
+		count := 0
+
+		for c, i in recvs {
+			if can_recv(c) {
+				candidates[count] = {
+					is_recv = true,
+					idx     = i,
+				}
+				count += 1
 			}
-			count += 1
 		}
-	}
 
-	for c, i in sends {
-		if can_send(c) {
-			candidates[count] = {
-				is_recv = false,
-				idx     = i,
+		for c, i in sends {
+			if i > builtin.len(send_msgs)-1 || send_msgs[i] == nil {
+				continue
+			}
+			if can_send(c)  {
+				candidates[count] = {
+					is_recv = false,
+					idx     = i,
+				}
+				count += 1
 			}
-			count += 1
 		}
-	}
 
-	if count == 0 {
-		return
-	}
+		if count == 0 {
+			return -1, .None
+		}
+
+		when ODIN_TEST {
+			if __try_select_raw_pause != nil {
+				__try_select_raw_pause()
+			}
+		}
 
-	select_idx = rand.int_max(count) if count > 0 else 0
+		candidate_idx := rand.int_max(count) if count > 0 else 0
 
-	sel := candidates[select_idx]
-	if sel.is_recv {
-		ok = recv_raw(recvs[sel.idx], recv_out)
-	} else {
-		ok = send_raw(sends[sel.idx], send_msgs[sel.idx])
+		sel := candidates[candidate_idx]
+		if sel.is_recv {
+			status = .Recv
+			if !try_recv_raw(recvs[sel.idx], recv_out) {
+				continue try_loop
+			}
+		} else {
+			status = .Send
+			if !try_send_raw(sends[sel.idx], send_msgs[sel.idx]) {
+				continue try_loop
+			}
+		}
+
+		return sel.idx, status
 	}
-	return
 }
 
+@(require_results, deprecated = "use try_select_raw")
+select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, status: Select_Status) #no_bounds_check {
+	return try_select_raw(recvs, sends, send_msgs, recv_out)
+}
 
 /*
 `Raw_Queue` is a non-thread-safe queue implementation designed to store messages

+ 177 - 0
tests/core/sync/chan/test_core_sync_chan.odin

@@ -272,3 +272,180 @@ test_accept_message_from_closed_buffered_chan :: proc(t: ^testing.T) {
 	testing.expect_value(t, result, 64)
 	testing.expect(t, ok)
 }
+
+// Ensures that if any input channel is eligible to receive or send, the try_select_raw
+// operation will process it.
+@test
+test_try_select_raw_happy :: proc(t: ^testing.T) {
+	testing.set_fail_timeout(t, FAIL_TIME)
+
+	recv1, recv1_err := chan.create(chan.Chan(int), context.allocator)
+
+	assert(recv1_err == nil, "allocation failed")
+	defer chan.destroy(recv1)
+
+	recv2, recv2_err := chan.create(chan.Chan(int), 1, context.allocator)
+
+	assert(recv2_err == nil, "allocation failed")
+	defer chan.destroy(recv2)
+
+	send1, send1_err := chan.create(chan.Chan(int), 1, context.allocator)
+
+	assert(send1_err == nil, "allocation failed")
+	defer chan.destroy(send1)
+
+	msg := 42
+
+	// Preload recv2 to make it eligible for selection.
+	testing.expect_value(t, chan.send(recv2, msg), true)
+
+	recvs := [?]^chan.Raw_Chan{recv1, recv2}
+	sends := [?]^chan.Raw_Chan{send1}
+	msgs := [?]rawptr{&msg}
+	received_value: int
+
+	iteration_count := 0
+	did_none_count := 0
+	did_send_count := 0
+	did_receive_count := 0
+
+	// This loop is expected to iterate three times. Twice to do the receive and
+	// send operations, and a third time to exit.
+	receive_loop: for {
+
+		iteration_count += 1
+
+		idx, status := chan.try_select_raw(recvs[:], sends[:], msgs[:], &received_value)
+
+		switch status {
+		case .None:
+			did_none_count += 1
+			break receive_loop
+
+		case .Recv:
+			did_receive_count += 1
+			testing.expect_value(t, idx, 1)
+			testing.expect_value(t, received_value, msg)
+			received_value = 0
+
+		case .Send:
+			did_send_count += 1
+			testing.expect_value(t, idx, 0)
+			v, ok := chan.try_recv(send1)
+			testing.expect_value(t, ok, true)
+			testing.expect_value(t, v, msg)
+			msgs[0] = nil // nil out the message to avoid constantly resending the same value.
+		}
+	}
+
+	testing.expect_value(t, iteration_count, 3)
+	testing.expect_value(t, did_none_count, 1)
+	testing.expect_value(t, did_receive_count, 1)
+	testing.expect_value(t, did_send_count, 1)
+}
+
+// Ensures that if no input channels are eligible to receive or send, the
+// try_select_raw operation does not block.
+@test
+test_try_select_raw_default_state :: proc(t: ^testing.T) {
+	testing.set_fail_timeout(t, FAIL_TIME)
+
+	recv1, recv1_err := chan.create(chan.Chan(int), context.allocator)
+
+	assert(recv1_err == nil, "allocation failed")
+	defer chan.destroy(recv1)
+
+	recv2, recv2_err := chan.create(chan.Chan(int), context.allocator)
+
+	assert(recv2_err == nil, "allocation failed")
+	defer chan.destroy(recv2)
+
+	recvs := [?]^chan.Raw_Chan{recv1, recv2}
+	received_value: int
+
+	idx, status := chan.try_select_raw(recvs[:], nil, nil, &received_value)
+
+	testing.expect_value(t, idx, -1)
+	testing.expect_value(t, status, chan.Select_Status.None)
+}
+
+// Ensures that the operation will not block even if the input channels are
+// consumed by a competing thread; that is, a value is received from another
+// thread between calls to can_{send,recv} and try_{send,recv}_raw.
+@test
+test_try_select_raw_no_toctou :: proc(t: ^testing.T) {
+	testing.set_fail_timeout(t, FAIL_TIME)
+
+	// Trigger will be used to coordinate between the thief and the try_select.
+	trigger, trigger_err := chan.create(chan.Chan(any), context.allocator)
+
+	assert(trigger_err == nil, "allocation failed")
+	defer chan.destroy(trigger)
+
+	@(static)
+	__global_context_for_test: rawptr
+
+	__global_context_for_test = &trigger
+	defer __global_context_for_test = nil
+
+	// Setup the pause proc. This will be invoked after the input channels are
+	// checked for eligibility but before any channel operations are attempted.
+	chan.__try_select_raw_pause = proc() {
+		trigger := (cast(^chan.Chan(any))(__global_context_for_test))^
+
+		// Notify the thief that we are paused so that it can steal the value.
+		 _ = chan.send(trigger, "signal")
+
+		// Wait for comfirmation of the burglary.
+		_, _ = chan.recv(trigger)
+	}
+
+	defer chan.__try_select_raw_pause = nil
+
+	recv1, recv1_err := chan.create(chan.Chan(int), 1, context.allocator)
+
+	assert(recv1_err == nil, "allocation failed")
+	defer chan.destroy(recv1)
+
+	Context :: struct {
+		recv1: chan.Chan(int),
+		trigger: chan.Chan(any),
+	}
+
+	ctx := Context{
+		recv1 = recv1,
+		trigger = trigger,
+	}
+
+	// Spin up a thread that will steal the value from the input channel after
+	// try_select has already considered it eligible for selection.
+	thief := thread.create_and_start_with_poly_data(ctx, proc(ctx: Context) {
+		// Wait for eligibility check.
+		_, _ = chan.recv(ctx.trigger)
+
+		// Steal the value.
+		v, ok := chan.recv(ctx.recv1)
+
+		assert(ok, "recv1: expected to receive a value")
+		assert(v == 42, "recv1: unexpected receive value")
+
+		// Notify select that we have stolen the value and that it can proceed.
+		_ = chan.send(ctx.trigger, "signal")
+	})
+
+	recvs := [?]^chan.Raw_Chan{recv1}
+	received_value: int
+
+	// Ensure channel is eligible prior to entering the select.
+	testing.expect_value(t, chan.send(recv1, 42), true)
+
+	// Execute the try_select_raw, assert that we don't block, and that we receive
+	// .None status since the value was stolen by the other thread.
+	idx, status := chan.try_select_raw(recvs[:], nil, nil, &received_value)
+
+	testing.expect_value(t, idx, -1)
+	testing.expect_value(t, status, chan.Select_Status.None)
+
+	thread.join(thief)
+	thread.destroy(thief)
+}