diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index 3618fb60d8f..7039bd7054c 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -559,8 +559,6 @@ async def call_with_kwarg(): with self.assertRaises(TypeError): self.loop.run_until_complete(call_with_kwarg()) - # TODO: RUSTPYTHON, error message mismatch - @unittest.expectedFailure def test_anext_bad_await(self): async def bad_awaitable(): class BadAwaitable: @@ -630,8 +628,6 @@ async def do_test(): result = self.loop.run_until_complete(do_test()) self.assertEqual(result, "completed") - # TODO: RUSTPYTHON, anext coroutine iteration issue - @unittest.expectedFailure def test_anext_iter(self): @types.coroutine def _async_yield(v): @@ -1489,8 +1485,6 @@ async def main(): self.assertEqual(messages, []) - # TODO: RUSTPYTHON, ValueError: not enough values to unpack (expected 1, got 0) - @unittest.expectedFailure def test_async_gen_asyncio_shutdown_exception_01(self): messages = [] diff --git a/crates/vm/src/builtins/asyncgenerator.rs b/crates/vm/src/builtins/asyncgenerator.rs index b41a49f931e..073513184ff 100644 --- a/crates/vm/src/builtins/asyncgenerator.rs +++ b/crates/vm/src/builtins/asyncgenerator.rs @@ -3,6 +3,7 @@ use crate::{ AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::PyBaseExceptionRef, class::PyClassImpl, + common::lock::PyMutex, coroutine::Coro, frame::FrameRef, function::OptionalArg, @@ -17,6 +18,10 @@ use crossbeam_utils::atomic::AtomicCell; pub struct PyAsyncGen { inner: Coro, running_async: AtomicCell, + // whether hooks have been initialized + ag_hooks_inited: AtomicCell, + // ag_origin_or_finalizer - stores the finalizer callback + ag_finalizer: PyMutex>, } type PyAsyncGenRef = PyRef; @@ -37,6 +42,48 @@ impl PyAsyncGen { Self { inner: Coro::new(frame, name, qualname), running_async: AtomicCell::new(false), + ag_hooks_inited: AtomicCell::new(false), + ag_finalizer: PyMutex::new(None), + } + } + + /// Initialize async generator hooks. + /// Returns Ok(()) if successful, Err if firstiter hook raised an exception. + fn init_hooks(zelf: &Py, vm: &VirtualMachine) -> PyResult<()> { + // = async_gen_init_hooks + if zelf.ag_hooks_inited.load() { + return Ok(()); + } + + zelf.ag_hooks_inited.store(true); + + // Get and store finalizer from thread-local storage + let finalizer = crate::vm::thread::ASYNC_GEN_FINALIZER.with_borrow(|f| f.as_ref().cloned()); + if let Some(finalizer) = finalizer { + *zelf.ag_finalizer.lock() = Some(finalizer); + } + + // Call firstiter hook + let firstiter = crate::vm::thread::ASYNC_GEN_FIRSTITER.with_borrow(|f| f.as_ref().cloned()); + if let Some(firstiter) = firstiter { + let obj: PyObjectRef = zelf.to_owned().into(); + firstiter.call((obj,), vm)?; + } + + Ok(()) + } + + /// Call finalizer hook if set + #[allow(dead_code)] + fn call_finalizer(zelf: &Py, vm: &VirtualMachine) { + // = gen_dealloc + let finalizer = zelf.ag_finalizer.lock().clone(); + if let Some(finalizer) = finalizer + && !zelf.inner.closed.load() + { + // Call finalizer, ignore any errors (PyErr_WriteUnraisable) + let obj: PyObjectRef = zelf.to_owned().into(); + let _ = finalizer.call((obj,), vm); } } @@ -91,17 +138,23 @@ impl PyRef { } #[pymethod] - fn __anext__(self, vm: &VirtualMachine) -> PyAsyncGenASend { - Self::asend(self, vm.ctx.none(), vm) + fn __anext__(self, vm: &VirtualMachine) -> PyResult { + PyAsyncGen::init_hooks(&self, vm)?; + Ok(PyAsyncGenASend { + ag: self, + state: AtomicCell::new(AwaitableState::Init), + value: vm.ctx.none(), + }) } #[pymethod] - const fn asend(self, value: PyObjectRef, _vm: &VirtualMachine) -> PyAsyncGenASend { - PyAsyncGenASend { + fn asend(self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + PyAsyncGen::init_hooks(&self, vm)?; + Ok(PyAsyncGenASend { ag: self, state: AtomicCell::new(AwaitableState::Init), value, - } + }) } #[pymethod] @@ -111,8 +164,9 @@ impl PyRef { exc_val: OptionalArg, exc_tb: OptionalArg, vm: &VirtualMachine, - ) -> PyAsyncGenAThrow { - PyAsyncGenAThrow { + ) -> PyResult { + PyAsyncGen::init_hooks(&self, vm)?; + Ok(PyAsyncGenAThrow { ag: self, aclose: false, state: AtomicCell::new(AwaitableState::Init), @@ -121,12 +175,13 @@ impl PyRef { exc_val.unwrap_or_none(vm), exc_tb.unwrap_or_none(vm), ), - } + }) } #[pymethod] - fn aclose(self, vm: &VirtualMachine) -> PyAsyncGenAThrow { - PyAsyncGenAThrow { + fn aclose(self, vm: &VirtualMachine) -> PyResult { + PyAsyncGen::init_hooks(&self, vm)?; + Ok(PyAsyncGenAThrow { ag: self, aclose: true, state: AtomicCell::new(AwaitableState::Init), @@ -135,7 +190,7 @@ impl PyRef { vm.ctx.none(), vm.ctx.none(), ), - } + }) } } @@ -441,6 +496,7 @@ impl IterNext for PyAsyncGenAThrow { pub struct PyAnextAwaitable { wrapped: PyObjectRef, default_value: PyObjectRef, + state: AtomicCell, } impl PyPayload for PyAnextAwaitable { @@ -456,6 +512,7 @@ impl PyAnextAwaitable { Self { wrapped, default_value, + state: AtomicCell::new(AwaitableState::Init), } } @@ -464,6 +521,13 @@ impl PyAnextAwaitable { zelf } + fn check_closed(&self, vm: &VirtualMachine) -> PyResult<()> { + if let AwaitableState::Closed = self.state.load() { + return Err(vm.new_runtime_error("cannot reuse already awaited __anext__()/asend()")); + } + Ok(()) + } + /// Get the awaitable iterator from wrapped object. // = anextawaitable_getiter. fn get_awaitable_iter(&self, vm: &VirtualMachine) -> PyResult { @@ -523,6 +587,8 @@ impl PyAnextAwaitable { #[pymethod] fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.check_closed(vm)?; + self.state.store(AwaitableState::Iter); let awaitable = self.get_awaitable_iter(vm)?; let result = vm.call_method(&awaitable, "send", (val,)); self.handle_result(result, vm) @@ -536,6 +602,8 @@ impl PyAnextAwaitable { exc_tb: OptionalArg, vm: &VirtualMachine, ) -> PyResult { + self.check_closed(vm)?; + self.state.store(AwaitableState::Iter); let awaitable = self.get_awaitable_iter(vm)?; let result = vm.call_method( &awaitable, @@ -551,6 +619,7 @@ impl PyAnextAwaitable { #[pymethod] fn close(&self, vm: &VirtualMachine) -> PyResult<()> { + self.state.store(AwaitableState::Closed); if let Ok(awaitable) = self.get_awaitable_iter(vm) { let _ = vm.call_method(&awaitable, "close", ()); } diff --git a/crates/vm/src/builtins/coroutine.rs b/crates/vm/src/builtins/coroutine.rs index 8f57059e085..0909cdfb444 100644 --- a/crates/vm/src/builtins/coroutine.rs +++ b/crates/vm/src/builtins/coroutine.rs @@ -8,6 +8,7 @@ use crate::{ protocol::PyIterReturn, types::{IterNext, Iterable, Representable, SelfIter, Unconstructible}, }; +use crossbeam_utils::atomic::AtomicCell; #[pyclass(module = false, name = "coroutine")] #[derive(Debug)] @@ -56,8 +57,11 @@ impl PyCoroutine { } #[pymethod(name = "__await__")] - const fn r#await(zelf: PyRef) -> PyCoroutineWrapper { - PyCoroutineWrapper { coro: zelf } + fn r#await(zelf: PyRef) -> PyCoroutineWrapper { + PyCoroutineWrapper { + coro: zelf, + closed: AtomicCell::new(false), + } } #[pygetset] @@ -140,6 +144,7 @@ impl IterNext for PyCoroutine { // PyCoroWrapper_Type in CPython pub struct PyCoroutineWrapper { coro: PyRef, + closed: AtomicCell, } impl PyPayload for PyCoroutineWrapper { @@ -151,9 +156,22 @@ impl PyPayload for PyCoroutineWrapper { #[pyclass(with(IterNext, Iterable))] impl PyCoroutineWrapper { + fn check_closed(&self, vm: &VirtualMachine) -> PyResult<()> { + if self.closed.load() { + return Err(vm.new_runtime_error("cannot reuse already awaited coroutine")); + } + Ok(()) + } + #[pymethod] fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.coro.send(val, vm) + self.check_closed(vm)?; + let result = self.coro.send(val, vm); + // Mark as closed if exhausted + if let Ok(PyIterReturn::StopIteration(_)) = &result { + self.closed.store(true); + } + result } #[pymethod] @@ -164,11 +182,18 @@ impl PyCoroutineWrapper { exc_tb: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - self.coro.throw(exc_type, exc_val, exc_tb, vm) + self.check_closed(vm)?; + let result = self.coro.throw(exc_type, exc_val, exc_tb, vm); + // Mark as closed if exhausted + if let Ok(PyIterReturn::StopIteration(_)) = &result { + self.closed.store(true); + } + result } #[pymethod] fn close(&self, vm: &VirtualMachine) -> PyResult<()> { + self.closed.store(true); self.coro.close(vm) } } diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index 77b034ade9a..4a460a95884 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -918,6 +918,8 @@ impl ExecutingFrame<'_> { Ok(None) } bytecode::Instruction::GetAwaitable => { + use crate::protocol::PyIter; + let awaited_obj = self.pop_value(); let awaitable = if awaited_obj.downcastable::() { awaited_obj @@ -932,7 +934,15 @@ impl ExecutingFrame<'_> { ) }, )?; - await_method.call((), vm)? + let result = await_method.call((), vm)?; + // Check that __await__ returned an iterator + if !PyIter::check(&result) { + return Err(vm.new_type_error(format!( + "__await__() returned non-iterator of type '{}'", + result.class().name() + ))); + } + result }; self.push_value(awaitable); Ok(None)