Ver Fonte

tests/core/sync/chan: test harness for chan.try_select_raw

This test harness ensures consistent non-blocking semantics and
validates that we have solved the toctou condition.

The __global_context_for_test is a bit of a hack to fuse together the
test supplied proc and the executing logic in packaage chan.
Jack Mordaunt há 3 meses atrás
pai
commit
4d7c182f7d
1 ficheiros alterados com 176 adições e 0 exclusões
  1. 176 0
      tests/core/sync/chan/test_core_sync_chan.odin

+ 176 - 0
tests/core/sync/chan/test_core_sync_chan.odin

@@ -35,6 +35,8 @@ MAX_RAND    :: 32
 FAIL_TIME   :: 1 * time.Second
 SLEEP_TIME  :: 1 * time.Millisecond
 
+__global_context_for_test: rawptr
+
 comm_client :: proc(th: ^thread.Thread) {
 	data := cast(^Comm)th.data
 	manual_buffering := data.manual_buffering
@@ -272,3 +274,177 @@ test_accept_message_from_closed_buffered_chan :: proc(t: ^testing.T) {
 	testing.expect_value(t, result, 64)
 	testing.expect(t, ok)
 }
+
+// Ensures that if any input channel is eligible to receive or send, the try_select_raw
+// operation will process it.
+@test
+test_try_select_raw_happy :: proc(t: ^testing.T) {
+	testing.set_fail_timeout(t, FAIL_TIME)
+
+	recv1, recv1_err := chan.create(chan.Chan(int), context.allocator)
+
+	assert(recv1_err == nil, "allocation failed")
+	defer chan.destroy(recv1)
+
+	recv2, recv2_err := chan.create(chan.Chan(int), 1, context.allocator)
+
+	assert(recv2_err == nil, "allocation failed")
+	defer chan.destroy(recv2)
+
+	send1, send1_err := chan.create(chan.Chan(int), 1, context.allocator)
+
+	assert(send1_err == nil, "allocation failed")
+	defer chan.destroy(send1)
+
+	msg := 42
+
+	// Preload recv2 to make it eligible for selection.
+	testing.expect_value(t, chan.send(recv2, msg), true)
+
+	recvs := [?]^chan.Raw_Chan{recv1, recv2}
+	sends := [?]^chan.Raw_Chan{send1}
+	msgs := [?]rawptr{&msg}
+	received_value: int
+
+	iteration_count := 0
+	did_none_count := 0
+	did_send_count := 0
+	did_receive_count := 0
+
+	// This loop is expected to iterate three times. Twice to do the receive and
+	// send operations, and a third time to exit.
+	receive_loop: for {
+
+		iteration_count += 1
+
+		idx, status := chan.try_select_raw(recvs[:], sends[:], msgs[:], &received_value)
+
+		switch status {
+		case .None:
+			did_none_count += 1
+			break receive_loop
+
+		case .Recv:
+			did_receive_count += 1
+			testing.expect_value(t, idx, 1)
+			testing.expect_value(t, received_value, msg)
+			received_value = 0
+
+		case .Send:
+			did_send_count += 1
+			testing.expect_value(t, idx, 0)
+			v, ok := chan.try_recv(send1)
+			testing.expect_value(t, ok, true)
+			testing.expect_value(t, v, msg)
+			msgs[0] = nil // nil out the message to avoid constantly resending the same value.
+		}
+	}
+
+	testing.expect_value(t, iteration_count, 3)
+	testing.expect_value(t, did_none_count, 1)
+	testing.expect_value(t, did_receive_count, 1)
+	testing.expect_value(t, did_send_count, 1)
+}
+
+// Ensures that if no input channels are eligible to receive or send, the
+// try_select_raw operation does not block.
+@test
+test_try_select_raw_default_state :: proc(t: ^testing.T) {
+	testing.set_fail_timeout(t, FAIL_TIME)
+
+	recv1, recv1_err := chan.create(chan.Chan(int), context.allocator)
+
+	assert(recv1_err == nil, "allocation failed")
+	defer chan.destroy(recv1)
+
+	recv2, recv2_err := chan.create(chan.Chan(int), context.allocator)
+
+	assert(recv2_err == nil, "allocation failed")
+	defer chan.destroy(recv2)
+
+	recvs := [?]^chan.Raw_Chan{recv1, recv2}
+	received_value: int
+
+	idx, status := chan.try_select_raw(recvs[:], nil, nil, &received_value)
+
+	testing.expect_value(t, idx, -1)
+	testing.expect_value(t, status, chan.Select_Status.None)
+}
+
+// Ensures that the operation will not block even if the input channels are
+// consumed by a competing thread; that is, a value is received from another
+// thread between calls to can_{send,recv} and try_{send,recv}_raw.
+@test
+test_try_select_raw_no_toctou :: proc(t: ^testing.T) {
+	testing.set_fail_timeout(t, FAIL_TIME)
+
+	// Trigger will be used to coordinate between the thief and the try_select.
+	trigger, trigger_err := chan.create(chan.Chan(any), context.allocator)
+
+	assert(trigger_err == nil, "allocation failed")
+	defer chan.destroy(trigger)
+
+	__global_context_for_test = &trigger
+	defer __global_context_for_test = nil
+
+	// Setup the pause proc. This will be invoked after the input channels are
+	// checked for eligibility but before any channel operations are attempted.
+	chan.__try_select_raw_pause = proc() {
+		trigger := (cast(^chan.Chan(any))(__global_context_for_test))^
+
+		// Notify the thief that we are paused so that it can steal the value.
+		 _ = chan.send(trigger, "signal")
+
+		// Wait for comfirmation of the burglary.
+		_, _ = chan.recv(trigger)
+	}
+
+	defer chan.__try_select_raw_pause = nil
+
+	recv1, recv1_err := chan.create(chan.Chan(int), 1, context.allocator)
+
+	assert(recv1_err == nil, "allocation failed")
+	defer chan.destroy(recv1)
+
+	Context :: struct {
+		recv1: chan.Chan(int),
+		trigger: chan.Chan(any),
+	}
+
+	ctx := Context{
+		recv1 = recv1,
+		trigger = trigger,
+	}
+
+	// Spin up a thread that will steal the value from the input channel after
+	// try_select has already considered it eligible for selection.
+	thief := thread.create_and_start_with_poly_data(ctx, proc(ctx: Context) {
+		// Wait for eligibility check.
+		_, _ = chan.recv(ctx.trigger)
+
+		// Steal the value.
+		v, ok := chan.recv(ctx.recv1)
+
+		assert(ok, "recv1: expected to receive a value")
+		assert(v == 42, "recv1: unexpected receive value")
+
+		// Notify select that we have stolen the value and that it can proceed.
+		_ = chan.send(ctx.trigger, "signal")
+	})
+
+	recvs := [?]^chan.Raw_Chan{recv1}
+	received_value: int
+
+	// Ensure channel is eligible prior to entering the select.
+	testing.expect_value(t, chan.send(recv1, 42), true)
+
+	// Execute the try_select_raw, assert that we don't block, and that we receive
+	// .None status since the value was stolen by the other thread.
+	idx, status := chan.try_select_raw(recvs[:], nil, nil, &received_value)
+
+	testing.expect_value(t, idx, -1)
+	testing.expect_value(t, status, chan.Select_Status.None)
+
+	thread.join(thief)
+	thread.destroy(thief)
+}