diff --git a/crates/vm/src/builtins/set.rs b/crates/vm/src/builtins/set.rs index 1c0268e2ed7..c3473809734 100644 --- a/crates/vm/src/builtins/set.rs +++ b/crates/vm/src/builtins/set.rs @@ -24,6 +24,7 @@ use crate::{ vm::VirtualMachine, }; use alloc::fmt; +use core::borrow::Borrow; use core::ops::Deref; use rustpython_common::{ atomic::{Ordering, PyAtomic, Radium}, @@ -719,8 +720,10 @@ impl PySet { } fn __iand__(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { - zelf.inner - .intersection_update(core::iter::once(set.into_iterable(vm)?), vm)?; + if !set.is(zelf.as_object()) { + zelf.inner + .intersection_update(core::iter::once(set.into_iterable(vm)?), vm)?; + } Ok(zelf) } @@ -731,8 +734,12 @@ impl PySet { } fn __isub__(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { - zelf.inner - .difference_update(set.into_iterable_iter(vm)?, vm)?; + if set.is(zelf.as_object()) { + zelf.inner.clear(); + } else { + zelf.inner + .difference_update(set.into_iterable_iter(vm)?, vm)?; + } Ok(zelf) } @@ -748,8 +755,12 @@ impl PySet { } fn __ixor__(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { - zelf.inner - .symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?; + if set.is(zelf.as_object()) { + zelf.inner.clear(); + } else { + zelf.inner + .symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?; + } Ok(zelf) } @@ -1297,6 +1308,13 @@ struct AnySet { object: PyObjectRef, } +impl Borrow for AnySet { + #[inline(always)] + fn borrow(&self) -> &PyObject { + &self.object + } +} + impl AnySet { /// Check if object is a set or frozenset (including subclasses) /// Equivalent to CPython's PyAnySet_Check diff --git a/extra_tests/snippets/builtin_set.py b/extra_tests/snippets/builtin_set.py index 1b2f6ff0968..950875ea09a 100644 --- a/extra_tests/snippets/builtin_set.py +++ b/extra_tests/snippets/builtin_set.py @@ -200,6 +200,18 @@ class S(set): with assert_raises(TypeError): a &= [1, 2, 3] +a = set([1, 2, 3]) +a &= a +assert a == set([1, 2, 3]) + +a = set([1, 2, 3]) +a -= a +assert a == set() + +a = set([1, 2, 3]) +a ^= a +assert a == set() + a = set([1, 2, 3]) a.difference_update([3, 4, 5]) assert a == set([1, 2])