diff --git a/crates/derive-impl/src/pyclass.rs b/crates/derive-impl/src/pyclass.rs index e54c1a867fd..31a7fcc7569 100644 --- a/crates/derive-impl/src/pyclass.rs +++ b/crates/derive-impl/src/pyclass.rs @@ -667,7 +667,7 @@ pub(crate) fn impl_pyclass(attr: PunctuatedNestedMeta, item: Item) -> Result::PAYLOAD_TYPE_ID; #[inline] - fn validate_downcastable_from(obj: &::rustpython_vm::PyObject) -> bool { + unsafe fn validate_downcastable_from(obj: &::rustpython_vm::PyObject) -> bool { ::BASICSIZE <= obj.class().slots.basicsize && obj.class().fast_issubclass(::static_type()) } diff --git a/crates/derive-impl/src/pystructseq.rs b/crates/derive-impl/src/pystructseq.rs index 75617d300be..ccff85fae79 100644 --- a/crates/derive-impl/src/pystructseq.rs +++ b/crates/derive-impl/src/pystructseq.rs @@ -595,7 +595,7 @@ pub(crate) fn impl_pystruct_sequence( const PAYLOAD_TYPE_ID: ::core::any::TypeId = <::rustpython_vm::builtins::PyTuple as ::rustpython_vm::PyPayload>::PAYLOAD_TYPE_ID; #[inline] - fn validate_downcastable_from(obj: &::rustpython_vm::PyObject) -> bool { + unsafe fn validate_downcastable_from(obj: &::rustpython_vm::PyObject) -> bool { obj.class().fast_issubclass(::static_type()) } diff --git a/crates/vm/src/builtins/str.rs b/crates/vm/src/builtins/str.rs index d765847c1ab..e8ba6eb915e 100644 --- a/crates/vm/src/builtins/str.rs +++ b/crates/vm/src/builtins/str.rs @@ -1944,7 +1944,7 @@ impl PyPayload for PyUtf8Str { const PAYLOAD_TYPE_ID: core::any::TypeId = core::any::TypeId::of::(); - fn validate_downcastable_from(obj: &PyObject) -> bool { + unsafe fn validate_downcastable_from(obj: &PyObject) -> bool { // SAFETY: we know the object is a PyStr in this context let wtf8 = unsafe { obj.downcast_unchecked_ref::() }; wtf8.is_utf8() diff --git a/crates/vm/src/object/core.rs b/crates/vm/src/object/core.rs index c949cae9053..43b2f7dc61a 100644 --- a/crates/vm/src/object/core.rs +++ b/crates/vm/src/object/core.rs @@ -753,7 +753,7 @@ impl PyObject { /// Check if this object can be downcast to T. #[inline(always)] pub fn downcastable(&self) -> bool { - T::downcastable_from(self) + self.typeid() == T::PAYLOAD_TYPE_ID && unsafe { T::validate_downcastable_from(self) } } /// Attempt to downcast this reference to a subclass. diff --git a/crates/vm/src/object/payload.rs b/crates/vm/src/object/payload.rs index 143033ee642..98c61817568 100644 --- a/crates/vm/src/object/payload.rs +++ b/crates/vm/src/object/payload.rs @@ -28,19 +28,15 @@ pub(crate) fn cold_downcast_type_error( pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { const PAYLOAD_TYPE_ID: core::any::TypeId = core::any::TypeId::of::(); - /// # Safety: this function should only be called if `payload_type_id` matches the type of `obj`. + /// # Safety + /// This function should only be called if `payload_type_id` matches the type of `obj`. #[inline] - fn downcastable_from(obj: &PyObject) -> bool { - obj.typeid() == Self::PAYLOAD_TYPE_ID && Self::validate_downcastable_from(obj) - } - - #[inline] - fn validate_downcastable_from(_obj: &PyObject) -> bool { + unsafe fn validate_downcastable_from(_obj: &PyObject) -> bool { true } fn try_downcast_from(obj: &PyObject, vm: &VirtualMachine) -> PyResult<()> { - if Self::downcastable_from(obj) { + if obj.downcastable::() { return Ok(()); }