diff --git a/crates/vm/src/builtins/bool.rs b/crates/vm/src/builtins/bool.rs index 3dabdbae717..24ded08ab10 100644 --- a/crates/vm/src/builtins/bool.rs +++ b/crates/vm/src/builtins/bool.rs @@ -9,7 +9,6 @@ use crate::{ types::{AsNumber, Constructor, Representable}, }; use core::fmt::{Debug, Formatter}; -use malachite_bigint::Sign; use num_traits::Zero; impl ToPyObject for bool { @@ -42,46 +41,28 @@ impl PyObjectRef { if self.is(&vm.ctx.false_value) { return Ok(false); } - let rs_bool = if let Some(nb_bool) = self.class().slots.as_number.boolean.load() { - nb_bool(self.as_object().number(), vm)? - } else { - // TODO: Fully implement AsNumber and remove this block - match vm.get_method(self.clone(), identifier!(vm, __bool__)) { - Some(method_or_err) => { - // If descriptor returns Error, propagate it further - let method = method_or_err?; - let bool_obj = method.call((), vm)?; - if !bool_obj.fast_isinstance(vm.ctx.types.bool_type) { - return Err(vm.new_type_error(format!( - "__bool__ should return bool, returned type {}", - bool_obj.class().name() - ))); - } - - get_value(&bool_obj) - } - None => match vm.get_method(self, identifier!(vm, __len__)) { - Some(method_or_err) => { - let method = method_or_err?; - let bool_obj = method.call((), vm)?; - let int_obj = bool_obj.downcast_ref::().ok_or_else(|| { - vm.new_type_error(format!( - "'{}' object cannot be interpreted as an integer", - bool_obj.class().name() - )) - })?; - - let len_val = int_obj.as_bigint(); - if len_val.sign() == Sign::Minus { - return Err(vm.new_value_error("__len__() should return >= 0")); - } - !len_val.is_zero() - } - None => true, - }, - } - }; - Ok(rs_bool) + + let slots = &self.class().slots; + + // 1. Try nb_bool slot first + if let Some(nb_bool) = slots.as_number.boolean.load() { + return nb_bool(self.as_object().number(), vm); + } + + // 2. Try mp_length slot (mapping protocol) + if let Some(mp_length) = slots.as_mapping.length.load() { + let len = mp_length(self.as_object().mapping_unchecked(), vm)?; + return Ok(len != 0); + } + + // 3. Try sq_length slot (sequence protocol) + if let Some(sq_length) = slots.as_sequence.length.load() { + let len = sq_length(self.as_object().sequence_unchecked(), vm)?; + return Ok(len != 0); + } + + // 4. Default: objects without __bool__ or __len__ are truthy + Ok(true) } } diff --git a/crates/vm/src/builtins/bytearray.rs b/crates/vm/src/builtins/bytearray.rs index 212e4604ec9..dc5ee100acf 100644 --- a/crates/vm/src/builtins/bytearray.rs +++ b/crates/vm/src/builtins/bytearray.rs @@ -206,7 +206,6 @@ impl PyByteArray { self.inner().capacity() } - #[pymethod] fn __len__(&self) -> usize { self.borrow_buf().len() } diff --git a/crates/vm/src/builtins/bytes.rs b/crates/vm/src/builtins/bytes.rs index b3feac8ac97..7e5b98b1ece 100644 --- a/crates/vm/src/builtins/bytes.rs +++ b/crates/vm/src/builtins/bytes.rs @@ -205,7 +205,6 @@ impl PyRef { )] impl PyBytes { #[inline] - #[pymethod] pub const fn __len__(&self) -> usize { self.inner.len() } diff --git a/crates/vm/src/builtins/dict.rs b/crates/vm/src/builtins/dict.rs index 358685fcdc4..693505b1f82 100644 --- a/crates/vm/src/builtins/dict.rs +++ b/crates/vm/src/builtins/dict.rs @@ -212,7 +212,6 @@ impl PyDict { } } - #[pymethod] pub fn __len__(&self) -> usize { self.entries.len() } @@ -764,7 +763,6 @@ trait DictView: PyPayload + PyClassDef + Iterable + Representable { fn dict(&self) -> &Py; fn item(vm: &VirtualMachine, key: PyObjectRef, value: PyObjectRef) -> PyObjectRef; - #[pymethod] fn __len__(&self) -> usize { self.dict().__len__() } diff --git a/crates/vm/src/builtins/list.rs b/crates/vm/src/builtins/list.rs index 46145b339cf..52df0756498 100644 --- a/crates/vm/src/builtins/list.rs +++ b/crates/vm/src/builtins/list.rs @@ -182,7 +182,6 @@ impl PyList { } #[allow(clippy::len_without_is_empty)] - #[pymethod] pub fn __len__(&self) -> usize { self.borrow_vec().len() } diff --git a/crates/vm/src/builtins/mappingproxy.rs b/crates/vm/src/builtins/mappingproxy.rs index 475f36cb5a9..34a598b03e0 100644 --- a/crates/vm/src/builtins/mappingproxy.rs +++ b/crates/vm/src/builtins/mappingproxy.rs @@ -173,7 +173,6 @@ impl PyMappingProxy { PyGenericAlias::from_args(cls, args, vm) } - #[pymethod] fn __len__(&self, vm: &VirtualMachine) -> PyResult { let obj = self.to_object(vm)?; obj.length(vm) @@ -235,6 +234,7 @@ impl AsMapping for PyMappingProxy { impl AsSequence for PyMappingProxy { fn as_sequence() -> &'static PySequenceMethods { static AS_SEQUENCE: LazyLock = LazyLock::new(|| PySequenceMethods { + length: atomic_func!(|seq, vm| PyMappingProxy::sequence_downcast(seq).__len__(vm)), contains: atomic_func!( |seq, target, vm| PyMappingProxy::sequence_downcast(seq)._contains(target, vm) ), diff --git a/crates/vm/src/builtins/memory.rs b/crates/vm/src/builtins/memory.rs index ff5df031c42..5ca4257cded 100644 --- a/crates/vm/src/builtins/memory.rs +++ b/crates/vm/src/builtins/memory.rs @@ -690,7 +690,6 @@ impl PyMemoryView { Err(vm.new_type_error("cannot delete memory")) } - #[pymethod] fn __len__(&self, vm: &VirtualMachine) -> PyResult { self.try_not_released(vm)?; if self.desc.ndim() == 0 { diff --git a/crates/vm/src/builtins/set.rs b/crates/vm/src/builtins/set.rs index b1236e44e93..1c0268e2ed7 100644 --- a/crates/vm/src/builtins/set.rs +++ b/crates/vm/src/builtins/set.rs @@ -533,11 +533,14 @@ fn reduce_set( flags(BASETYPE, _MATCH_SELF) )] impl PySet { - #[pymethod] fn __len__(&self) -> usize { self.inner.len() } + fn __contains__(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult { + self.inner.contains(needle, vm) + } + #[pymethod] fn __sizeof__(&self) -> usize { core::mem::size_of::() + self.inner.sizeof() @@ -550,11 +553,6 @@ impl PySet { } } - #[pymethod] - fn __contains__(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.contains(&needle, vm) - } - #[pymethod] fn union(&self, others: PosArgs, vm: &VirtualMachine) -> PyResult { self.fold_op(others.into_iter(), PySetInner::union, vm) @@ -594,8 +592,6 @@ impl PySet { self.inner.isdisjoint(other, vm) } - #[pymethod(name = "__ror__")] - #[pymethod] fn __or__(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult> { if let Ok(other) = AnySet::try_from_object(vm, other) { Ok(PyArithmeticValue::Implemented(self.op( @@ -608,8 +604,6 @@ impl PySet { } } - #[pymethod(name = "__rand__")] - #[pymethod] fn __and__( &self, other: PyObjectRef, @@ -626,7 +620,6 @@ impl PySet { } } - #[pymethod] fn __sub__( &self, other: PyObjectRef, @@ -643,7 +636,6 @@ impl PySet { } } - #[pymethod] fn __rsub__( zelf: PyRef, other: PyObjectRef, @@ -660,8 +652,6 @@ impl PySet { } } - #[pymethod(name = "__rxor__")] - #[pymethod] fn __xor__( &self, other: PyObjectRef, @@ -705,7 +695,6 @@ impl PySet { self.inner.pop(vm) } - #[pymethod] fn __ior__(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { zelf.inner.update(set.into_iterable_iter(vm)?, vm)?; Ok(zelf) @@ -729,7 +718,6 @@ impl PySet { Ok(()) } - #[pymethod] fn __iand__(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { zelf.inner .intersection_update(core::iter::once(set.into_iterable(vm)?), vm)?; @@ -742,7 +730,6 @@ impl PySet { Ok(()) } - #[pymethod] fn __isub__(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { zelf.inner .difference_update(set.into_iterable_iter(vm)?, vm)?; @@ -760,7 +747,6 @@ impl PySet { Ok(()) } - #[pymethod] fn __ixor__(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { zelf.inner .symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?; @@ -799,9 +785,9 @@ impl AsSequence for PySet { fn as_sequence() -> &'static PySequenceMethods { static AS_SEQUENCE: LazyLock = LazyLock::new(|| PySequenceMethods { length: atomic_func!(|seq, _vm| Ok(PySet::sequence_downcast(seq).__len__())), - contains: atomic_func!(|seq, needle, vm| PySet::sequence_downcast(seq) - .inner - .contains(needle, vm)), + contains: atomic_func!( + |seq, needle, vm| PySet::sequence_downcast(seq).__contains__(needle, vm) + ), ..PySequenceMethods::NOT_IMPLEMENTED }); &AS_SEQUENCE @@ -830,30 +816,77 @@ impl Iterable for PySet { impl AsNumber for PySet { fn as_number() -> &'static PyNumberMethods { static AS_NUMBER: PyNumberMethods = PyNumberMethods { + // Binary ops check both operands are sets (like CPython's set_sub, etc.) + // This is needed because __rsub__ swaps operands: a.__rsub__(b) calls subtract(b, a) subtract: Some(|a, b, vm| { + if !AnySet::check(a, vm) || !AnySet::check(b, vm) { + return Ok(vm.ctx.not_implemented()); + } if let Some(a) = a.downcast_ref::() { a.__sub__(b.to_owned(), vm).to_pyresult(vm) + } else if let Some(a) = a.downcast_ref::() { + // When called via __rsub__, a might be PyFrozenSet + a.__sub__(b.to_owned(), vm) + .map(|r| { + r.map(|s| PySet { + inner: s.inner.clone(), + }) + }) + .to_pyresult(vm) } else { Ok(vm.ctx.not_implemented()) } }), and: Some(|a, b, vm| { + if !AnySet::check(a, vm) || !AnySet::check(b, vm) { + return Ok(vm.ctx.not_implemented()); + } if let Some(a) = a.downcast_ref::() { a.__and__(b.to_owned(), vm).to_pyresult(vm) + } else if let Some(a) = a.downcast_ref::() { + a.__and__(b.to_owned(), vm) + .map(|r| { + r.map(|s| PySet { + inner: s.inner.clone(), + }) + }) + .to_pyresult(vm) } else { Ok(vm.ctx.not_implemented()) } }), xor: Some(|a, b, vm| { + if !AnySet::check(a, vm) || !AnySet::check(b, vm) { + return Ok(vm.ctx.not_implemented()); + } if let Some(a) = a.downcast_ref::() { a.__xor__(b.to_owned(), vm).to_pyresult(vm) + } else if let Some(a) = a.downcast_ref::() { + a.__xor__(b.to_owned(), vm) + .map(|r| { + r.map(|s| PySet { + inner: s.inner.clone(), + }) + }) + .to_pyresult(vm) } else { Ok(vm.ctx.not_implemented()) } }), or: Some(|a, b, vm| { + if !AnySet::check(a, vm) || !AnySet::check(b, vm) { + return Ok(vm.ctx.not_implemented()); + } if let Some(a) = a.downcast_ref::() { a.__or__(b.to_owned(), vm).to_pyresult(vm) + } else if let Some(a) = a.downcast_ref::() { + a.__or__(b.to_owned(), vm) + .map(|r| { + r.map(|s| PySet { + inner: s.inner.clone(), + }) + }) + .to_pyresult(vm) } else { Ok(vm.ctx.not_implemented()) } @@ -972,11 +1005,14 @@ impl Constructor for PyFrozenSet { ) )] impl PyFrozenSet { - #[pymethod] fn __len__(&self) -> usize { self.inner.len() } + fn __contains__(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult { + self.inner.contains(needle, vm) + } + #[pymethod] fn __sizeof__(&self) -> usize { core::mem::size_of::() + self.inner.sizeof() @@ -995,11 +1031,6 @@ impl PyFrozenSet { } } - #[pymethod] - fn __contains__(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.contains(&needle, vm) - } - #[pymethod] fn union(&self, others: PosArgs, vm: &VirtualMachine) -> PyResult { self.fold_op(others.into_iter(), PySetInner::union, vm) @@ -1039,8 +1070,6 @@ impl PyFrozenSet { self.inner.isdisjoint(other, vm) } - #[pymethod(name = "__ror__")] - #[pymethod] fn __or__(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult> { if let Ok(set) = AnySet::try_from_object(vm, other) { Ok(PyArithmeticValue::Implemented(self.op( @@ -1053,8 +1082,6 @@ impl PyFrozenSet { } } - #[pymethod(name = "__rand__")] - #[pymethod] fn __and__( &self, other: PyObjectRef, @@ -1071,7 +1098,6 @@ impl PyFrozenSet { } } - #[pymethod] fn __sub__( &self, other: PyObjectRef, @@ -1088,7 +1114,6 @@ impl PyFrozenSet { } } - #[pymethod] fn __rsub__( zelf: PyRef, other: PyObjectRef, @@ -1106,8 +1131,6 @@ impl PyFrozenSet { } } - #[pymethod(name = "__rxor__")] - #[pymethod] fn __xor__( &self, other: PyObjectRef, @@ -1142,9 +1165,9 @@ impl AsSequence for PyFrozenSet { fn as_sequence() -> &'static PySequenceMethods { static AS_SEQUENCE: LazyLock = LazyLock::new(|| PySequenceMethods { length: atomic_func!(|seq, _vm| Ok(PyFrozenSet::sequence_downcast(seq).__len__())), - contains: atomic_func!(|seq, needle, vm| PyFrozenSet::sequence_downcast(seq) - .inner - .contains(needle, vm)), + contains: atomic_func!( + |seq, needle, vm| PyFrozenSet::sequence_downcast(seq).__contains__(needle, vm) + ), ..PySequenceMethods::NOT_IMPLEMENTED }); &AS_SEQUENCE @@ -1196,30 +1219,53 @@ impl Iterable for PyFrozenSet { impl AsNumber for PyFrozenSet { fn as_number() -> &'static PyNumberMethods { static AS_NUMBER: PyNumberMethods = PyNumberMethods { + // Binary ops check both operands are sets (like CPython's set_sub, etc.) + // __rsub__ swaps operands. Result type follows first operand's type. subtract: Some(|a, b, vm| { + if !AnySet::check(a, vm) || !AnySet::check(b, vm) { + return Ok(vm.ctx.not_implemented()); + } if let Some(a) = a.downcast_ref::() { a.__sub__(b.to_owned(), vm).to_pyresult(vm) + } else if let Some(a) = a.downcast_ref::() { + // When called via __rsub__, a might be PySet - return set (not frozenset) + a.__sub__(b.to_owned(), vm).to_pyresult(vm) } else { Ok(vm.ctx.not_implemented()) } }), and: Some(|a, b, vm| { + if !AnySet::check(a, vm) || !AnySet::check(b, vm) { + return Ok(vm.ctx.not_implemented()); + } if let Some(a) = a.downcast_ref::() { a.__and__(b.to_owned(), vm).to_pyresult(vm) + } else if let Some(a) = a.downcast_ref::() { + a.__and__(b.to_owned(), vm).to_pyresult(vm) } else { Ok(vm.ctx.not_implemented()) } }), xor: Some(|a, b, vm| { + if !AnySet::check(a, vm) || !AnySet::check(b, vm) { + return Ok(vm.ctx.not_implemented()); + } if let Some(a) = a.downcast_ref::() { a.__xor__(b.to_owned(), vm).to_pyresult(vm) + } else if let Some(a) = a.downcast_ref::() { + a.__xor__(b.to_owned(), vm).to_pyresult(vm) } else { Ok(vm.ctx.not_implemented()) } }), or: Some(|a, b, vm| { + if !AnySet::check(a, vm) || !AnySet::check(b, vm) { + return Ok(vm.ctx.not_implemented()); + } if let Some(a) = a.downcast_ref::() { a.__or__(b.to_owned(), vm).to_pyresult(vm) + } else if let Some(a) = a.downcast_ref::() { + a.__or__(b.to_owned(), vm).to_pyresult(vm) } else { Ok(vm.ctx.not_implemented()) } @@ -1252,6 +1298,13 @@ struct AnySet { } impl AnySet { + /// Check if object is a set or frozenset (including subclasses) + /// Equivalent to CPython's PyAnySet_Check + fn check(obj: &PyObject, vm: &VirtualMachine) -> bool { + let ctx = &vm.ctx; + obj.fast_isinstance(ctx.types.set_type) || obj.fast_isinstance(ctx.types.frozenset_type) + } + fn into_iterable(self, vm: &VirtualMachine) -> PyResult { self.object.try_into_value(vm) } diff --git a/crates/vm/src/builtins/tuple.rs b/crates/vm/src/builtins/tuple.rs index e0b8009e892..910b0c8a204 100644 --- a/crates/vm/src/builtins/tuple.rs +++ b/crates/vm/src/builtins/tuple.rs @@ -298,7 +298,6 @@ impl PyTuple { } #[inline] - #[pymethod] pub const fn __len__(&self) -> usize { self.elements.len() } diff --git a/crates/vm/src/protocol/number.rs b/crates/vm/src/protocol/number.rs index 58891d1d710..542afce2c6c 100644 --- a/crates/vm/src/protocol/number.rs +++ b/crates/vm/src/protocol/number.rs @@ -356,21 +356,27 @@ impl PyNumberSlots { pub fn copy_from(&self, methods: &PyNumberMethods) { if let Some(f) = methods.add { self.add.store(Some(f)); + self.right_add.store(Some(f)); } if let Some(f) = methods.subtract { self.subtract.store(Some(f)); + self.right_subtract.store(Some(f)); } if let Some(f) = methods.multiply { self.multiply.store(Some(f)); + self.right_multiply.store(Some(f)); } if let Some(f) = methods.remainder { self.remainder.store(Some(f)); + self.right_remainder.store(Some(f)); } if let Some(f) = methods.divmod { self.divmod.store(Some(f)); + self.right_divmod.store(Some(f)); } if let Some(f) = methods.power { self.power.store(Some(f)); + self.right_power.store(Some(f)); } if let Some(f) = methods.negative { self.negative.store(Some(f)); @@ -389,18 +395,23 @@ impl PyNumberSlots { } if let Some(f) = methods.lshift { self.lshift.store(Some(f)); + self.right_lshift.store(Some(f)); } if let Some(f) = methods.rshift { self.rshift.store(Some(f)); + self.right_rshift.store(Some(f)); } if let Some(f) = methods.and { self.and.store(Some(f)); + self.right_and.store(Some(f)); } if let Some(f) = methods.xor { self.xor.store(Some(f)); + self.right_xor.store(Some(f)); } if let Some(f) = methods.or { self.or.store(Some(f)); + self.right_or.store(Some(f)); } if let Some(f) = methods.int { self.int.store(Some(f)); @@ -440,9 +451,11 @@ impl PyNumberSlots { } if let Some(f) = methods.floor_divide { self.floor_divide.store(Some(f)); + self.right_floor_divide.store(Some(f)); } if let Some(f) = methods.true_divide { self.true_divide.store(Some(f)); + self.right_true_divide.store(Some(f)); } if let Some(f) = methods.inplace_floor_divide { self.inplace_floor_divide.store(Some(f)); @@ -455,6 +468,7 @@ impl PyNumberSlots { } if let Some(f) = methods.matrix_multiply { self.matrix_multiply.store(Some(f)); + self.right_matrix_multiply.store(Some(f)); } if let Some(f) = methods.inplace_matrix_multiply { self.inplace_matrix_multiply.store(Some(f)); diff --git a/crates/vm/src/stdlib/collections.rs b/crates/vm/src/stdlib/collections.rs index 67e8e1f2734..9b7a78f7237 100644 --- a/crates/vm/src/stdlib/collections.rs +++ b/crates/vm/src/stdlib/collections.rs @@ -350,16 +350,10 @@ mod _collections { Ok(zelf) } - #[pymethod] fn __len__(&self) -> usize { self.borrow_deque().len() } - #[pymethod] - fn __add__(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.concat(&other, vm) - } - fn concat(&self, other: &PyObject, vm: &VirtualMachine) -> PyResult { if let Some(o) = other.downcast_ref::() { let mut deque = self.borrow_deque().clone(); diff --git a/crates/vm/src/stdlib/ctypes/array.rs b/crates/vm/src/stdlib/ctypes/array.rs index 5c298d26a56..63594879878 100644 --- a/crates/vm/src/stdlib/ctypes/array.rs +++ b/crates/vm/src/stdlib/ctypes/array.rs @@ -1087,7 +1087,6 @@ impl PyCArray { Ok(()) } - #[pymethod] fn __len__(zelf: &Py, _vm: &VirtualMachine) -> usize { zelf.class().stg_info_opt().map_or(0, |i| i.length) } diff --git a/crates/vm/src/stdlib/winreg.rs b/crates/vm/src/stdlib/winreg.rs index 400cab44210..28fa2e9b74c 100644 --- a/crates/vm/src/stdlib/winreg.rs +++ b/crates/vm/src/stdlib/winreg.rs @@ -222,7 +222,6 @@ mod winreg { zelf.Close(vm) } - #[pymethod] fn __int__(&self) -> usize { self.hkey.load() as usize } diff --git a/crates/vm/src/types/slot.rs b/crates/vm/src/types/slot.rs index d7ead562063..17b2b4e2f8a 100644 --- a/crates/vm/src/types/slot.rs +++ b/crates/vm/src/types/slot.rs @@ -607,7 +607,9 @@ impl PyType { debug_assert!(name.as_str().ends_with("__")); // Find all slot_defs matching this name and update each - for def in find_slot_defs_by_name(name.as_str()) { + // NOTE: Collect into Vec first to avoid issues during iteration + let defs: Vec<_> = find_slot_defs_by_name(name.as_str()).collect(); + for def in defs { self.update_one_slot::(&def.accessor, name, ctx); } @@ -676,7 +678,29 @@ impl PyType { macro_rules! update_sub_slot { ($group:ident, $slot:ident, $wrapper:expr, $variant:ident) => {{ if ADD { - if let Some(func) = self.lookup_slot_in_mro(name, ctx, |sf| { + // Check if this type defines any method that maps to this slot. + // Some slots like SqAssItem/MpAssSubscript are shared by multiple + // methods (__setitem__ and __delitem__). If any of those methods + // is defined, we must use the wrapper to ensure Python method calls. + let has_own = { + let guard = self.attributes.read(); + // Check the current method name + let mut result = guard.contains_key(name); + // For ass_item/ass_subscript slots, also check the paired method + // (__setitem__ and __delitem__ share the same slot) + if !result + && (stringify!($slot) == "ass_item" + || stringify!($slot) == "ass_subscript") + { + let setitem = ctx.intern_str("__setitem__"); + let delitem = ctx.intern_str("__delitem__"); + result = guard.contains_key(setitem) || guard.contains_key(delitem); + } + result + }; + if has_own { + self.slots.$group.$slot.store(Some($wrapper)); + } else if let Some(func) = self.lookup_slot_in_mro(name, ctx, |sf| { if let SlotFunc::$variant(f) = sf { Some(*f) } else {