diff --git a/crates/vm/src/types/structseq.rs b/crates/vm/src/types/structseq.rs index 27315749e06..adf5f5658b2 100644 --- a/crates/vm/src/types/structseq.rs +++ b/crates/vm/src/types/structseq.rs @@ -2,7 +2,7 @@ use crate::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func, builtins::{PyBaseExceptionRef, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef}, class::{PyClassImpl, StaticType}, - function::{Either, PyComparisonValue}, + function::{Either, FuncArgs, PyComparisonValue, PyMethodDef, PyMethodFlags}, iter::PyExactSizeIterator, protocol::{PyMappingMethods, PySequenceMethods}, sliceable::{SequenceIndex, SliceableSequenceOp}, @@ -11,6 +11,15 @@ use crate::{ }; use std::sync::LazyLock; +const DEFAULT_STRUCTSEQ_REDUCE: PyMethodDef = PyMethodDef::new_const( + "__reduce__", + |zelf: PyRef, vm: &VirtualMachine| -> PyTupleRef { + vm.new_tuple((zelf.class().to_owned(), (vm.ctx.new_tuple(zelf.to_vec()),))) + }, + PyMethodFlags::METHOD, + None, +); + /// Create a new struct sequence instance from a sequence. /// /// The class must have `n_sequence_fields` and `n_fields` attributes set @@ -206,19 +215,13 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static { }; let (body, suffix) = if let Some(_guard) = rustpython_vm::recursion::ReprGuard::enter(vm, zelf.as_ref()) { - if field_names.len() == 1 { - let value = zelf.first().unwrap(); - let formatted = format_field((value, field_names[0]))?; - (formatted, ",") - } else { - let fields: PyResult> = zelf - .iter() - .map(|value| value.as_ref()) - .zip(field_names.iter().copied()) - .map(format_field) - .collect(); - (fields?.join(", "), "") - } + let fields: PyResult> = zelf + .iter() + .map(|value| value.as_ref()) + .zip(field_names.iter().copied()) + .map(format_field) + .collect(); + (fields?.join(", "), "") } else { (String::new(), "...") }; @@ -232,8 +235,45 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static { } #[pymethod] - fn __reduce__(zelf: PyRef, vm: &VirtualMachine) -> PyTupleRef { - vm.new_tuple((zelf.class().to_owned(), (vm.ctx.new_tuple(zelf.to_vec()),))) + fn __replace__(zelf: PyRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + if !args.args.is_empty() { + return Err(vm.new_type_error("__replace__() takes no positional arguments".to_owned())); + } + + if Self::Data::UNNAMED_FIELDS_LEN > 0 { + return Err(vm.new_type_error(format!( + "__replace__() is not supported for {} because it has unnamed field(s)", + zelf.class().slot_name() + ))); + } + + let n_fields = + Self::Data::REQUIRED_FIELD_NAMES.len() + Self::Data::OPTIONAL_FIELD_NAMES.len(); + let mut items: Vec = zelf.as_slice()[..n_fields].to_vec(); + + let mut kwargs = args.kwargs.clone(); + + // Replace fields from kwargs + let all_field_names: Vec<&str> = Self::Data::REQUIRED_FIELD_NAMES + .iter() + .chain(Self::Data::OPTIONAL_FIELD_NAMES.iter()) + .copied() + .collect(); + for (i, &name) in all_field_names.iter().enumerate() { + if let Some(val) = kwargs.shift_remove(name) { + items[i] = val; + } + } + + // Check for unexpected keyword arguments + if !kwargs.is_empty() { + let names: Vec<&str> = kwargs.keys().map(|k| k.as_str()).collect(); + return Err(vm.new_type_error(format!("Got unexpected field name(s): {:?}", names))); + } + + PyTuple::new_unchecked(items.into_boxed_slice()) + .into_ref_with_type(vm, zelf.class().to_owned()) + .map(Into::into) } #[pymethod] @@ -327,6 +367,20 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static { .slots .richcompare .store(Some(struct_sequence_richcompare)); + + // Default __reduce__: only set if not already overridden by the impl's extend_class. + // This allows struct sequences like sched_param to provide a custom __reduce__ + // (equivalent to METH_COEXIST in structseq.c). + if !class + .attributes + .read() + .contains_key(ctx.intern_str("__reduce__")) + { + class.set_attr( + ctx.intern_str("__reduce__"), + DEFAULT_STRUCTSEQ_REDUCE.to_proper_method(class, ctx), + ); + } } }