diff --git a/crates/vm/src/builtins/function.rs b/crates/vm/src/builtins/function.rs index 489482d1933..140cb6701b6 100644 --- a/crates/vm/src/builtins/function.rs +++ b/crates/vm/src/builtins/function.rs @@ -9,7 +9,7 @@ use super::{ use crate::common::lock::OnceCell; use crate::common::lock::PyMutex; use crate::function::ArgMapping; -use crate::object::{Traverse, TraverseFn}; +use crate::object::{PyAtomicRef, Traverse, TraverseFn}; use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, bytecode, @@ -61,7 +61,7 @@ fn format_missing_args( #[pyclass(module = false, name = "function", traverse = "manual")] #[derive(Debug)] pub struct PyFunction { - code: PyMutex>, + code: PyAtomicRef, globals: PyDictRef, builtins: PyObjectRef, closure: Option>>, @@ -192,7 +192,7 @@ impl PyFunction { let qualname = vm.ctx.new_str(code.qualname.as_str()); let func = Self { - code: PyMutex::new(code.clone()), + code: PyAtomicRef::from(code.clone()), globals, builtins, closure: None, @@ -217,7 +217,7 @@ impl PyFunction { func_args: FuncArgs, vm: &VirtualMachine, ) -> PyResult<()> { - let code = &*self.code.lock(); + let code: &Py = &self.code; let nargs = func_args.args.len(); let n_expected_args = code.arg_count as usize; let total_args = code.arg_count as usize + code.kwonlyarg_count as usize; @@ -539,13 +539,12 @@ impl Py { Err(err) => info!( "jit: function `{}` is falling back to being interpreted because of the \ error: {}", - self.code.lock().obj_name, - err + self.code.obj_name, err ), } } - let code = self.code.lock().clone(); + let code: PyRef = (*self.code).to_owned(); let locals = if code.flags.contains(bytecode::CodeFlags::NEWLOCALS) { ArgMapping::from_dict_exact(vm.ctx.new_dict()) @@ -609,7 +608,7 @@ impl Py { /// Returns true if: no VARARGS, no VARKEYWORDS, no kwonly args, not generator/coroutine, /// and effective_nargs matches co_argcount. pub(crate) fn can_specialize_call(&self, effective_nargs: u32) -> bool { - let code = self.code.lock(); + let code: &Py = &self.code; let flags = code.flags; flags.contains(bytecode::CodeFlags::NEWLOCALS) && !flags.intersects( @@ -627,7 +626,7 @@ impl Py { /// Only valid when: no VARARGS, no VARKEYWORDS, no kwonlyargs, not generator/coroutine, /// and nargs == co_argcount. pub fn invoke_exact_args(&self, args: &[PyObjectRef], vm: &VirtualMachine) -> PyResult { - let code = self.code.lock().clone(); + let code: PyRef = (*self.code).to_owned(); let locals = ArgMapping::from_dict_exact(vm.ctx.new_dict()); @@ -676,12 +675,12 @@ impl PyPayload for PyFunction { impl PyFunction { #[pygetset] fn __code__(&self) -> PyRef { - self.code.lock().clone() + (*self.code).to_owned() } #[pygetset(setter)] - fn set___code__(&self, code: PyRef) { - *self.code.lock() = code; + fn set___code__(&self, code: PyRef, vm: &VirtualMachine) { + self.code.swap_to_temporary_refs(code, vm); self.func_version.store(0, Relaxed); } @@ -923,7 +922,7 @@ impl PyFunction { } let arg_types = jit::get_jit_arg_types(&zelf, vm)?; let ret_type = jit::jit_ret_type(&zelf, vm)?; - let code = zelf.code.lock(); + let code: &Py = &zelf.code; let compiled = rustpython_jit::compile(&code.code, &arg_types, ret_type) .map_err(|err| jit::new_jit_error(err.to_string(), vm))?; let _ = zelf.jitted_code.set(compiled); diff --git a/crates/vm/src/builtins/function/jit.rs b/crates/vm/src/builtins/function/jit.rs index 9d3803759cf..56594a0f462 100644 --- a/crates/vm/src/builtins/function/jit.rs +++ b/crates/vm/src/builtins/function/jit.rs @@ -1,7 +1,7 @@ use crate::{ AsObject, Py, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::{ - PyBaseExceptionRef, PyDict, PyDictRef, PyFunction, PyStrInterned, bool_, float, int, + PyBaseExceptionRef, PyCode, PyDict, PyDictRef, PyFunction, PyStrInterned, bool_, float, int, }, bytecode::CodeFlags, convert::ToPyObject, @@ -67,7 +67,7 @@ fn get_jit_arg_type(dict: &Py, name: &str, vm: &VirtualMachine) -> PyRes } pub fn get_jit_arg_types(func: &Py, vm: &VirtualMachine) -> PyResult> { - let code = func.code.lock(); + let code: &Py = &func.code; let arg_names = code.arg_names(); if code @@ -160,7 +160,7 @@ pub(crate) fn get_jit_args<'a>( let mut jit_args = jitted_code.args_builder(); let nargs = func_args.args.len(); - let code = func.code.lock(); + let code: &Py = &func.code; let arg_names = code.arg_names(); let arg_count = code.arg_count; let posonlyarg_count = code.posonlyarg_count; @@ -220,7 +220,5 @@ pub(crate) fn get_jit_args<'a>( } } - drop(code); - jit_args.into_args().ok_or(ArgsError::NotAllArgsPassed) }