diff --git a/Lib/test/test_context.py b/Lib/test/test_context.py index 6083b532223..59d2320de85 100644 --- a/Lib/test/test_context.py +++ b/Lib/test/test_context.py @@ -217,8 +217,6 @@ def fun(): ctx.run(fun) - # TODO: RUSTPYTHON - @unittest.expectedFailure @isolated_context def test_context_getset_1(self): c = contextvars.ContextVar('c') @@ -317,8 +315,6 @@ def test_context_getset_4(self): with self.assertRaisesRegex(ValueError, 'different Context'): c.reset(tok) - # TODO: RUSTPYTHON - @unittest.expectedFailure @isolated_context def test_context_getset_5(self): c = contextvars.ContextVar('c', default=42) @@ -332,8 +328,6 @@ def fun(): contextvars.copy_context().run(fun) self.assertEqual(c.get(), []) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_context_copy_1(self): ctx1 = contextvars.Context() c = contextvars.ContextVar('c', default=42) diff --git a/Lib/test/test_inspect/test_inspect.py b/Lib/test/test_inspect/test_inspect.py index 7fdcfa7c04d..595966e3405 100644 --- a/Lib/test/test_inspect/test_inspect.py +++ b/Lib/test/test_inspect/test_inspect.py @@ -2797,7 +2797,6 @@ def test_easy_debugging(self): self.assertIn(name, repr(state)) self.assertIn(name, str(state)) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_getgeneratorlocals(self): def each(lst, a=None): b=(1, 2, 3) @@ -2985,7 +2984,6 @@ def test_easy_debugging(self): self.assertIn(name, repr(state)) self.assertIn(name, str(state)) - @unittest.expectedFailure # TODO: RUSTPYTHON async def test_getasyncgenlocals(self): async def each(lst, a=None): b=(1, 2, 3) diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 17113408ee6..71b54e286a3 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -3525,7 +3525,6 @@ def test_starttls(self): else: s.close() - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_socketserver(self): """Using socketserver to create and manage SSL connections.""" server = make_https_server(self, certfile=SIGNED_CERTFILE) diff --git a/Lib/test/test_unittest/test_async_case.py b/Lib/test/test_unittest/test_async_case.py index b31c877e22e..b1ccd644343 100644 --- a/Lib/test/test_unittest/test_async_case.py +++ b/Lib/test/test_unittest/test_async_case.py @@ -13,9 +13,7 @@ class MyException(Exception): def tearDownModule(): - # XXX: RUSTPYTHON; asyncio.events._set_event_loop_policy is not implemented - # asyncio.events._set_event_loop_policy(None) - pass + asyncio.events._set_event_loop_policy(None) class TestCM: @@ -52,7 +50,6 @@ def setUp(self): # starting a new event loop self.addCleanup(support.gc_collect) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_full_cycle(self): expected = ['setUp', 'asyncSetUp', diff --git a/crates/codegen/src/compile.rs b/crates/codegen/src/compile.rs index eaff3c52d5b..2f28c6c9683 100644 --- a/crates/codegen/src/compile.rs +++ b/crates/codegen/src/compile.rs @@ -1528,27 +1528,30 @@ impl Compiler { // Otherwise, if an exception occurs during the finally body, the stack // will be unwound to the wrong depth and the return value will be lost. if preserve_tos { - // Get the handler info from the saved fblock (or current handler) - // and create a new handler with stack_depth + 1 - let (handler, stack_depth, preserve_lasti) = - if let Some(handler) = saved_fblock.fb_handler { - ( - Some(handler), - saved_fblock.fb_stack_depth + 1, // +1 for return value - saved_fblock.fb_preserve_lasti, - ) - } else { - // No handler in saved_fblock, check current handler - if let Some(current_handler) = self.current_except_handler() { - ( - Some(current_handler.handler_block), - current_handler.stack_depth + 1, // +1 for return value - current_handler.preserve_lasti, - ) - } else { - (None, 1, false) // No handler, but still track the return value + // Find the outer handler for exceptions during finally body execution. + // CRITICAL: Only search fblocks with index < fblock_idx (= outer fblocks). + // Inner FinallyTry blocks may have been restored after their unwind + // processing, and we must NOT use their handlers - that would cause + // the inner finally body to execute again on exception. + let (handler, stack_depth, preserve_lasti) = { + let code = self.code_stack.last().unwrap(); + let mut found = None; + // Only search fblocks at indices 0..fblock_idx (outer fblocks) + // After removal, fblock_idx now points to where saved_fblock was, + // so indices 0..fblock_idx are the outer fblocks + for i in (0..fblock_idx).rev() { + let fblock = &code.fblock[i]; + if let Some(handler) = fblock.fb_handler { + found = Some(( + Some(handler), + fblock.fb_stack_depth + 1, // +1 for return value + fblock.fb_preserve_lasti, + )); + break; } - }; + } + found.unwrap_or((None, 1, false)) + }; self.push_fblock_with_handler( FBlockType::PopValue, diff --git a/crates/stdlib/src/contextvars.rs b/crates/stdlib/src/contextvars.rs index 57864bf45fa..700a94692d5 100644 --- a/crates/stdlib/src/contextvars.rs +++ b/crates/stdlib/src/contextvars.rs @@ -168,11 +168,15 @@ mod _contextvars { } #[pymethod] - fn copy(&self) -> Self { + fn copy(&self, vm: &VirtualMachine) -> Self { + // Deep copy the vars - clone the underlying Hamt data, not just the PyRef + let vars_copy = HamtObject { + hamt: RefCell::new(self.inner.vars.hamt.borrow().clone()), + }; Self { inner: ContextInner { idx: Cell::new(usize::MAX), - vars: self.inner.vars.clone(), + vars: vars_copy.into_ref(&vm.ctx), entered: Cell::new(false), }, } @@ -630,7 +634,7 @@ mod _contextvars { #[pyfunction] fn copy_context(vm: &VirtualMachine) -> PyContext { - PyContext::current(vm).copy() + PyContext::current(vm).copy(vm) } // Set Token.MISSING attribute diff --git a/crates/stdlib/src/socket.rs b/crates/stdlib/src/socket.rs index b0ebd681e02..2f6e5f14a19 100644 --- a/crates/stdlib/src/socket.rs +++ b/crates/stdlib/src/socket.rs @@ -15,7 +15,10 @@ mod _socket { }, common::os::ErrorExt, convert::{IntoPyException, ToPyObject, TryFromBorrowedObject, TryFromObject}, - function::{ArgBytesLike, ArgMemoryBuffer, Either, FsPath, OptionalArg, OptionalOption}, + function::{ + ArgBytesLike, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath, OptionalArg, + OptionalOption, + }, types::{Constructor, DefaultConstructor, Initializer, Representable}, utils::ToCString, }; @@ -2783,9 +2786,9 @@ mod _socket { #[derive(FromArgs)] struct GAIOptions { #[pyarg(positional)] - host: Option, + host: Option, #[pyarg(positional)] - port: Option>, + port: Option>, #[pyarg(positional, default = c::AF_UNSPEC)] family: i32, @@ -2809,9 +2812,9 @@ mod _socket { flags: opts.flags, }; - // Encode host using IDNA encoding + // Encode host: str uses IDNA encoding, bytes must be valid UTF-8 let host_encoded: Option = match opts.host.as_ref() { - Some(s) => { + Some(ArgStrOrBytesLike::Str(s)) => { let encoded = vm.state .codec_registry @@ -2820,19 +2823,43 @@ mod _socket { .map_err(|_| vm.new_runtime_error("idna output is not utf8".to_owned()))?; Some(host_str.to_owned()) } + Some(ArgStrOrBytesLike::Buf(b)) => { + let bytes = b.borrow_buf(); + let host_str = core::str::from_utf8(&bytes).map_err(|_| { + vm.new_unicode_decode_error("host bytes is not utf8".to_owned()) + })?; + Some(host_str.to_owned()) + } None => None, }; let host = host_encoded.as_deref(); - // Encode port using UTF-8 - let port: Option> = match opts.port.as_ref() { - Some(Either::A(s)) => Some(alloc::borrow::Cow::Borrowed(s.to_str().ok_or_else( - || vm.new_unicode_encode_error("surrogates not allowed".to_owned()), - )?)), - Some(Either::B(i)) => Some(alloc::borrow::Cow::Owned(i.to_string())), + // Encode port: str/bytes as service name, int as port number + let port_encoded: Option = match opts.port.as_ref() { + Some(Either::A(sb)) => { + let port_str = match sb { + ArgStrOrBytesLike::Str(s) => { + // For str, check for surrogates and raise UnicodeEncodeError if found + s.to_str() + .ok_or_else(|| vm.new_unicode_encode_error("surrogates not allowed"))? + .to_owned() + } + ArgStrOrBytesLike::Buf(b) => { + // For bytes, check if it's valid UTF-8 + let bytes = b.borrow_buf(); + core::str::from_utf8(&bytes) + .map_err(|_| { + vm.new_unicode_decode_error("port is not utf8".to_owned()) + })? + .to_owned() + } + }; + Some(port_str) + } + Some(Either::B(i)) => Some(i.to_string()), None => None, }; - let port = port.as_ref().map(|p| p.as_ref()); + let port = port_encoded.as_deref(); let addrs = dns_lookup::getaddrinfo(host, port, Some(hints)) .map_err(|err| convert_socket_error(vm, err, SocketError::GaiError))?; diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index 7d27e259cae..4b31662cfe8 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -53,6 +53,7 @@ mod _ssl { // Import error types used in this module (others are exposed via pymodule(with(...))) use super::error::{ PySSLError, create_ssl_eof_error, create_ssl_want_read_error, create_ssl_want_write_error, + create_ssl_zero_return_error, }; use alloc::sync::Arc; use core::{ @@ -3593,7 +3594,7 @@ mod _ssl { let mut conn_guard = self.connection.lock(); let conn = match conn_guard.as_mut() { Some(conn) => conn, - None => return return_data(vec![], &buffer, vm), + None => return Err(create_ssl_zero_return_error(vm).upcast()), }; use std::io::BufRead; let mut reader = conn.reader(); @@ -3613,8 +3614,20 @@ mod _ssl { return return_data(buf, &buffer, vm); } } - // Clean closure with close_notify - return empty data - return_data(vec![], &buffer, vm) + // Clean closure with close_notify + // CPython behavior depends on whether we've sent our close_notify: + // - If we've already sent close_notify (unwrap was called): raise SSLZeroReturnError + // - If we haven't sent close_notify yet: return empty bytes + let our_shutdown_state = *self.shutdown_state.lock(); + if our_shutdown_state == ShutdownState::SentCloseNotify + || our_shutdown_state == ShutdownState::Completed + { + // We already sent close_notify, now receiving peer's → SSLZeroReturnError + Err(create_ssl_zero_return_error(vm).upcast()) + } else { + // We haven't sent close_notify yet → return empty bytes + return_data(vec![], &buffer, vm) + } } Err(crate::ssl::compat::SslError::WantRead) => { // Non-blocking mode: would block diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index 322fdde5b9a..5bf2cd8b60f 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -1552,6 +1552,11 @@ pub(super) fn ssl_read( // Try to read plaintext from rustls buffer if let Some(n) = try_read_plaintext(conn, buf)? { + if n == 0 { + // EOF from TLS - close_notify received + // Return ZeroReturn so Python raises SSLZeroReturnError + return Err(SslError::ZeroReturn); + } return Ok(n); } @@ -1740,17 +1745,40 @@ pub(super) fn ssl_write( let already_buffered = *socket.write_buffered_len.lock(); // Only write plaintext if not already buffered + // Track how much we wrote for partial write handling + let mut bytes_written_to_rustls = 0usize; + if already_buffered == 0 { // Write plaintext to rustls (= SSL_write_ex internal buffer write) - { + bytes_written_to_rustls = { let mut writer = conn.writer(); use std::io::Write; - writer - .write_all(data) - .map_err(|e| SslError::Syscall(format!("Write failed: {e}")))?; - } - // Mark data as buffered - *socket.write_buffered_len.lock() = data.len(); + // Use write() instead of write_all() to support partial writes. + // In BIO mode (asyncio), when the internal buffer is full, + // we want to write as much as possible and return that count, + // rather than failing completely. + match writer.write(data) { + Ok(0) if !data.is_empty() => { + // Buffer is full and nothing could be written. + // In BIO mode, return WantWrite so the caller can + // drain the outgoing BIO and retry. + if is_bio { + return Err(SslError::WantWrite); + } + return Err(SslError::Syscall("Write failed: buffer full".to_string())); + } + Ok(n) => n, + Err(e) => { + if is_bio { + // In BIO mode, treat write errors as WantWrite + return Err(SslError::WantWrite); + } + return Err(SslError::Syscall(format!("Write failed: {e}"))); + } + } + }; + // Mark data as buffered (only the portion we actually wrote) + *socket.write_buffered_len.lock() = bytes_written_to_rustls; } else if already_buffered != data.len() { // Caller is retrying with different data - this is a protocol error // Clear the buffer state and return an SSL error (bad write retry) @@ -1790,13 +1818,23 @@ pub(super) fn ssl_write( } Err(SslError::WantWrite) => { // Non-blocking socket would block - return WANT_WRITE + // If we had a partial write to rustls, return partial success + // instead of error to match OpenSSL partial-write semantics + if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() { + *socket.write_buffered_len.lock() = 0; + return Ok(bytes_written_to_rustls); + } // Keep write_buffered_len set so we don't re-buffer on retry return Err(SslError::WantWrite); } Err(SslError::WantRead) => { // Need to read before write can complete (e.g., renegotiation) - // This matches CPython's handling of SSL_ERROR_WANT_READ in write if is_bio { + // If we had a partial write to rustls, return partial success + if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() { + *socket.write_buffered_len.lock() = 0; + return Ok(bytes_written_to_rustls); + } // Keep write_buffered_len set so we don't re-buffer on retry return Err(SslError::WantRead); } @@ -1807,6 +1845,11 @@ pub(super) fn ssl_write( // Continue loop } Err(e @ SslError::Timeout(_)) => { + // If we had a partial write to rustls, return partial success + if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() { + *socket.write_buffered_len.lock() = 0; + return Ok(bytes_written_to_rustls); + } // Preserve buffered state so retry doesn't duplicate data // (send_all_bytes saved unsent TLS bytes to pending_tls_output) return Err(e); @@ -1826,10 +1869,21 @@ pub(super) fn ssl_write( .map_err(SslError::Py)?; } + // Determine how many bytes we actually wrote + let actual_written = if bytes_written_to_rustls > 0 { + // Fresh write: return what we wrote to rustls + bytes_written_to_rustls + } else if already_buffered > 0 { + // Retry of previous write: return the full buffered amount + already_buffered + } else { + data.len() + }; + // Write completed successfully - clear buffer state *socket.write_buffered_len.lock() = 0; - Ok(data.len()) + Ok(actual_written) } // Helper functions (private-ish, used by public SSL functions) diff --git a/crates/vm/src/builtins/asyncgenerator.rs b/crates/vm/src/builtins/asyncgenerator.rs index 8b7c107d4b8..a77aaf518bd 100644 --- a/crates/vm/src/builtins/asyncgenerator.rs +++ b/crates/vm/src/builtins/asyncgenerator.rs @@ -123,8 +123,12 @@ impl PyAsyncGen { self.inner.frame().yield_from_target() } #[pygetset] - fn ag_frame(&self, _vm: &VirtualMachine) -> FrameRef { - self.inner.frame() + fn ag_frame(&self, _vm: &VirtualMachine) -> Option { + if self.inner.closed() { + None + } else { + Some(self.inner.frame()) + } } #[pygetset] fn ag_running(&self, _vm: &VirtualMachine) -> bool { diff --git a/crates/vm/src/builtins/coroutine.rs b/crates/vm/src/builtins/coroutine.rs index 961c352f8df..d2a70e54229 100644 --- a/crates/vm/src/builtins/coroutine.rs +++ b/crates/vm/src/builtins/coroutine.rs @@ -76,8 +76,12 @@ impl PyCoroutine { self.inner.frame().yield_from_target() } #[pygetset] - fn cr_frame(&self, _vm: &VirtualMachine) -> FrameRef { - self.inner.frame() + fn cr_frame(&self, _vm: &VirtualMachine) -> Option { + if self.inner.closed() { + None + } else { + Some(self.inner.frame()) + } } #[pygetset] fn cr_running(&self, _vm: &VirtualMachine) -> bool { diff --git a/crates/vm/src/builtins/generator.rs b/crates/vm/src/builtins/generator.rs index ceae2e61c3b..dec7d82add0 100644 --- a/crates/vm/src/builtins/generator.rs +++ b/crates/vm/src/builtins/generator.rs @@ -66,8 +66,12 @@ impl PyGenerator { } #[pygetset] - fn gi_frame(&self, _vm: &VirtualMachine) -> FrameRef { - self.inner.frame() + fn gi_frame(&self, _vm: &VirtualMachine) -> Option { + if self.inner.closed() { + None + } else { + Some(self.inner.frame()) + } } #[pygetset] diff --git a/crates/vm/src/stdlib/posix.rs b/crates/vm/src/stdlib/posix.rs index 699989be133..84d27f2286e 100644 --- a/crates/vm/src/stdlib/posix.rs +++ b/crates/vm/src/stdlib/posix.rs @@ -2574,7 +2574,14 @@ pub mod module { headers, trailers, ); - res.map_err(|err| err.into_pyexception(vm))?; + // On macOS, sendfile can return EAGAIN even when some bytes were written. + // In that case, we should return the number of bytes written rather than + // raising an exception. Only raise an error if no bytes were written. + if let Err(err) = res + && written == 0 + { + return Err(err.into_pyexception(vm)); + } Ok(vm.ctx.new_int(written as u64).into()) }