Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Lib/test/test_descr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4318,7 +4318,6 @@ class C:
C.__name__ = Nasty("abc")
C.__name__ = "normal"

@unittest.expectedFailure # TODO: RUSTPYTHON
def test_subclass_right_op(self):
# Testing correct dispatch of subclass overloading __r<op>__...

Expand Down
51 changes: 51 additions & 0 deletions crates/vm/src/protocol/number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,63 @@ pub enum PyNumberBinaryOp {
InplaceMatrixMultiply,
}

impl PyNumberBinaryOp {
/// Returns `None` for in-place ops which don't have right-side variants.
pub fn right_method_name(
self,
vm: &VirtualMachine,
) -> Option<&'static crate::builtins::PyStrInterned> {
use PyNumberBinaryOp::*;
Some(match self {
Add => identifier!(vm, __radd__),
Subtract => identifier!(vm, __rsub__),
Multiply => identifier!(vm, __rmul__),
Remainder => identifier!(vm, __rmod__),
Divmod => identifier!(vm, __rdivmod__),
Lshift => identifier!(vm, __rlshift__),
Rshift => identifier!(vm, __rrshift__),
And => identifier!(vm, __rand__),
Xor => identifier!(vm, __rxor__),
Or => identifier!(vm, __ror__),
FloorDivide => identifier!(vm, __rfloordiv__),
TrueDivide => identifier!(vm, __rtruediv__),
MatrixMultiply => identifier!(vm, __rmatmul__),
// In-place ops don't have right-side variants
InplaceAdd
| InplaceSubtract
| InplaceMultiply
| InplaceRemainder
| InplaceLshift
| InplaceRshift
| InplaceAnd
| InplaceXor
| InplaceOr
| InplaceFloorDivide
| InplaceTrueDivide
| InplaceMatrixMultiply => return None,
})
}
}

#[derive(Copy, Clone)]
pub enum PyNumberTernaryOp {
Power,
InplacePower,
}

impl PyNumberTernaryOp {
/// Returns `None` for in-place ops which don't have right-side variants.
pub fn right_method_name(
self,
vm: &VirtualMachine,
) -> Option<&'static crate::builtins::PyStrInterned> {
Some(match self {
PyNumberTernaryOp::Power => identifier!(vm, __rpow__),
PyNumberTernaryOp::InplacePower => return None,
})
}
}

#[derive(Default)]
pub struct PyNumberSlots {
pub add: AtomicCell<Option<PyNumberBinaryFunc>>,
Expand Down
58 changes: 54 additions & 4 deletions crates/vm/src/vm/vm_ops.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,32 @@
use super::VirtualMachine;
use crate::stdlib::_warnings;
use crate::{
PyRef,
builtins::{PyInt, PyStr, PyStrRef, PyUtf8Str},
Py, PyRef,
builtins::{PyInt, PyStr, PyStrInterned, PyStrRef, PyType, PyUtf8Str},
object::{AsObject, PyObject, PyObjectRef, PyResult},
protocol::{PyNumberBinaryOp, PyNumberTernaryOp},
types::PyComparisonOp,
};
use num_traits::ToPrimitive;

/// Similar to `method_is_overloaded` in CPython typeobject.c
Comment thread
rlaisqls marked this conversation as resolved.
Outdated
fn method_is_overloaded(
class_a: &Py<PyType>,
class_b: &Py<PyType>,
rop_name: Option<&'static PyStrInterned>,
vm: &VirtualMachine,
) -> PyResult<bool> {
let Some(rop_name) = rop_name else {
return Ok(false);
};
let Some(method_b) = class_b.get_attr(rop_name) else {
return Ok(false);
};
class_a.get_attr(rop_name).map_or(Ok(true), |method_a| {
vm.identical_or_equal(&method_a, &method_b).map(|eq| !eq)
})
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

macro_rules! binary_func {
($fn:ident, $op_slot:ident, $op:expr) => {
pub fn $fn(&self, a: &PyObject, b: &PyObject) -> PyResult {
Expand Down Expand Up @@ -162,18 +180,34 @@ impl VirtualMachine {

// Number slots are inherited, direct access is O(1)
let slot_a = class_a.slots.as_number.left_binary_op(op_slot);
let slot_a_addr = slot_a.map(|x| x as usize);
let mut slot_b = None;
let left_b_addr;

if !class_a.is(class_b) {
let slot_bb = class_b.slots.as_number.right_binary_op(op_slot);
if slot_bb.map(|x| x as usize) != slot_a.map(|x| x as usize) {
if slot_bb.map(|x| x as usize) != slot_a_addr {
slot_b = slot_bb;
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
left_b_addr = class_b
.slots
.as_number
.left_binary_op(op_slot)
.map(|x| x as usize);
} else {
left_b_addr = slot_a_addr;
}

if let Some(slot_a) = slot_a {
if let Some(slot_bb) = slot_b
&& class_b.fast_issubclass(class_a)
&& (slot_a_addr != left_b_addr
|| method_is_overloaded(
class_a,
class_b,
op_slot.right_method_name(self),
self,
)?)
{
let ret = slot_bb(a, b, self)?;
if !ret.is(&self.ctx.not_implemented) {
Expand Down Expand Up @@ -269,18 +303,34 @@ impl VirtualMachine {

// Number slots are inherited, direct access is O(1)
let slot_a = class_a.slots.as_number.left_ternary_op(op_slot);
let slot_a_addr = slot_a.map(|x| x as usize);
let mut slot_b = None;
let left_b_addr;

if !class_a.is(class_b) {
let slot_bb = class_b.slots.as_number.right_ternary_op(op_slot);
if slot_bb.map(|x| x as usize) != slot_a.map(|x| x as usize) {
if slot_bb.map(|x| x as usize) != slot_a_addr {
slot_b = slot_bb;
}
left_b_addr = class_b
.slots
.as_number
.left_ternary_op(op_slot)
.map(|x| x as usize);
} else {
left_b_addr = slot_a_addr;
}

if let Some(slot_a) = slot_a {
if let Some(slot_bb) = slot_b
&& class_b.fast_issubclass(class_a)
&& (slot_a_addr != left_b_addr
|| method_is_overloaded(
class_a,
class_b,
op_slot.right_method_name(self),
self,
)?)
{
let ret = slot_bb(a, b, c, self)?;
if !ret.is(&self.ctx.not_implemented) {
Expand Down
Loading