diff --git a/core/sync/chan/chan.odin b/core/sync/chan/chan.odin index c5a4cf317..05312e5a2 100644 --- a/core/sync/chan/chan.odin +++ b/core/sync/chan/chan.odin @@ -83,6 +83,8 @@ Raw_Chan :: struct { r_waiting: int, // guarded by `mutex` w_waiting: int, // guarded by `mutex` + did_read: bool, // lets a sender know if the value was read + // Buffered queue: ^Raw_Queue, @@ -420,8 +422,8 @@ as_recv :: #force_inline proc "contextless" (c: $C/Chan($T, $D)) -> (r: Chan(T, Sends the specified message, blocking the current thread if: - the channel is unbuffered - the channel's buffer is full -until the channel is being read from. `send` will return -`false` when attempting to send on an already closed channel. +until the channel is being read from or the channel is closed. `send` will +return `false` when attempting to send on an already closed channel. **Inputs** - `c`: The channel @@ -492,8 +494,9 @@ try_send :: proc "contextless" (c: $C/Chan($T, $D), data: T) -> (ok: bool) where Reads a message from the channel, blocking the current thread if: - the channel is unbuffered - the channel's buffer is empty -until the channel is being written to. `recv` will return -`false` when attempting to receive a message on an already closed channel. +until the channel is being written to or the channel is closed. `recv` will +return `false` when attempting to receive a message on an already closed +channel. **Inputs** - `c`: The channel @@ -566,8 +569,8 @@ try_recv :: proc "contextless" (c: $C/Chan($T)) -> (data: T, ok: bool) where C.D Sends the specified message, blocking the current thread if: - the channel is unbuffered - the channel's buffer is full -until the channel is being read from. `send_raw` will return -`false` when attempting to send on an already closed channel. +until the channel is being read from or the channel is closed. `send_raw` will +return `false` when attempting to send on an already closed channel. Note: The message referenced by `msg_out` must match the size and alignment used when the `Raw_Chan` was created. @@ -627,12 +630,23 @@ send_raw :: proc "contextless" (c: ^Raw_Chan, msg_in: rawptr) -> (ok: bool) { return false } + c.did_read = false + defer c.did_read = false + mem.copy(c.unbuffered_data, msg_in, int(c.msg_size)) + c.w_waiting += 1 + if c.r_waiting > 0 { sync.signal(&c.r_cond) } + sync.wait(&c.w_cond, &c.mutex) + + if c.closed && !c.did_read { + return false + } + ok = true } return @@ -642,8 +656,9 @@ send_raw :: proc "contextless" (c: ^Raw_Chan, msg_in: rawptr) -> (ok: bool) { Reads a message from the channel, blocking the current thread if: - the channel is unbuffered - the channel's buffer is empty -until the channel is being written to. `recv_raw` will return -`false` when attempting to receive a message on an already closed channel. +until the channel is being written to or the channel is closed. `recv_raw` +will return `false` when attempting to receive a message on an already closed +channel. Note: The location pointed to by `msg_out` must match the size and alignment used when the `Raw_Chan` was created. @@ -706,8 +721,7 @@ recv_raw :: proc "contextless" (c: ^Raw_Chan, msg_out: rawptr) -> (ok: bool) { } else if c.unbuffered_data != nil { // unbuffered sync.guard(&c.mutex) - for !c.closed && - c.w_waiting == 0 { + for !c.closed && c.w_waiting == 0 { c.r_waiting += 1 sync.wait(&c.r_cond, &c.mutex) c.r_waiting -= 1 @@ -720,6 +734,7 @@ recv_raw :: proc "contextless" (c: ^Raw_Chan, msg_out: rawptr) -> (ok: bool) { mem.copy(msg_out, c.unbuffered_data, int(c.msg_size)) c.w_waiting -= 1 + c.did_read = true sync.signal(&c.w_cond) ok = true } @@ -779,7 +794,7 @@ try_send_raw :: proc "contextless" (c: ^Raw_Chan, msg_in: rawptr) -> (ok: bool) } else if c.unbuffered_data != nil { // unbuffered sync.guard(&c.mutex) - if c.closed { + if c.closed || c.r_waiting - c.w_waiting <= 0 { return false } @@ -843,7 +858,7 @@ try_recv_raw :: proc "contextless" (c: ^Raw_Chan, msg_out: rawptr) -> bool { } else if c.unbuffered_data != nil { // unbuffered sync.guard(&c.mutex) - if c.closed || c.w_waiting == 0 { + if c.closed || c.w_waiting - c.r_waiting <= 0 { return false } @@ -1046,8 +1061,9 @@ is_closed :: proc "contextless" (c: ^Raw_Chan) -> bool { } /* -Returns whether a message is ready to be read, i.e., -if a call to `recv` or `recv_raw` would block +Returns whether a message can be read without blocking the current +thread. Specifically, it checks if the channel is buffered and not full, +or if there is already a writer attempting to send a message. **Inputs** - `c`: The channel @@ -1075,7 +1091,7 @@ can_recv :: proc "contextless" (c: ^Raw_Chan) -> bool { if is_buffered(c) { return c.queue.len > 0 } - return c.w_waiting > 0 + return c.w_waiting - c.r_waiting > 0 } @@ -1088,7 +1104,7 @@ or if there is already a reader waiting for a message. - `c`: The channel **Returns** -- `true` if a message can be send, `false` otherwise +- `true` if a message can be sent, `false` otherwise Example: @@ -1110,7 +1126,7 @@ can_send :: proc "contextless" (c: ^Raw_Chan) -> bool { if is_buffered(c) { return c.queue.len < c.queue.cap } - return c.w_waiting == 0 + return c.r_waiting - c.w_waiting > 0 } /* diff --git a/tests/core/sync/chan/test_core_sync_chan.odin b/tests/core/sync/chan/test_core_sync_chan.odin index a87452eb0..304986ae7 100644 --- a/tests/core/sync/chan/test_core_sync_chan.odin +++ b/tests/core/sync/chan/test_core_sync_chan.odin @@ -4,6 +4,7 @@ import "base:runtime" import "base:intrinsics" import "core:log" import "core:math/rand" +import "core:sync" import "core:sync/chan" import "core:testing" import "core:thread" @@ -33,18 +34,16 @@ Comm :: struct { BUFFER_SIZE :: 8 MAX_RAND :: 32 FAIL_TIME :: 1 * time.Second -SLEEP_TIME :: 1 * time.Millisecond + +// Synchronizes try_select tests that require access to global state. +test_lock: sync.Mutex +__global_context_for_test: rawptr comm_client :: proc(th: ^thread.Thread) { data := cast(^Comm)th.data - manual_buffering := data.manual_buffering n: i64 - for manual_buffering && !chan.can_recv(data.host) { - thread.yield() - } - recv_loop: for msg in chan.recv(data.host) { #partial switch msg.type { case .Add: n += msg.i @@ -56,14 +55,6 @@ comm_client :: proc(th: ^thread.Thread) { case: panic("Unknown message type for client.") } - - for manual_buffering && !chan.can_recv(data.host) { - thread.yield() - } - } - - for manual_buffering && !chan.can_send(data.host) { - thread.yield() } chan.send(data.client, Message{.Result, n}) @@ -72,9 +63,6 @@ comm_client :: proc(th: ^thread.Thread) { send_messages :: proc(t: ^testing.T, host: chan.Chan(Message), manual_buffering: bool = false) -> (expected: i64) { expected = 1 - for manual_buffering && !chan.can_send(host) { - thread.yield() - } chan.send(host, Message{.Add, 1}) log.debug(Message{.Add, 1}) @@ -96,9 +84,6 @@ send_messages :: proc(t: ^testing.T, host: chan.Chan(Message), manual_buffering: expected /= msg.i } - for manual_buffering && !chan.can_send(host) { - thread.yield() - } if manual_buffering { testing.expect(t, chan.len(host) == 0) } @@ -107,9 +92,6 @@ send_messages :: proc(t: ^testing.T, host: chan.Chan(Message), manual_buffering: log.debug(msg) } - for manual_buffering && !chan.can_send(host) { - thread.yield() - } chan.send(host, Message{.End, 0}) log.debug(Message{.End, 0}) chan.close(host) @@ -148,18 +130,15 @@ test_chan_buffered :: proc(t: ^testing.T) { expected := send_messages(t, comm.host, manual_buffering = false) - // Sleep so we can give the other thread enough time to buffer its message. - time.sleep(SLEEP_TIME) - - testing.expect_value(t, chan.len(comm.client), 1) - result, ok := chan.try_recv(comm.client) - - // One more sleep to ensure it has enough time to close. - time.sleep(SLEEP_TIME) - - testing.expect_value(t, chan.is_closed(comm.client), true) + result, ok := chan.recv(comm.client) testing.expect_value(t, ok, true) testing.expect_value(t, result.i, expected) + + // Wait for channel to close. + _, ok = chan.recv(comm.client) + testing.expect(t, !ok, "channel should have been closed") + + testing.expect_value(t, chan.is_closed(comm.client), true) log.debug(result, expected) // Make sure sending to closed channels fails. @@ -171,6 +150,8 @@ test_chan_buffered :: proc(t: ^testing.T) { _, ok = chan.recv(comm.client); testing.expect_value(t, ok, false) _, ok = chan.try_recv(comm.host); testing.expect_value(t, ok, false) _, ok = chan.try_recv(comm.client); testing.expect_value(t, ok, false) + + thread.join(reckoner) } @test @@ -193,6 +174,10 @@ test_chan_unbuffered :: proc(t: ^testing.T) { testing.expect(t, !chan.is_buffered(comm.client)) testing.expect(t, chan.is_unbuffered(comm.host)) testing.expect(t, chan.is_unbuffered(comm.client)) + testing.expect(t, !chan.can_send(comm.host)) + testing.expect(t, !chan.can_send(comm.client)) + testing.expect(t, !chan.can_recv(comm.host)) + testing.expect(t, !chan.can_recv(comm.client)) testing.expect_value(t, chan.len(comm.host), 0) testing.expect_value(t, chan.len(comm.client), 0) testing.expect_value(t, chan.cap(comm.host), 0) @@ -203,25 +188,16 @@ test_chan_unbuffered :: proc(t: ^testing.T) { reckoner.data = &comm thread.start(reckoner) - for !chan.can_send(comm.client) { - thread.yield() - } - expected := send_messages(t, comm.host) testing.expect_value(t, chan.is_closed(comm.host), true) - for !chan.can_recv(comm.client) { - thread.yield() - } - - result, ok := chan.try_recv(comm.client) + result, ok := chan.recv(comm.client) testing.expect_value(t, ok, true) testing.expect_value(t, result.i, expected) log.debug(result, expected) - // Sleep so we can give the other thread enough time to close its side - // after we've received its message. - time.sleep(SLEEP_TIME) + _, ok2 := chan.recv(comm.client) + testing.expect(t, !ok2, "read of closed channel should return false") testing.expect_value(t, chan.is_closed(comm.client), true) @@ -234,6 +210,8 @@ test_chan_unbuffered :: proc(t: ^testing.T) { _, ok = chan.recv(comm.client); testing.expect_value(t, ok, false) _, ok = chan.try_recv(comm.host); testing.expect_value(t, ok, false) _, ok = chan.try_recv(comm.client); testing.expect_value(t, ok, false) + + thread.join(reckoner) } @test @@ -250,6 +228,198 @@ test_full_buffered_closed_chan_deadlock :: proc(t: ^testing.T) { testing.expect(t, !chan.send(ch, 32)) } +// Ensures that if a thread is doing a blocking send and the channel +// is closed, it will report false to indicate a failure to complete. +@test +test_fail_blocking_send_on_close :: proc(t: ^testing.T) { + ch, ch_alloc_err := chan.create(chan.Chan(int), context.allocator) + assert(ch_alloc_err == nil, "allocation failed") + defer chan.destroy(ch) + + sender := thread.create_and_start_with_poly_data(ch, proc(ch: chan.Chan(int)) { + assert(!chan.send(ch, 42)) + }) + + for !chan.can_recv(ch) { + thread.yield() + } + + testing.expect(t, chan.close(ch)) + thread.join(sender) + thread.destroy(sender) +} + +// Ensures that if a thread is doing a blocking read and the channel +// is closed, it will report false to indicate a failure to complete. +@test +test_fail_blocking_recv_on_close :: proc(t: ^testing.T) { + ch, ch_alloc_err := chan.create(chan.Chan(int), context.allocator) + assert(ch_alloc_err == nil, "allocation failed") + defer chan.destroy(ch) + + reader := thread.create_and_start_with_poly_data(ch, proc(ch: chan.Chan(int)) { + v, ok := chan.recv(ch) + assert(!ok) + assert(v == 0) + }) + + for !chan.can_send(ch) { + thread.yield() + } + + testing.expect(t, chan.close(ch)) + thread.join(reader) + thread.destroy(reader) +} + +// Ensures that try_send for unbuffered channels works as expected. +// If 1 reader of a channel, and 3 try_senders, only one of the senders +// will succeed and none of them will block. +@test +test_unbuffered_try_send_chan_contention :: proc(t: ^testing.T) { + testing.set_fail_timeout(t, FAIL_TIME) + + start, start_alloc_err := chan.create(chan.Chan(any), context.allocator) + assert(start_alloc_err == nil, "allocation failed") + defer chan.destroy(start) + + trigger, trigger_alloc_err := chan.create(chan.Chan(any), context.allocator) + assert(trigger_alloc_err == nil, "allocation failed") + defer chan.destroy(trigger) + + results, results_alloc_err := chan.create(chan.Chan(int), 3, context.allocator) + assert(results_alloc_err == nil, "allocation failed") + defer chan.destroy(results) + + ch, ch_alloc_err := chan.create(chan.Chan(int), context.allocator) + assert(ch_alloc_err == nil, "allocation failed") + defer chan.destroy(ch) + + // There are no readers or writers, so calling recv or send would block! + testing.expect_value(t, chan.can_send(ch), false) + testing.expect_value(t, chan.can_recv(ch), false) + + // Non-blocking operations should not block, and should return false. + testing.expect_value(t, chan.try_send(ch, -1), false) + if v, ok := chan.try_recv(ch); ok { + testing.expect_value(t, ok, false) + testing.expect_value(t, v, 0) + } + + // Spinup several threads contending to send on an unbuffered channel. + contenders: [3]^thread.Thread + wait: sync.Wait_Group + + for ii in 0..