diff --git a/crates/vm/src/builtins/object.rs b/crates/vm/src/builtins/object.rs index 77a1eff6c7d..a61fb1e2971 100644 --- a/crates/vm/src/builtins/object.rs +++ b/crates/vm/src/builtins/object.rs @@ -497,16 +497,45 @@ impl PyBaseObject { ) -> PyResult<()> { match value.downcast::() { Ok(cls) => { - let both_module = instance.class().fast_issubclass(vm.ctx.types.module_type) + let current_cls = instance.class(); + let both_module = current_cls.fast_issubclass(vm.ctx.types.module_type) && cls.fast_issubclass(vm.ctx.types.module_type); - let both_mutable = !instance - .class() + let both_mutable = !current_cls .slots .flags .has_feature(PyTypeFlags::IMMUTABLETYPE) && !cls.slots.flags.has_feature(PyTypeFlags::IMMUTABLETYPE); // FIXME(#1979) cls instances might have a payload if both_mutable || both_module { + let has_dict = + |typ: &Py| typ.slots.flags.has_feature(PyTypeFlags::HAS_DICT); + // Compare slots tuples + let slots_equal = match ( + current_cls + .heaptype_ext + .as_ref() + .and_then(|e| e.slots.as_ref()), + cls.heaptype_ext.as_ref().and_then(|e| e.slots.as_ref()), + ) { + (Some(a), Some(b)) => { + a.len() == b.len() + && a.iter() + .zip(b.iter()) + .all(|(x, y)| x.as_str() == y.as_str()) + } + (None, None) => true, + _ => false, + }; + if current_cls.slots.basicsize != cls.slots.basicsize + || !slots_equal + || has_dict(current_cls) != has_dict(&cls) + { + return Err(vm.new_type_error(format!( + "__class__ assignment: '{}' object layout differs from '{}'", + cls.name(), + current_cls.name() + ))); + } instance.set_class(cls, vm); Ok(()) } else { diff --git a/extra_tests/snippets/builtin_type.py b/extra_tests/snippets/builtin_type.py index 7a8e4840e13..67269e694c0 100644 --- a/extra_tests/snippets/builtin_type.py +++ b/extra_tests/snippets/builtin_type.py @@ -240,6 +240,56 @@ class C(B, BB): assert C.mro() == [C, B, A, BB, AA, object] +class TypeA: + def __init__(self): + self.a = 1 + + +class TypeB: + __slots__ = "b" + + def __init__(self): + self.b = 2 + + +obj = TypeA() +with assert_raises(TypeError) as cm: + obj.__class__ = TypeB +assert "__class__ assignment: 'TypeB' object layout differs from 'TypeA'" in str( + cm.exception +) + + +# Test: same slot count but different slot names should fail +class SlotX: + __slots__ = ("x",) + + +class SlotY: + __slots__ = ("y",) + + +slot_obj = SlotX() +with assert_raises(TypeError) as cm: + slot_obj.__class__ = SlotY +assert "__class__ assignment: 'SlotY' object layout differs from 'SlotX'" in str( + cm.exception +) + + +# Test: same slots should succeed +class SlotA: + __slots__ = ("a",) + + +class SlotA2: + __slots__ = ("a",) + + +slot_a = SlotA() +slot_a.__class__ = SlotA2 # Should work + + assert type(Exception.args).__name__ == "getset_descriptor" assert type(None).__bool__(None) is False