diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index e25aebc1af8..5bb6c259e66 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -1,4 +1,4 @@ -// spell-checker: ignore ssleof aesccm aesgcm getblocking setblocking ENDTLS TLSEXT +// spell-checker: ignore ssleof aesccm aesgcm capath getblocking setblocking ENDTLS TLSEXT //! Pure Rust SSL/TLS implementation using rustls //! @@ -2786,6 +2786,16 @@ mod _ssl { recv_method.call((self.sock.clone(), vm.ctx.new_int(size)), vm) } + /// Peek at socket data without consuming it (MSG_PEEK). + /// Used during TLS shutdown to avoid consuming post-TLS cleartext data. + pub(crate) fn sock_peek(&self, size: usize, vm: &VirtualMachine) -> PyResult { + let socket_mod = vm.import("socket", 0)?; + let socket_class = socket_mod.get_attr("socket", vm)?; + let recv_method = socket_class.get_attr("recv", vm)?; + let msg_peek = socket_mod.get_attr("MSG_PEEK", vm)?; + recv_method.call((self.sock.clone(), vm.ctx.new_int(size), msg_peek), vm) + } + /// Socket send - just sends data, caller must handle pending flush /// Use flush_pending_tls_output before this if ordering is important pub(crate) fn sock_send(&self, data: &[u8], vm: &VirtualMachine) -> PyResult { @@ -4287,45 +4297,118 @@ mod _ssl { conn: &mut TlsConnection, vm: &VirtualMachine, ) -> PyResult { - // Try to read incoming data + // In socket mode, peek first to avoid consuming post-TLS cleartext + // data. During STARTTLS, after close_notify exchange, the socket + // transitions to cleartext. Without peeking, sock_recv may consume + // cleartext data meant for the application after unwrap(). + if self.incoming_bio.is_none() { + return self.try_read_close_notify_socket(conn, vm); + } + + // BIO mode: read from incoming BIO match self.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) { Ok(bytes_obj) => { let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?; let data = bytes.borrow_buf(); if data.is_empty() { - // Empty read could mean EOF or just "no data yet" in BIO mode if let Some(ref bio) = self.incoming_bio { // BIO mode: check if EOF was signaled via write_eof() let bio_obj: PyObjectRef = bio.clone().into(); let eof_attr = bio_obj.get_attr("eof", vm)?; let is_eof = eof_attr.try_to_bool(vm)?; if !is_eof { - // No EOF signaled, just no data available yet return Ok(false); } } - // Socket mode or BIO with EOF: peer closed connection - // This is "ragged EOF" - peer closed without close_notify return Ok(true); } - // Feed data to TLS connection let data_slice: &[u8] = data.as_ref(); let mut cursor = std::io::Cursor::new(data_slice); let _ = conn.read_tls(&mut cursor); + let _ = conn.process_new_packets(); + Ok(false) + } + Err(e) => { + if is_blocking_io_error(&e, vm) { + return Ok(false); + } + Ok(true) + } + } + } - // Process packets + /// Socket-mode close_notify reader that respects TLS record boundaries. + /// Uses MSG_PEEK to inspect data before consuming, preventing accidental + /// consumption of post-TLS cleartext data during STARTTLS transitions. + /// + /// Equivalent to OpenSSL's `SSL_set_read_ahead(ssl, 0)` — rustls has no + /// such knob, so we enforce record-level reads manually via peek. + fn try_read_close_notify_socket( + &self, + conn: &mut TlsConnection, + vm: &VirtualMachine, + ) -> PyResult { + // Peek at the first 5 bytes (TLS record header size) + let peeked_obj = match self.sock_peek(5, vm) { + Ok(obj) => obj, + Err(e) => { + if is_blocking_io_error(&e, vm) { + return Ok(false); + } + return Ok(true); + } + }; + + let peeked = ArgBytesLike::try_from_object(vm, peeked_obj)?; + let peek_data = peeked.borrow_buf(); + + if peek_data.is_empty() { + return Ok(true); // EOF + } + + // TLS record content types: ChangeCipherSpec(20), Alert(21), + // Handshake(22), ApplicationData(23) + let content_type = peek_data[0]; + if !(20..=23).contains(&content_type) { + // Not a TLS record - post-TLS cleartext data. + // Peer has completed TLS shutdown; don't consume this data. + return Ok(true); + } + + // Determine how many bytes to read for exactly one TLS record + let recv_size = if peek_data.len() >= 5 { + let record_length = u16::from_be_bytes([peek_data[3], peek_data[4]]) as usize; + 5 + record_length + } else { + // Partial header available - read just these bytes for now + peek_data.len() + }; + + drop(peek_data); + drop(peeked); + + // Now consume exactly one TLS record from the socket + match self.sock_recv(recv_size, vm) { + Ok(bytes_obj) => { + let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?; + let data = bytes.borrow_buf(); + + if data.is_empty() { + return Ok(true); + } + + let data_slice: &[u8] = data.as_ref(); + let mut cursor = std::io::Cursor::new(data_slice); + let _ = conn.read_tls(&mut cursor); let _ = conn.process_new_packets(); Ok(false) } Err(e) => { - // BlockingIOError means no data yet if is_blocking_io_error(&e, vm) { return Ok(false); } - // Connection reset, EOF, or other error means peer closed - // ECONNRESET, EPIPE, broken pipe, etc. Ok(true) } } diff --git a/extra_tests/snippets/builtin_list.py b/extra_tests/snippets/builtin_list.py index 711bf6acc6c..d4afbffa1cb 100644 --- a/extra_tests/snippets/builtin_list.py +++ b/extra_tests/snippets/builtin_list.py @@ -270,6 +270,7 @@ def __gt__(self, other): lst.sort(key=C) assert lst == [1, 2, 3, 4, 5] + # Test that sorted() uses __lt__ (not __gt__) for comparisons. # Track which comparison method is actually called during sort. class TrackComparison: @@ -287,13 +288,16 @@ def __gt__(self, other): TrackComparison.gt_calls += 1 return self.value > other.value + # Reset and test sorted() TrackComparison.lt_calls = 0 TrackComparison.gt_calls = 0 items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)] sorted(items) assert TrackComparison.lt_calls > 0, "sorted() should call __lt__" -assert TrackComparison.gt_calls == 0, f"sorted() should not call __gt__, but it was called {TrackComparison.gt_calls} times" +assert TrackComparison.gt_calls == 0, ( + f"sorted() should not call __gt__, but it was called {TrackComparison.gt_calls} times" +) # Reset and test list.sort() TrackComparison.lt_calls = 0 @@ -301,7 +305,9 @@ def __gt__(self, other): items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)] items.sort() assert TrackComparison.lt_calls > 0, "list.sort() should call __lt__" -assert TrackComparison.gt_calls == 0, f"list.sort() should not call __gt__, but it was called {TrackComparison.gt_calls} times" +assert TrackComparison.gt_calls == 0, ( + f"list.sort() should not call __gt__, but it was called {TrackComparison.gt_calls} times" +) # Reset and test sorted(reverse=True) - should still use __lt__, not __gt__ TrackComparison.lt_calls = 0 @@ -309,7 +315,9 @@ def __gt__(self, other): items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)] sorted(items, reverse=True) assert TrackComparison.lt_calls > 0, "sorted(reverse=True) should call __lt__" -assert TrackComparison.gt_calls == 0, f"sorted(reverse=True) should not call __gt__, but it was called {TrackComparison.gt_calls} times" +assert TrackComparison.gt_calls == 0, ( + f"sorted(reverse=True) should not call __gt__, but it was called {TrackComparison.gt_calls} times" +) lst = [5, 1, 2, 3, 4]