Browse Source

Merge pull request #5329 from JackMordaunt/jfm-fix_chan_try_send

chan: fix try_send and send
Laytan 2 months ago
parent
commit
7f648d11d6
2 changed files with 305 additions and 64 deletions
  1. 33 17
      core/sync/chan/chan.odin
  2. 272 47
      tests/core/sync/chan/test_core_sync_chan.odin

+ 33 - 17
core/sync/chan/chan.odin

@@ -83,6 +83,8 @@ Raw_Chan :: struct {
 	r_waiting:       int,  // guarded by `mutex`
 	w_waiting:       int,  // guarded by `mutex`
 
+	did_read: bool, // lets a sender know if the value was read
+
 	// Buffered
 	queue: ^Raw_Queue,
 
@@ -420,8 +422,8 @@ as_recv :: #force_inline proc "contextless" (c: $C/Chan($T, $D)) -> (r: Chan(T,
 Sends the specified message, blocking the current thread if:
 - the channel is unbuffered
 - the channel's buffer is full
-until the channel is being read from. `send` will return
-`false` when attempting to send on an already closed channel.
+until the channel is being read from or the channel is closed. `send` will
+return `false` when attempting to send on an already closed channel.
 
 **Inputs**
 - `c`: The channel
@@ -492,8 +494,9 @@ try_send :: proc "contextless" (c: $C/Chan($T, $D), data: T) -> (ok: bool) where
 Reads a message from the channel, blocking the current thread if:
 - the channel is unbuffered
 - the channel's buffer is empty
-until the channel is being written to. `recv` will return
-`false` when attempting to receive a message on an already closed channel.
+until the channel is being written to or the channel is closed. `recv` will
+return `false` when attempting to receive a message on an already closed
+channel.
 
 **Inputs**
 - `c`: The channel
@@ -566,8 +569,8 @@ try_recv :: proc "contextless" (c: $C/Chan($T)) -> (data: T, ok: bool) where C.D
 Sends the specified message, blocking the current thread if:
 - the channel is unbuffered
 - the channel's buffer is full
-until the channel is being read from. `send_raw` will return
-`false` when attempting to send on an already closed channel.
+until the channel is being read from or the channel is closed. `send_raw` will
+return `false` when attempting to send on an already closed channel.
 
 Note: The message referenced by `msg_out` must match the size
 and alignment used when the `Raw_Chan` was created.
@@ -627,12 +630,23 @@ send_raw :: proc "contextless" (c: ^Raw_Chan, msg_in: rawptr) -> (ok: bool) {
 			return false
 		}
 
+		c.did_read = false
+		defer c.did_read = false
+
 		mem.copy(c.unbuffered_data, msg_in, int(c.msg_size))
+
 		c.w_waiting += 1
+
 		if c.r_waiting > 0 {
 			sync.signal(&c.r_cond)
 		}
+
 		sync.wait(&c.w_cond, &c.mutex)
+
+		if c.closed && !c.did_read {
+			return false
+		}
+
 		ok = true
 	}
 	return
@@ -642,8 +656,9 @@ send_raw :: proc "contextless" (c: ^Raw_Chan, msg_in: rawptr) -> (ok: bool) {
 Reads a message from the channel, blocking the current thread if:
 - the channel is unbuffered
 - the channel's buffer is empty
-until the channel is being written to. `recv_raw` will return
-`false` when attempting to receive a message on an already closed channel.
+until the channel is being written to or the channel is closed. `recv_raw`
+will return `false` when attempting to receive a message on an already closed
+channel.
 
 Note: The location pointed to by `msg_out` must match the size
 and alignment used when the `Raw_Chan` was created.
@@ -706,8 +721,7 @@ recv_raw :: proc "contextless" (c: ^Raw_Chan, msg_out: rawptr) -> (ok: bool) {
 	} else if c.unbuffered_data != nil { // unbuffered
 		sync.guard(&c.mutex)
 
-		for !c.closed &&
-			c.w_waiting == 0 {
+		for !c.closed && c.w_waiting == 0 {
 			c.r_waiting += 1
 			sync.wait(&c.r_cond, &c.mutex)
 			c.r_waiting -= 1
@@ -720,6 +734,7 @@ recv_raw :: proc "contextless" (c: ^Raw_Chan, msg_out: rawptr) -> (ok: bool) {
 		mem.copy(msg_out, c.unbuffered_data, int(c.msg_size))
 		c.w_waiting -= 1
 
+		c.did_read = true
 		sync.signal(&c.w_cond)
 		ok = true
 	}
@@ -779,7 +794,7 @@ try_send_raw :: proc "contextless" (c: ^Raw_Chan, msg_in: rawptr) -> (ok: bool)
 	} else if c.unbuffered_data != nil { // unbuffered
 		sync.guard(&c.mutex)
 
-		if c.closed {
+		if c.closed || c.r_waiting - c.w_waiting <= 0 {
 			return false
 		}
 
@@ -843,7 +858,7 @@ try_recv_raw :: proc "contextless" (c: ^Raw_Chan, msg_out: rawptr) -> bool {
 	} else if c.unbuffered_data != nil { // unbuffered
 		sync.guard(&c.mutex)
 
-		if c.closed || c.w_waiting == 0 {
+		if c.closed || c.w_waiting - c.r_waiting <= 0 {
 			return false
 		}
 
@@ -1046,8 +1061,9 @@ is_closed :: proc "contextless" (c: ^Raw_Chan) -> bool {
 }
 
 /*
-Returns whether a message is ready to be read, i.e.,
-if a call to `recv` or `recv_raw` would block
+Returns whether a message can be read without blocking the current
+thread. Specifically, it checks if the channel is buffered and not full,
+or if there is already a writer attempting to send a message.
 
 **Inputs**
 - `c`: The channel
@@ -1075,7 +1091,7 @@ can_recv :: proc "contextless" (c: ^Raw_Chan) -> bool {
 	if is_buffered(c) {
 		return c.queue.len > 0
 	}
-	return c.w_waiting > 0
+	return c.w_waiting - c.r_waiting > 0
 }
 
 
@@ -1088,7 +1104,7 @@ or if there is already a reader waiting for a message.
 - `c`: The channel
 
 **Returns**
-- `true` if a message can be send, `false` otherwise
+- `true` if a message can be sent, `false` otherwise
 
 Example:
 
@@ -1110,7 +1126,7 @@ can_send :: proc "contextless" (c: ^Raw_Chan) -> bool {
 	if is_buffered(c) {
 		return c.queue.len < c.queue.cap
 	}
-	return c.w_waiting == 0
+	return c.r_waiting - c.w_waiting > 0
 }
 
 /*

+ 272 - 47
tests/core/sync/chan/test_core_sync_chan.odin

@@ -4,6 +4,7 @@ import "base:runtime"
 import "base:intrinsics"
 import "core:log"
 import "core:math/rand"
+import "core:sync"
 import "core:sync/chan"
 import "core:testing"
 import "core:thread"
@@ -33,18 +34,16 @@ Comm :: struct {
 BUFFER_SIZE :: 8
 MAX_RAND    :: 32
 FAIL_TIME   :: 1 * time.Second
-SLEEP_TIME  :: 1 * time.Millisecond
+
+// Synchronizes try_select tests that require access to global state.
+test_lock: sync.Mutex
+__global_context_for_test: rawptr
 
 comm_client :: proc(th: ^thread.Thread) {
 	data := cast(^Comm)th.data
-	manual_buffering := data.manual_buffering
 
 	n: i64
 
-	for manual_buffering && !chan.can_recv(data.host) {
-		thread.yield()
-	}
-
 	recv_loop: for msg in chan.recv(data.host) {
 		#partial switch msg.type {
 		case .Add:      n += msg.i
@@ -56,14 +55,6 @@ comm_client :: proc(th: ^thread.Thread) {
 		case:
 			panic("Unknown message type for client.")
 		}
-
-		for manual_buffering && !chan.can_recv(data.host) {
-			thread.yield()
-		}
-	}
-
-	for manual_buffering && !chan.can_send(data.host) {
-		thread.yield()
 	}
 
 	chan.send(data.client, Message{.Result, n})
@@ -72,9 +63,6 @@ comm_client :: proc(th: ^thread.Thread) {
 
 send_messages :: proc(t: ^testing.T, host: chan.Chan(Message), manual_buffering: bool = false) -> (expected: i64) {
 	expected = 1
-	for manual_buffering && !chan.can_send(host) {
-		thread.yield()
-	}
 	chan.send(host, Message{.Add, 1})
 	log.debug(Message{.Add, 1})
 
@@ -96,9 +84,6 @@ send_messages :: proc(t: ^testing.T, host: chan.Chan(Message), manual_buffering:
 			expected /= msg.i
 		}
 
-		for manual_buffering && !chan.can_send(host) {
-			thread.yield()
-		}
 		if manual_buffering {
 			testing.expect(t, chan.len(host) == 0)
 		}
@@ -107,9 +92,6 @@ send_messages :: proc(t: ^testing.T, host: chan.Chan(Message), manual_buffering:
 		log.debug(msg)
 	}
 
-	for manual_buffering && !chan.can_send(host) {
-		thread.yield()
-	}
 	chan.send(host, Message{.End, 0})
 	log.debug(Message{.End, 0})
 	chan.close(host)
@@ -148,18 +130,15 @@ test_chan_buffered :: proc(t: ^testing.T) {
 
 	expected := send_messages(t, comm.host, manual_buffering = false)
 
-	// Sleep so we can give the other thread enough time to buffer its message.
-	time.sleep(SLEEP_TIME)
-
-	testing.expect_value(t, chan.len(comm.client), 1)
-	result, ok := chan.try_recv(comm.client)
+	result, ok := chan.recv(comm.client)
+	testing.expect_value(t, ok, true)
+	testing.expect_value(t, result.i, expected)
 
-	// One more sleep to ensure it has enough time to close.
-	time.sleep(SLEEP_TIME)
+	// Wait for channel to close.
+	_, ok = chan.recv(comm.client)
+	testing.expect(t, !ok, "channel should have been closed")
 
 	testing.expect_value(t, chan.is_closed(comm.client), true)
-	testing.expect_value(t, ok, true)
-	testing.expect_value(t, result.i, expected)
 	log.debug(result, expected)
 
 	// Make sure sending to closed channels fails.
@@ -171,6 +150,8 @@ test_chan_buffered :: proc(t: ^testing.T) {
 	_, ok = chan.recv(comm.client);     testing.expect_value(t, ok, false)
 	_, ok = chan.try_recv(comm.host);   testing.expect_value(t, ok, false)
 	_, ok = chan.try_recv(comm.client); testing.expect_value(t, ok, false)
+
+	thread.join(reckoner)
 }
 
 @test
@@ -193,6 +174,10 @@ test_chan_unbuffered :: proc(t: ^testing.T) {
 	testing.expect(t, !chan.is_buffered(comm.client))
 	testing.expect(t, chan.is_unbuffered(comm.host))
 	testing.expect(t, chan.is_unbuffered(comm.client))
+	testing.expect(t, !chan.can_send(comm.host))
+	testing.expect(t, !chan.can_send(comm.client))
+	testing.expect(t, !chan.can_recv(comm.host))
+	testing.expect(t, !chan.can_recv(comm.client))
 	testing.expect_value(t, chan.len(comm.host), 0)
 	testing.expect_value(t, chan.len(comm.client), 0)
 	testing.expect_value(t, chan.cap(comm.host), 0)
@@ -203,25 +188,16 @@ test_chan_unbuffered :: proc(t: ^testing.T) {
 	reckoner.data = &comm
 	thread.start(reckoner)
 
-	for !chan.can_send(comm.client) {
-		thread.yield()
-	}
-
 	expected := send_messages(t, comm.host)
 	testing.expect_value(t, chan.is_closed(comm.host), true)
 
-	for !chan.can_recv(comm.client) {
-		thread.yield()
-	}
-
-	result, ok := chan.try_recv(comm.client)
+	result, ok := chan.recv(comm.client)
 	testing.expect_value(t, ok, true)
 	testing.expect_value(t, result.i, expected)
 	log.debug(result, expected)
 
-	// Sleep so we can give the other thread enough time to close its side
-	// after we've received its message.
-	time.sleep(SLEEP_TIME)
+	_, ok2 := chan.recv(comm.client)
+	testing.expect(t, !ok2, "read of closed channel should return false")
 
 	testing.expect_value(t, chan.is_closed(comm.client), true)
 
@@ -234,6 +210,8 @@ test_chan_unbuffered :: proc(t: ^testing.T) {
 	_, ok = chan.recv(comm.client);     testing.expect_value(t, ok, false)
 	_, ok = chan.try_recv(comm.host);   testing.expect_value(t, ok, false)
 	_, ok = chan.try_recv(comm.client); testing.expect_value(t, ok, false)
+
+	thread.join(reckoner)
 }
 
 @test
@@ -250,6 +228,198 @@ test_full_buffered_closed_chan_deadlock :: proc(t: ^testing.T) {
 	testing.expect(t, !chan.send(ch, 32))
 }
 
+// Ensures that if a thread is doing a blocking send and the channel
+// is closed, it will report false to indicate a failure to complete.
+@test
+test_fail_blocking_send_on_close :: proc(t: ^testing.T) {
+	ch, ch_alloc_err := chan.create(chan.Chan(int), context.allocator)
+	assert(ch_alloc_err == nil, "allocation failed")
+	defer chan.destroy(ch)
+
+	sender := thread.create_and_start_with_poly_data(ch, proc(ch: chan.Chan(int)) {
+		assert(!chan.send(ch, 42))
+	})
+
+	for !chan.can_recv(ch) {
+		thread.yield()
+	}
+
+	testing.expect(t, chan.close(ch))
+	thread.join(sender)
+	thread.destroy(sender)
+}
+
+// Ensures that if a thread is doing a blocking read and the channel
+// is closed, it will report false to indicate a failure to complete.
+@test
+test_fail_blocking_recv_on_close :: proc(t: ^testing.T) {
+	ch, ch_alloc_err := chan.create(chan.Chan(int), context.allocator)
+	assert(ch_alloc_err == nil, "allocation failed")
+	defer chan.destroy(ch)
+
+	reader := thread.create_and_start_with_poly_data(ch, proc(ch: chan.Chan(int)) {
+		v, ok := chan.recv(ch)
+		assert(!ok)
+		assert(v == 0)
+	})
+
+	for !chan.can_send(ch) {
+		thread.yield()
+	}
+
+	testing.expect(t, chan.close(ch))
+	thread.join(reader)
+	thread.destroy(reader)
+}
+
+// Ensures that try_send for unbuffered channels works as expected.
+// If 1 reader of a channel, and 3 try_senders, only one of the senders
+// will succeed and none of them will block.
+@test
+test_unbuffered_try_send_chan_contention :: proc(t: ^testing.T) {
+	testing.set_fail_timeout(t, FAIL_TIME)
+
+	start, start_alloc_err := chan.create(chan.Chan(any), context.allocator)
+	assert(start_alloc_err == nil, "allocation failed")
+	defer chan.destroy(start)
+
+	trigger, trigger_alloc_err := chan.create(chan.Chan(any), context.allocator)
+	assert(trigger_alloc_err == nil, "allocation failed")
+	defer chan.destroy(trigger)
+
+	results, results_alloc_err := chan.create(chan.Chan(int), 3, context.allocator)
+	assert(results_alloc_err == nil, "allocation failed")
+	defer chan.destroy(results)
+
+	ch, ch_alloc_err := chan.create(chan.Chan(int), context.allocator)
+	assert(ch_alloc_err == nil, "allocation failed")
+	defer chan.destroy(ch)
+
+	// There are no readers or writers, so calling recv or send would block!
+	testing.expect_value(t, chan.can_send(ch), false)
+	testing.expect_value(t, chan.can_recv(ch), false)
+
+	// Non-blocking operations should not block, and should return false.
+	testing.expect_value(t, chan.try_send(ch, -1), false)
+	if v, ok := chan.try_recv(ch); ok {
+		testing.expect_value(t, ok, false)
+		testing.expect_value(t, v, 0)
+	}
+
+	// Spinup several threads contending to send on an unbuffered channel.
+	contenders: [3]^thread.Thread
+	wait: sync.Wait_Group
+
+	for ii in 0..<len(contenders) {
+		sync.wait_group_add(&wait, 1)
+		Context :: struct {
+			id: int,
+			start: chan.Chan(any),
+			trigger: chan.Chan(any),
+			results: chan.Chan(int),
+			ch: chan.Chan(int),
+			wg: ^sync.Wait_Group,
+		}
+		ctx := Context {
+			id = ii,
+			start = start,
+			trigger = trigger,
+			results = results,
+			ch	 = ch,
+			wg = &wait,
+		}
+		contenders[ii] = thread.create_and_start_with_poly_data(ctx, proc(ctx: Context) {
+			defer sync.wait_group_done(ctx.wg)
+
+			assert(!chan.can_send(ctx.ch), "channel shouldn't be ready for non-blocking send yet")
+			assert(chan.send(ctx.start, "ready"))
+
+			log.debugf("contender %v: ready", ctx.id)
+
+			// Wait for trigger to be closed so that all contenders have the same opportunity.
+			_, _ = chan.recv(ctx.trigger)
+
+			log.debugf("contender %v: racing", ctx.id)
+
+			// Attempt to send a value. We are competing against the other contenders.
+			ok := chan.try_send(ctx.ch, 42)
+			if ok {
+				log.debugf("contender %v: sent!", ctx.id)
+				assert(chan.send(ctx.results, 1))
+			} else {
+				log.debugf("contender %v: too-slow", ctx.id)
+				assert(chan.send(ctx.results, -1))
+			}
+		}, init_context = context)
+	}
+
+	// Spinup a closer thread that will close the results channel once all
+	// contenders are done. This lets the test thread check for spurious results by
+	// draining the results until closed.
+	results_closer := thread.create_and_start_with_poly_data2(&wait, results, proc(wg: ^sync.Wait_Group, results: chan.Chan(int)) {
+		sync.wait_group_wait(wg)
+		assert(chan.close(results))
+	})
+
+	// Wait for contenders to be ready.
+	for _ in 0..<len(contenders) {
+		if data, ok := chan.recv(start); !ok {
+			testing.expect_value(t, ok, true)
+			testing.expect_value(t, data.(string), "ready")
+		}
+	}
+
+	// Fire the trigger when the test thread is ready to receive.
+	trigger_closer := thread.create_and_start_with_poly_data2(trigger, ch, proc(trigger: chan.Chan(any), ch: chan.Chan(int)) {
+		for !chan.can_send(ch) {
+			thread.yield()
+		}
+		assert(chan.close(trigger))
+	})
+
+	// Blocking read, wait for a sender.
+	if v, ok := chan.recv(ch); !ok {
+		testing.expect_value(t, ok, true)
+		testing.expect_value(t, v, 42)
+	}
+
+	did_send_count: int
+	did_not_send_count: int
+
+	// Let the contenders fight to send a value.
+	for {
+		data, ok := chan.recv(results)
+		if !ok {
+			break
+		}
+
+		log.debugf("data: %v, ok: %v", data, ok)
+
+		switch data {
+		case 1:
+			did_send_count += 1
+		case -1:
+			did_not_send_count += 1
+		case:
+			testing.fail_now(t, "got spurious result")
+		}
+	}
+
+	thread.join(trigger_closer)
+	thread.join(results_closer)
+	thread.join_multiple(..contenders[:])
+
+	defer for tr in contenders {
+		thread.destroy(tr)
+	}
+	defer thread.destroy(trigger_closer)
+	defer thread.destroy(results_closer)
+
+	// Expect that one got to send and the others did not.
+	testing.expect_value(t, did_send_count, 1)
+	testing.expect_value(t, did_not_send_count, len(contenders)-1)
+}
+
 // This test guarantees a buffered channel's messages can still be received
 // even after closing. This is currently how the API works. If that changes,
 // this test will need to change.
@@ -279,6 +449,7 @@ test_accept_message_from_closed_buffered_chan :: proc(t: ^testing.T) {
 /*
 @test
 test_try_select_raw_happy :: proc(t: ^testing.T) {
+	sync.guard(&test_lock)
 	testing.set_fail_timeout(t, FAIL_TIME)
 
 	recv1, recv1_err := chan.create(chan.Chan(int), context.allocator)
@@ -351,6 +522,7 @@ test_try_select_raw_happy :: proc(t: ^testing.T) {
 // try_select_raw operation does not block.
 @test
 test_try_select_raw_default_state :: proc(t: ^testing.T) {
+	sync.guard(&test_lock)
 	testing.set_fail_timeout(t, FAIL_TIME)
 
 	recv1, recv1_err := chan.create(chan.Chan(int), context.allocator)
@@ -377,6 +549,7 @@ test_try_select_raw_default_state :: proc(t: ^testing.T) {
 // thread between calls to can_{send,recv} and try_{send,recv}_raw.
 @test
 test_try_select_raw_no_toctou :: proc(t: ^testing.T) {
+	sync.guard(&test_lock)
 	testing.set_fail_timeout(t, FAIL_TIME)
 
 	// Trigger will be used to coordinate between the thief and the try_select.
@@ -385,9 +558,6 @@ test_try_select_raw_no_toctou :: proc(t: ^testing.T) {
 	assert(trigger_err == nil, "allocation failed")
 	defer chan.destroy(trigger)
 
-	@(static)
-	__global_context_for_test: rawptr
-
 	__global_context_for_test = &trigger
 	defer __global_context_for_test = nil
 
@@ -452,3 +622,58 @@ test_try_select_raw_no_toctou :: proc(t: ^testing.T) {
 	thread.join(thief)
 	thread.destroy(thief)
 }
+
+// Ensures that a sender will always report correctly whether the value was received
+// or not in the event of channel closure.
+//
+// 1. send thread does a blocking send
+// 2. recv and close threads race
+// 3. send returns false if close won and reports true if recv won
+//
+// We know if recv won by whether it sends us the original value on the results channel.
+// This test is non-deterministic.
+@test
+test_send_close_read :: proc(t: ^testing.T) {
+	trigger, trigger_err := chan.create(chan.Chan(int), context.allocator)
+	assert(trigger_err == nil, "allocation failed")
+	defer chan.destroy(trigger)
+
+	ch, alloc_err := chan.create(chan.Chan(int), context.allocator)
+	assert(alloc_err == nil, "allocation failed")
+	defer chan.destroy(ch)
+
+	results, results_err := chan.create(chan.Chan(int), 1, context.allocator)
+	assert(results_err == nil, "allocation failed")
+	defer chan.destroy(results)
+
+	receiver := thread.create_and_start_with_poly_data3(trigger, results, ch, proc(trigger, results, ch: chan.Chan(int)) {
+		_, _ = chan.recv(trigger)
+		v, _ := chan.recv(ch)
+		assert(chan.send(results, v))
+	})
+
+	closer := thread.create_and_start_with_poly_data2(trigger, ch, proc(trigger, ch: chan.Chan(int)) {
+		_, _ = chan.recv(trigger)
+		ok := chan.close(ch)
+		assert(ok)
+	})
+
+	testing.expect(t, chan.close(trigger))
+
+	did_send := chan.send(ch, 42)
+
+	v, ok := chan.recv(results)
+	testing.expect(t, ok)
+
+	if v == 42 {
+		testing.expect(t, did_send)
+	} else {
+		testing.expect(t, !did_send)
+	}
+
+	thread.join_multiple(receiver, closer)
+	thread.destroy(receiver)
+	thread.destroy(closer)
+}
+
+