diff --git a/.cspell.dict/python-more.txt b/.cspell.dict/python-more.txt index c4a419c5ffe..73b2d620ca6 100644 --- a/.cspell.dict/python-more.txt +++ b/.cspell.dict/python-more.txt @@ -67,6 +67,7 @@ fnctl frombytes fromhex fromunicode +frozensets fset fspath fstring diff --git a/Lib/copyreg.py b/Lib/copyreg.py index 578392409b4..c9da81a6882 100644 --- a/Lib/copyreg.py +++ b/Lib/copyreg.py @@ -31,8 +31,8 @@ def pickle_complex(c): pickle(complex, pickle_complex, complex) def pickle_union(obj): - import functools, operator - return functools.reduce, (operator.or_, obj.__args__) + import typing, operator + return operator.getitem, (typing.Union, obj.__args__) pickle(type(int | str), pickle_union) diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 3f1e2331bc2..4db73a64e73 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -2311,6 +2311,7 @@ class C: self.assertDocStrEqual(C.__doc__, "C(x:int=3)") + @unittest.expectedFailure # TODO: RUSTPYTHON def test_docstring_one_field_with_default_none(self): @dataclass class C: diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 047916caf07..bf3e24481ca 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -2922,6 +2922,7 @@ def decorated_classmethod(cls, arg: int) -> str: 'decorated_classmethod' ) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_invalid_registrations(self): msg_prefix = "Invalid first argument to `register()`: " msg_suffix = ( diff --git a/Lib/test/test_inspect/test_inspect.py b/Lib/test/test_inspect/test_inspect.py index 7d037d0554e..655587e7bb4 100644 --- a/Lib/test/test_inspect/test_inspect.py +++ b/Lib/test/test_inspect/test_inspect.py @@ -2028,6 +2028,7 @@ def test_pep_695_generics_with_future_annotations_nested_in_function(self): class TestFormatAnnotation(unittest.TestCase): + @unittest.expectedFailure # TODO: RUSTPYTHON def test_typing_replacement(self): from test.typinganndata.ann_module9 import A, ann, ann1 self.assertEqual(inspect.formatannotation(ann), 'Union[List[str], int]') @@ -2040,6 +2041,7 @@ def test_typing_replacement(self): 'Union[List[testModule.typing.A], int]', ) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_formatannotationrelativeto(self): from test.typinganndata.ann_module9 import A, ann1 diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index b57fdf35fb6..073599e1dd2 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -1,19 +1,35 @@ # Python test set -- part 6, built-in types -from test.support import run_with_locale, cpython_only +from test.support import ( + run_with_locale, cpython_only, no_rerun, + MISSING_C_DOCSTRINGS, EqualToForwardRef, check_disallow_instantiation, +) +from test.support.script_helper import assert_python_ok +from test.support.import_helper import import_fresh_module + import collections.abc -from collections import namedtuple +from collections import namedtuple, UserDict import copy +# XXX: RUSTPYTHON +try: + import _datetime +except ImportError: + _datetime = None import gc import inspect import pickle import locale import sys +import textwrap import types import unittest.mock import weakref import typing +import unittest # XXX: RUSTPYTHON; importing to be able to skip tests + +c_types = import_fresh_module('types', fresh=['_types']) +py_types = import_fresh_module('types', blocked=['_types']) T = typing.TypeVar("T") @@ -29,6 +45,29 @@ def clear_typing_caches(): class TypesTests(unittest.TestCase): + @unittest.skipUnless(c_types, "TODO: RUSTPYTHON; requires _types module") + def test_names(self): + c_only_names = {'CapsuleType'} + ignored = {'new_class', 'resolve_bases', 'prepare_class', + 'get_original_bases', 'DynamicClassAttribute', 'coroutine'} + + for name in c_types.__all__: + if name not in c_only_names | ignored: + self.assertIs(getattr(c_types, name), getattr(py_types, name)) + + all_names = ignored | { + 'AsyncGeneratorType', 'BuiltinFunctionType', 'BuiltinMethodType', + 'CapsuleType', 'CellType', 'ClassMethodDescriptorType', 'CodeType', + 'CoroutineType', 'EllipsisType', 'FrameType', 'FunctionType', + 'GeneratorType', 'GenericAlias', 'GetSetDescriptorType', + 'LambdaType', 'MappingProxyType', 'MemberDescriptorType', + 'MethodDescriptorType', 'MethodType', 'MethodWrapperType', + 'ModuleType', 'NoneType', 'NotImplementedType', 'SimpleNamespace', + 'TracebackType', 'UnionType', 'WrapperDescriptorType', + } + self.assertEqual(all_names, set(c_types.__all__)) + self.assertEqual(all_names - c_only_names, set(py_types.__all__)) + def test_truth_values(self): if None: self.fail('None is true instead of false') if 0: self.fail('0 is true instead of false') @@ -226,8 +265,8 @@ def test_type_function(self): def test_int__format__(self): def test(i, format_spec, result): # just make sure we have the unified type for integers - assert type(i) == int - assert type(format_spec) == str + self.assertIs(type(i), int) + self.assertIs(type(format_spec), str) self.assertEqual(i.__format__(format_spec), result) test(123456789, 'd', '123456789') @@ -392,8 +431,8 @@ def test(i, format_spec, result): test(123456, "1=20", '11111111111111123456') test(123456, "*=20", '**************123456') - @unittest.expectedFailure - @run_with_locale('LC_NUMERIC', 'en_US.UTF8') + @unittest.expectedFailure # TODO: RUSTPYTHON + @run_with_locale('LC_NUMERIC', 'en_US.UTF8', '') def test_float__format__locale(self): # test locale support for __format__ code 'n' @@ -402,7 +441,8 @@ def test_float__format__locale(self): self.assertEqual(locale.format_string('%g', x, grouping=True), format(x, 'n')) self.assertEqual(locale.format_string('%.10g', x, grouping=True), format(x, '.10n')) - @run_with_locale('LC_NUMERIC', 'en_US.UTF8') + @unittest.expectedFailure # TODO: RUSTPYTHON + @run_with_locale('LC_NUMERIC', 'en_US.UTF8', '') def test_int__format__locale(self): # test locale support for __format__ code 'n' for integers @@ -420,9 +460,6 @@ def test_int__format__locale(self): self.assertEqual(len(format(0, rfmt)), len(format(x, rfmt))) self.assertEqual(len(format(0, lfmt)), len(format(x, lfmt))) self.assertEqual(len(format(0, cfmt)), len(format(x, cfmt))) - - if sys.platform != "darwin": - test_int__format__locale = unittest.expectedFailure(test_int__format__locale) def test_float__format__(self): def test(f, format_spec, result): @@ -489,8 +526,8 @@ def test(f, format_spec, result): # and a number after the decimal. This is tricky, because # a totally empty format specifier means something else. # So, just use a sign flag - test(1e200, '+g', '+1e+200') - test(1e200, '+', '+1e+200') + test(1.25e200, '+g', '+1.25e+200') + test(1.25e200, '+', '+1.25e+200') test(1.1e200, '+g', '+1.1e+200') test(1.1e200, '+', '+1.1e+200') @@ -602,8 +639,9 @@ def test_slot_wrapper_types(self): self.assertIsInstance(object.__lt__, types.WrapperDescriptorType) self.assertIsInstance(int.__lt__, types.WrapperDescriptorType) - # TODO: RUSTPYTHON No signature found in builtin method __get__ of 'method_descriptor' objects. - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; No signature found in builtin method __get__ of 'method_descriptor' objects. + @unittest.skipIf(MISSING_C_DOCSTRINGS, + "Signature information for builtins requires docstrings") def test_dunder_get_signature(self): sig = inspect.signature(object.__init__.__get__) self.assertEqual(list(sig.parameters), ["instance", "owner"]) @@ -627,6 +665,26 @@ def test_method_descriptor_types(self): self.assertIsInstance(int.from_bytes, types.BuiltinMethodType) self.assertIsInstance(int.__new__, types.BuiltinMethodType) + @unittest.expectedFailure # TODO: RUSTPYTHON; ModuleNotFoundError: No module named '_queue' + def test_method_descriptor_crash(self): + # gh-132747: The default __get__() implementation in C was unable + # to handle a second argument of None when called from Python + import _io + import io + import _queue + + to_check = [ + # (method, instance) + (_io._TextIOBase.read, io.StringIO()), + (_queue.SimpleQueue.put, _queue.SimpleQueue()), + (str.capitalize, "nobody expects the spanish inquisition") + ] + + for method, instance in to_check: + with self.subTest(method=method, instance=instance): + bound = method.__get__(instance) + self.assertIsInstance(bound, types.BuiltinMethodType) + def test_ellipsis_type(self): self.assertIsInstance(Ellipsis, types.EllipsisType) @@ -644,6 +702,29 @@ def test_traceback_and_frame_types(self): self.assertIsInstance(exc.__traceback__, types.TracebackType) self.assertIsInstance(exc.__traceback__.tb_frame, types.FrameType) + # XXX: RUSTPYTHON + @unittest.skipUnless(_datetime, "requires _datetime module") + def test_capsule_type(self): + self.assertIsInstance(_datetime.datetime_CAPI, types.CapsuleType) + + def test_call_unbound_crash(self): + # GH-131998: The specialized instruction would get tricked into dereferencing + # a bound "self" that didn't exist if subsequently called unbound. + code = """if True: + + def call(part): + [] + ([] + []) + part.pop() + + for _ in range(3): + call(['a']) + try: + call(list) + except TypeError: + pass + """ + assert_python_ok("-c", code) + class UnionTests(unittest.TestCase): @@ -706,15 +787,54 @@ def test_or_types_operator(self): y = int | bool with self.assertRaises(TypeError): x < y - # Check that we don't crash if typing.Union does not have a tuple in __args__ - y = typing.Union[str, int] - y.__args__ = [str, int] - self.assertEqual(x, y) def test_hash(self): self.assertEqual(hash(int | str), hash(str | int)) self.assertEqual(hash(int | str), hash(typing.Union[int, str])) + def test_union_of_unhashable(self): + class UnhashableMeta(type): + __hash__ = None + + class A(metaclass=UnhashableMeta): ... + class B(metaclass=UnhashableMeta): ... + + self.assertEqual((A | B).__args__, (A, B)) + union1 = A | B + with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"): + hash(union1) + + union2 = int | B + with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"): + hash(union2) + + union3 = A | int + with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"): + hash(union3) + + def test_unhashable_becomes_hashable(self): + is_hashable = False + class UnhashableMeta(type): + def __hash__(self): + if is_hashable: + return 1 + else: + raise TypeError("not hashable") + + class A(metaclass=UnhashableMeta): ... + class B(metaclass=UnhashableMeta): ... + + union = A | B + self.assertEqual(union.__args__, (A, B)) + + with self.assertRaisesRegex(TypeError, "not hashable"): + hash(union) + + is_hashable = True + + with self.assertRaisesRegex(TypeError, "union contains 2 unhashable elements"): + hash(union) + def test_instancecheck_and_subclasscheck(self): for x in (int | str, typing.Union[int, str]): with self.subTest(x=x): @@ -722,15 +842,15 @@ def test_instancecheck_and_subclasscheck(self): self.assertIsInstance(True, x) self.assertIsInstance('a', x) self.assertNotIsInstance(None, x) - self.assertTrue(issubclass(int, x)) - self.assertTrue(issubclass(bool, x)) - self.assertTrue(issubclass(str, x)) - self.assertFalse(issubclass(type(None), x)) + self.assertIsSubclass(int, x) + self.assertIsSubclass(bool, x) + self.assertIsSubclass(str, x) + self.assertNotIsSubclass(type(None), x) for x in (int | None, typing.Union[int, None]): with self.subTest(x=x): self.assertIsInstance(None, x) - self.assertTrue(issubclass(type(None), x)) + self.assertIsSubclass(type(None), x) for x in ( int | collections.abc.Mapping, @@ -739,8 +859,8 @@ def test_instancecheck_and_subclasscheck(self): with self.subTest(x=x): self.assertIsInstance({}, x) self.assertNotIsInstance((), x) - self.assertTrue(issubclass(dict, x)) - self.assertFalse(issubclass(list, x)) + self.assertIsSubclass(dict, x) + self.assertNotIsSubclass(list, x) def test_instancecheck_and_subclasscheck_order(self): T = typing.TypeVar('T') @@ -752,7 +872,7 @@ def test_instancecheck_and_subclasscheck_order(self): for x in will_resolve: with self.subTest(x=x): self.assertIsInstance(1, x) - self.assertTrue(issubclass(int, x)) + self.assertIsSubclass(int, x) wont_resolve = ( T | int, @@ -785,13 +905,13 @@ class BadMeta(type): def __subclasscheck__(cls, sub): 1/0 x = int | BadMeta('A', (), {}) - self.assertTrue(issubclass(int, x)) + self.assertIsSubclass(int, x) self.assertRaises(ZeroDivisionError, issubclass, list, x) def test_or_type_operator_with_TypeVar(self): TV = typing.TypeVar('T') - assert TV | str == typing.Union[TV, str] - assert str | TV == typing.Union[str, TV] + self.assertEqual(TV | str, typing.Union[TV, str]) + self.assertEqual(str | TV, typing.Union[str, TV]) self.assertIs((int | TV)[int], int) self.assertIs((TV | int)[int], int) @@ -895,54 +1015,83 @@ def test_or_type_operator_with_forward(self): ForwardBefore = 'Forward' | T def forward_after(x: ForwardAfter[int]) -> None: ... def forward_before(x: ForwardBefore[int]) -> None: ... - assert typing.get_args(typing.get_type_hints(forward_after)['x']) == (int, Forward) - assert typing.get_args(typing.get_type_hints(forward_before)['x']) == (int, Forward) + self.assertEqual(typing.get_args(typing.get_type_hints(forward_after)['x']), + (int, Forward)) + self.assertEqual(typing.get_args(typing.get_type_hints(forward_before)['x']), + (Forward, int)) def test_or_type_operator_with_Protocol(self): class Proto(typing.Protocol): def meth(self) -> int: ... - assert Proto | str == typing.Union[Proto, str] + self.assertEqual(Proto | str, typing.Union[Proto, str]) def test_or_type_operator_with_Alias(self): - assert list | str == typing.Union[list, str] - assert typing.List | str == typing.Union[typing.List, str] + self.assertEqual(list | str, typing.Union[list, str]) + self.assertEqual(typing.List | str, typing.Union[typing.List, str]) def test_or_type_operator_with_NamedTuple(self): - NT=namedtuple('A', ['B', 'C', 'D']) - assert NT | str == typing.Union[NT,str] + NT = namedtuple('A', ['B', 'C', 'D']) + self.assertEqual(NT | str, typing.Union[NT, str]) def test_or_type_operator_with_TypedDict(self): class Point2D(typing.TypedDict): x: int y: int label: str - assert Point2D | str == typing.Union[Point2D, str] + self.assertEqual(Point2D | str, typing.Union[Point2D, str]) def test_or_type_operator_with_NewType(self): UserId = typing.NewType('UserId', int) - assert UserId | str == typing.Union[UserId, str] + self.assertEqual(UserId | str, typing.Union[UserId, str]) def test_or_type_operator_with_IO(self): - assert typing.IO | str == typing.Union[typing.IO, str] + self.assertEqual(typing.IO | str, typing.Union[typing.IO, str]) def test_or_type_operator_with_SpecialForm(self): - assert typing.Any | str == typing.Union[typing.Any, str] - assert typing.NoReturn | str == typing.Union[typing.NoReturn, str] - assert typing.Optional[int] | str == typing.Union[typing.Optional[int], str] - assert typing.Optional[int] | str == typing.Union[int, str, None] - assert typing.Union[int, bool] | str == typing.Union[int, bool, str] + self.assertEqual(typing.Any | str, typing.Union[typing.Any, str]) + self.assertEqual(typing.NoReturn | str, typing.Union[typing.NoReturn, str]) + self.assertEqual(typing.Optional[int] | str, typing.Union[typing.Optional[int], str]) + self.assertEqual(typing.Optional[int] | str, typing.Union[int, str, None]) + self.assertEqual(typing.Union[int, bool] | str, typing.Union[int, bool, str]) + + def test_or_type_operator_with_Literal(self): + Literal = typing.Literal + self.assertEqual((Literal[1] | Literal[2]).__args__, + (Literal[1], Literal[2])) + + self.assertEqual((Literal[0] | Literal[False]).__args__, + (Literal[0], Literal[False])) + self.assertEqual((Literal[1] | Literal[True]).__args__, + (Literal[1], Literal[True])) + + self.assertEqual(Literal[1] | Literal[1], Literal[1]) + self.assertEqual(Literal['a'] | Literal['a'], Literal['a']) + + import enum + class Ints(enum.IntEnum): + A = 0 + B = 1 + + self.assertEqual(Literal[Ints.A] | Literal[Ints.A], Literal[Ints.A]) + self.assertEqual(Literal[Ints.B] | Literal[Ints.B], Literal[Ints.B]) + + self.assertEqual((Literal[Ints.B] | Literal[Ints.A]).__args__, + (Literal[Ints.B], Literal[Ints.A])) + + self.assertEqual((Literal[0] | Literal[Ints.A]).__args__, + (Literal[0], Literal[Ints.A])) + self.assertEqual((Literal[1] | Literal[Ints.B]).__args__, + (Literal[1], Literal[Ints.B])) def test_or_type_repr(self): - assert repr(int | str) == "int | str" - assert repr((int | str) | list) == "int | str | list" - assert repr(int | (str | list)) == "int | str | list" - assert repr(int | None) == "int | None" - assert repr(int | type(None)) == "int | None" - assert repr(int | typing.GenericAlias(list, int)) == "int | list[int]" - - # TODO: RUSTPYTHON - @unittest.expectedFailure + self.assertEqual(repr(int | str), "int | str") + self.assertEqual(repr((int | str) | list), "int | str | list") + self.assertEqual(repr(int | (str | list)), "int | str | list") + self.assertEqual(repr(int | None), "int | None") + self.assertEqual(repr(int | type(None)), "int | None") + self.assertEqual(repr(int | typing.GenericAlias(list, int)), "int | list[int]") + def test_or_type_operator_with_genericalias(self): a = list[int] b = list[str] @@ -963,9 +1112,14 @@ def __eq__(self, other): return 1 / 0 bt = BadType('bt', (), {}) + bt2 = BadType('bt2', (), {}) # Comparison should fail and errors should propagate out for bad types. + union1 = int | bt + union2 = int | bt2 + with self.assertRaises(ZeroDivisionError): + union1 == union2 with self.assertRaises(ZeroDivisionError): - list[int] | list[bt] + bt | bt2 union_ga = (list[str] | int, collections.abc.Callable[..., str] | int, d | int) @@ -1008,6 +1162,19 @@ def test_or_type_operator_reference_cycle(self): self.assertLessEqual(sys.gettotalrefcount() - before, leeway, msg='Check for union reference leak.') + def test_instantiation(self): + check_disallow_instantiation(self, types.UnionType) + self.assertIs(int, types.UnionType[int]) + self.assertIs(int, types.UnionType[int, int]) + self.assertEqual(int | str, types.UnionType[int, str]) + + for obj in ( + int | typing.ForwardRef("str"), + typing.Union[int, "str"], + ): + self.assertIsInstance(obj, types.UnionType) + self.assertEqual(obj.__args__, (int, EqualToForwardRef("str"))) + class MappingProxyTests(unittest.TestCase): mappingproxy = types.MappingProxyType @@ -1197,8 +1364,7 @@ def test_copy(self): self.assertEqual(view['key1'], 70) self.assertEqual(copy['key1'], 27) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_union(self): mapping = {'a': 0, 'b': 1, 'c': 2} view = self.mappingproxy(mapping) @@ -1215,6 +1381,16 @@ def test_union(self): self.assertDictEqual(mapping, {'a': 0, 'b': 1, 'c': 2}) self.assertDictEqual(other, {'c': 3, 'p': 0}) + def test_hash(self): + class HashableDict(dict): + def __hash__(self): + return 3844817361 + view = self.mappingproxy({'a': 1, 'b': 2}) + self.assertRaises(TypeError, hash, view) + mapping = HashableDict({'a': 1, 'b': 2}) + view = self.mappingproxy(mapping) + self.assertEqual(hash(view), hash(mapping)) + class ClassCreationTests(unittest.TestCase): @@ -1238,7 +1414,7 @@ def test_new_class_basics(self): def test_new_class_subclass(self): C = types.new_class("C", (int,)) - self.assertTrue(issubclass(C, int)) + self.assertIsSubclass(C, int) def test_new_class_meta(self): Meta = self.Meta @@ -1283,7 +1459,7 @@ def func(ns): bases=(int,), kwds=dict(metaclass=Meta, z=2), exec_body=func) - self.assertTrue(issubclass(C, int)) + self.assertIsSubclass(C, int) self.assertIsInstance(C, Meta) self.assertEqual(C.x, 0) self.assertEqual(C.y, 1) @@ -1362,6 +1538,80 @@ class C: pass D = types.new_class('D', (A(), C, B()), {}) self.assertEqual(D.__bases__, (A1, A2, A3, C, B1, B2)) + def test_get_original_bases(self): + T = typing.TypeVar('T') + class A: pass + class B(typing.Generic[T]): pass + class C(B[int]): pass + class D(B[str], float): pass + + self.assertEqual(types.get_original_bases(A), (object,)) + self.assertEqual(types.get_original_bases(B), (typing.Generic[T],)) + self.assertEqual(types.get_original_bases(C), (B[int],)) + self.assertEqual(types.get_original_bases(int), (object,)) + self.assertEqual(types.get_original_bases(D), (B[str], float)) + + class E(list[T]): pass + class F(list[int]): pass + + self.assertEqual(types.get_original_bases(E), (list[T],)) + self.assertEqual(types.get_original_bases(F), (list[int],)) + + class FirstBase(typing.Generic[T]): pass + class SecondBase(typing.Generic[T]): pass + class First(FirstBase[int]): pass + class Second(SecondBase[int]): pass + class G(First, Second): pass + self.assertEqual(types.get_original_bases(G), (First, Second)) + + class First_(typing.Generic[T]): pass + class Second_(typing.Generic[T]): pass + class H(First_, Second_): pass + self.assertEqual(types.get_original_bases(H), (First_, Second_)) + + class ClassBasedNamedTuple(typing.NamedTuple): + x: int + + class GenericNamedTuple(typing.NamedTuple, typing.Generic[T]): + x: T + + CallBasedNamedTuple = typing.NamedTuple("CallBasedNamedTuple", [("x", int)]) + + self.assertIs( + types.get_original_bases(ClassBasedNamedTuple)[0], typing.NamedTuple + ) + self.assertEqual( + types.get_original_bases(GenericNamedTuple), + (typing.NamedTuple, typing.Generic[T]) + ) + self.assertIs( + types.get_original_bases(CallBasedNamedTuple)[0], typing.NamedTuple + ) + + class ClassBasedTypedDict(typing.TypedDict): + x: int + + class GenericTypedDict(typing.TypedDict, typing.Generic[T]): + x: T + + CallBasedTypedDict = typing.TypedDict("CallBasedTypedDict", {"x": int}) + + self.assertIs( + types.get_original_bases(ClassBasedTypedDict)[0], + typing.TypedDict + ) + self.assertEqual( + types.get_original_bases(GenericTypedDict), + (typing.TypedDict, typing.Generic[T]) + ) + self.assertIs( + types.get_original_bases(CallBasedTypedDict)[0], + typing.TypedDict + ) + + with self.assertRaisesRegex(TypeError, "Expected an instance of type"): + types.get_original_bases(object()) + # Many of the following tests are derived from test_descr.py def test_prepare_class(self): # Basic test of metaclass derivation @@ -1622,25 +1872,81 @@ class Model(metaclass=ModelBase): with self.assertRaises(RuntimeWarning): type("SouthPonies", (Model,), {}) + def test_subclass_inherited_slot_update(self): + # gh-132284: Make sure slot update still works after fix. + # Note that after assignment to D.__getitem__ the actual C slot will + # never go back to dict_subscript as it was on class type creation but + # rather be set to slot_mp_subscript, unfortunately there is no way to + # check that here. + + class D(dict): + pass + + d = D({None: None}) + self.assertIs(d[None], None) + D.__getitem__ = lambda self, item: 42 + self.assertEqual(d[None], 42) + D.__getitem__ = dict.__getitem__ + self.assertIs(d[None], None) + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: != + def test_tuple_subclass_as_bases(self): + # gh-132176: it used to crash on using + # tuple subclass for as base classes. + class TupleSubclass(tuple): pass + + typ = type("typ", TupleSubclass((int, object)), {}) + self.assertEqual(typ.__bases__, (int, object)) + self.assertEqual(type(typ.__bases__), TupleSubclass) + class SimpleNamespaceTests(unittest.TestCase): def test_constructor(self): - ns1 = types.SimpleNamespace() - ns2 = types.SimpleNamespace(x=1, y=2) - ns3 = types.SimpleNamespace(**dict(x=1, y=2)) + def check(ns, expected): + self.assertEqual(len(ns.__dict__), len(expected)) + self.assertEqual(vars(ns), expected) + # check order + self.assertEqual(list(vars(ns).items()), list(expected.items())) + for name in expected: + self.assertEqual(getattr(ns, name), expected[name]) + + check(types.SimpleNamespace(), {}) + check(types.SimpleNamespace(x=1, y=2), {'x': 1, 'y': 2}) + check(types.SimpleNamespace(**dict(x=1, y=2)), {'x': 1, 'y': 2}) + check(types.SimpleNamespace({'x': 1, 'y': 2}, x=4, z=3), + {'x': 4, 'y': 2, 'z': 3}) + check(types.SimpleNamespace([['x', 1], ['y', 2]], x=4, z=3), + {'x': 4, 'y': 2, 'z': 3}) + check(types.SimpleNamespace(UserDict({'x': 1, 'y': 2}), x=4, z=3), + {'x': 4, 'y': 2, 'z': 3}) + check(types.SimpleNamespace({'x': 1, 'y': 2}), {'x': 1, 'y': 2}) + check(types.SimpleNamespace([['x', 1], ['y', 2]]), {'x': 1, 'y': 2}) + check(types.SimpleNamespace([], x=4, z=3), {'x': 4, 'z': 3}) + check(types.SimpleNamespace({}, x=4, z=3), {'x': 4, 'z': 3}) + check(types.SimpleNamespace([]), {}) + check(types.SimpleNamespace({}), {}) with self.assertRaises(TypeError): - types.SimpleNamespace(1, 2, 3) + types.SimpleNamespace([], []) # too many positional arguments with self.assertRaises(TypeError): - types.SimpleNamespace(**{1: 2}) - - self.assertEqual(len(ns1.__dict__), 0) - self.assertEqual(vars(ns1), {}) - self.assertEqual(len(ns2.__dict__), 2) - self.assertEqual(vars(ns2), {'y': 2, 'x': 1}) - self.assertEqual(len(ns3.__dict__), 2) - self.assertEqual(vars(ns3), {'y': 2, 'x': 1}) + types.SimpleNamespace(1) # not a mapping or iterable + with self.assertRaises(TypeError): + types.SimpleNamespace([1]) # non-iterable + with self.assertRaises(ValueError): + types.SimpleNamespace([['x']]) # not a pair + with self.assertRaises(ValueError): + types.SimpleNamespace([['x', 'y', 'z']]) + with self.assertRaises(TypeError): + types.SimpleNamespace(**{1: 2}) # non-string key + with self.assertRaises(TypeError): + types.SimpleNamespace({1: 2}) + with self.assertRaises(TypeError): + types.SimpleNamespace([[1, 2]]) + with self.assertRaises(TypeError): + types.SimpleNamespace(UserDict({1: 2})) + with self.assertRaises(TypeError): + types.SimpleNamespace([[[], 2]]) # non-hashable key def test_unbound(self): ns1 = vars(types.SimpleNamespace()) @@ -1797,6 +2103,33 @@ def test_pickle(self): self.assertEqual(ns, ns_roundtrip, pname) + def test_replace(self): + ns = types.SimpleNamespace(x=11, y=22) + + ns2 = copy.replace(ns) + self.assertEqual(ns2, ns) + self.assertIsNot(ns2, ns) + self.assertIs(type(ns2), types.SimpleNamespace) + self.assertEqual(vars(ns2), {'x': 11, 'y': 22}) + ns2.x = 3 + self.assertEqual(ns.x, 11) + ns.x = 4 + self.assertEqual(ns2.x, 3) + + self.assertEqual(vars(copy.replace(ns, x=1)), {'x': 1, 'y': 22}) + self.assertEqual(vars(copy.replace(ns, y=2)), {'x': 4, 'y': 2}) + self.assertEqual(vars(copy.replace(ns, x=1, y=2)), {'x': 1, 'y': 2}) + + def test_replace_subclass(self): + class Spam(types.SimpleNamespace): + pass + + spam = Spam(ham=8, eggs=9) + spam2 = copy.replace(spam, ham=5) + + self.assertIs(type(spam2), Spam) + self.assertEqual(vars(spam2), {'ham': 5, 'eggs': 9}) + def test_fake_namespace_compare(self): # Issue #24257: Incorrect use of PyObject_IsInstance() caused # SystemError. @@ -1841,8 +2174,7 @@ def foo(): foo = types.coroutine(foo) self.assertIs(aw, foo()) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_async_def(self): # Test that types.coroutine passes 'async def' coroutines # without modification @@ -2076,7 +2408,7 @@ def foo(): return gen wrapper = foo() wrapper.send(None) with self.assertRaisesRegex(Exception, 'ham'): - wrapper.throw(Exception, Exception('ham')) + wrapper.throw(Exception('ham')) # decorate foo second time foo = types.coroutine(foo) @@ -2099,8 +2431,7 @@ def foo(): foo = types.coroutine(foo) self.assertIs(foo(), gencoro) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_genfunc(self): def gen(): yield self.assertIs(types.coroutine(gen), gen) @@ -2131,5 +2462,125 @@ def coro(): 'close', 'throw'})) +class FunctionTests(unittest.TestCase): + def test_function_type_defaults(self): + def ex(a, /, b, *, c): + return a + b + c + + func = types.FunctionType( + ex.__code__, {}, "func", (1, 2), None, {'c': 3}, + ) + + self.assertEqual(func(), 6) + self.assertEqual(func.__defaults__, (1, 2)) + self.assertEqual(func.__kwdefaults__, {'c': 3}) + + func = types.FunctionType( + ex.__code__, {}, "func", None, None, None, + ) + self.assertEqual(func.__defaults__, None) + self.assertEqual(func.__kwdefaults__, None) + + def test_function_type_wrong_defaults(self): + def ex(a, /, b, *, c): + return a + b + c + + with self.assertRaisesRegex(TypeError, 'arg 4'): + types.FunctionType( + ex.__code__, {}, "func", 1, None, {'c': 3}, + ) + with self.assertRaisesRegex(TypeError, 'arg 6'): + types.FunctionType( + ex.__code__, {}, "func", None, None, 3, + ) + + +@unittest.skip("TODO: RUSTPYTHON; no subinterpreters yet") +class SubinterpreterTests(unittest.TestCase): + + NUMERIC_METHODS = { + '__abs__', + '__add__', + '__bool__', + '__divmod__', + '__float__', + '__floordiv__', + '__index__', + '__int__', + '__lshift__', + '__mod__', + '__mul__', + '__neg__', + '__pos__', + '__pow__', + '__radd__', + '__rdivmod__', + '__rfloordiv__', + '__rlshift__', + '__rmod__', + '__rmul__', + '__rpow__', + '__rrshift__', + '__rshift__', + '__rsub__', + '__rtruediv__', + '__sub__', + '__truediv__', + } + + @classmethod + def setUpClass(cls): + global interpreters + try: + from concurrent import interpreters + except ModuleNotFoundError: + raise unittest.SkipTest('subinterpreters required') + from test.support import channels # noqa: F401 + cls.create_channel = staticmethod(channels.create) + + @cpython_only + @no_rerun('channels (and queues) might have a refleak; see gh-122199') + def test_static_types_inherited_slots(self): + rch, sch = self.create_channel() + + script = textwrap.dedent(""" + import test.support + results = [] + for cls in test.support.iter_builtin_types(): + for attr, _ in test.support.iter_slot_wrappers(cls): + wrapper = getattr(cls, attr) + res = (cls, attr, wrapper) + results.append(res) + results = tuple((repr(c), a, repr(w)) for c, a, w in results) + sch.send_nowait(results) + """) + def collate_results(raw): + results = {} + for cls, attr, wrapper in raw: + key = cls, attr + assert key not in results, (results, key, wrapper) + results[key] = wrapper + return results + + exec(script) + raw = rch.recv_nowait() + main_results = collate_results(raw) + + interp = interpreters.create() + interp.exec('from concurrent import interpreters') + interp.prepare_main(sch=sch) + interp.exec(script) + raw = rch.recv_nowait() + interp_results = collate_results(raw) + + for key, expected in main_results.items(): + cls, attr = key + with self.subTest(cls=cls, slotattr=attr): + actual = interp_results.pop(key) + self.assertEqual(actual, expected) + self.maxDiff = None + self.assertEqual(interp_results, {}) + + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index d96a3a6a5d4..3d101c62e12 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -1,3 +1,4 @@ +import annotationlib import contextlib import collections import collections.abc @@ -5,8 +6,10 @@ from functools import lru_cache, wraps, reduce import gc import inspect +import io import itertools import operator +import os import pickle import re import sys @@ -43,13 +46,19 @@ import textwrap import typing import weakref +import warnings import types -from test.support import captured_stderr, cpython_only, infinite_recursion, requires_docstrings, import_helper -from test.support.testcase import ExtraAssertions -from test.typinganndata import ann_module695, mod_generics_cache, _typed_dict_helper +from test.support import ( + captured_stderr, cpython_only, infinite_recursion, requires_docstrings, import_helper, run_code, + EqualToForwardRef, +) +from test.typinganndata import ( + ann_module695, mod_generics_cache, _typed_dict_helper, + ann_module, ann_module2, ann_module3, ann_module5, ann_module6, ann_module8 +) -import unittest # XXX: RUSTPYTHON +import unittest # XXX: RUSTPYTHON; importing to be able to skip tests CANNOT_SUBCLASS_TYPE = 'Cannot subclass special typing classes' @@ -57,7 +66,7 @@ CANNOT_SUBCLASS_INSTANCE = 'Cannot subclass an instance of %s' -class BaseTestCase(TestCase, ExtraAssertions): +class BaseTestCase(TestCase): def clear_caches(self): for f in typing._cleanups: @@ -115,18 +124,18 @@ def test_errors(self): def test_can_subclass(self): class Mock(Any): pass - self.assertTrue(issubclass(Mock, Any)) + self.assertIsSubclass(Mock, Any) self.assertIsInstance(Mock(), Mock) class Something: pass - self.assertFalse(issubclass(Something, Any)) + self.assertNotIsSubclass(Something, Any) self.assertNotIsInstance(Something(), Mock) class MockSomething(Something, Mock): pass - self.assertTrue(issubclass(MockSomething, Any)) - self.assertTrue(issubclass(MockSomething, MockSomething)) - self.assertTrue(issubclass(MockSomething, Something)) - self.assertTrue(issubclass(MockSomething, Mock)) + self.assertIsSubclass(MockSomething, Any) + self.assertIsSubclass(MockSomething, MockSomething) + self.assertIsSubclass(MockSomething, Something) + self.assertIsSubclass(MockSomething, Mock) ms = MockSomething() self.assertIsInstance(ms, MockSomething) self.assertIsInstance(ms, Something) @@ -373,6 +382,7 @@ def test_alias(self): self.assertEqual(get_args(alias_2), (LiteralString,)) self.assertEqual(get_args(alias_3), (LiteralString,)) + class TypeVarTests(BaseTestCase): def test_basic_plain(self): T = TypeVar('T') @@ -467,8 +477,8 @@ def test_or(self): self.assertEqual(X | "x", Union[X, "x"]) self.assertEqual("x" | X, Union["x", X]) # make sure the order is correct - self.assertEqual(get_args(X | "x"), (X, ForwardRef("x"))) - self.assertEqual(get_args("x" | X), (ForwardRef("x"), X)) + self.assertEqual(get_args(X | "x"), (X, EqualToForwardRef("x"))) + self.assertEqual(get_args("x" | X), (EqualToForwardRef("x"), X)) def test_union_constrained(self): A = TypeVar('A', str, bytes) @@ -502,7 +512,7 @@ def test_cannot_instantiate_vars(self): def test_bound_errors(self): with self.assertRaises(TypeError): - TypeVar('X', bound=Union) + TypeVar('X', bound=Optional) with self.assertRaises(TypeError): TypeVar('X', str, float, bound=Employee) with self.assertRaisesRegex(TypeError, @@ -542,7 +552,7 @@ def test_var_substitution(self): def test_bad_var_substitution(self): T = TypeVar('T') bad_args = ( - (), (int, str), Union, + (), (int, str), Optional, Generic, Generic[T], Protocol, Protocol[T], Final, Final[int], ClassVar, ClassVar[int], ) @@ -625,7 +635,7 @@ class TypeParameterDefaultsTests(BaseTestCase): def test_typevar(self): T = TypeVar('T', default=int) self.assertEqual(T.__default__, int) - self.assertTrue(T.has_default()) + self.assertIs(T.has_default(), True) self.assertIsInstance(T, TypeVar) class A(Generic[T]): ... @@ -635,19 +645,19 @@ def test_typevar_none(self): U = TypeVar('U') U_None = TypeVar('U_None', default=None) self.assertIs(U.__default__, NoDefault) - self.assertFalse(U.has_default()) + self.assertIs(U.has_default(), False) self.assertIs(U_None.__default__, None) - self.assertTrue(U_None.has_default()) + self.assertIs(U_None.has_default(), True) class X[T]: ... T, = X.__type_params__ self.assertIs(T.__default__, NoDefault) - self.assertFalse(T.has_default()) + self.assertIs(T.has_default(), False) def test_paramspec(self): P = ParamSpec('P', default=(str, int)) self.assertEqual(P.__default__, (str, int)) - self.assertTrue(P.has_default()) + self.assertIs(P.has_default(), True) self.assertIsInstance(P, ParamSpec) class A(Generic[P]): ... @@ -660,19 +670,19 @@ def test_paramspec_none(self): U = ParamSpec('U') U_None = ParamSpec('U_None', default=None) self.assertIs(U.__default__, NoDefault) - self.assertFalse(U.has_default()) + self.assertIs(U.has_default(), False) self.assertIs(U_None.__default__, None) - self.assertTrue(U_None.has_default()) + self.assertIs(U_None.has_default(), True) class X[**P]: ... P, = X.__type_params__ self.assertIs(P.__default__, NoDefault) - self.assertFalse(P.has_default()) + self.assertIs(P.has_default(), False) def test_typevartuple(self): Ts = TypeVarTuple('Ts', default=Unpack[Tuple[str, int]]) self.assertEqual(Ts.__default__, Unpack[Tuple[str, int]]) - self.assertTrue(Ts.has_default()) + self.assertIs(Ts.has_default(), True) self.assertIsInstance(Ts, TypeVarTuple) class A(Generic[Unpack[Ts]]): ... @@ -754,18 +764,28 @@ class A(Generic[T, P, U]): ... self.assertEqual(A[float, [range]].__args__, (float, (range,), float)) self.assertEqual(A[float, [range], int].__args__, (float, (range,), int)) + def test_paramspec_and_typevar_specialization_2(self): + T = TypeVar("T") + P = ParamSpec('P', default=...) + U = TypeVar("U", default=float) + self.assertEqual(P.__default__, ...) + class A(Generic[T, P, U]): ... + self.assertEqual(A[float].__args__, (float, ..., float)) + self.assertEqual(A[float, [range]].__args__, (float, (range,), float)) + self.assertEqual(A[float, [range], int].__args__, (float, (range,), int)) + def test_typevartuple_none(self): U = TypeVarTuple('U') U_None = TypeVarTuple('U_None', default=None) self.assertIs(U.__default__, NoDefault) - self.assertFalse(U.has_default()) + self.assertIs(U.has_default(), False) self.assertIs(U_None.__default__, None) - self.assertTrue(U_None.has_default()) + self.assertIs(U_None.has_default(), True) class X[**Ts]: ... Ts, = X.__type_params__ self.assertIs(Ts.__default__, NoDefault) - self.assertFalse(Ts.has_default()) + self.assertIs(Ts.has_default(), False) def test_no_default_after_non_default(self): DefaultStrT = TypeVar('DefaultStrT', default=str) @@ -966,7 +986,7 @@ class C(Generic[T]): pass ) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_two_parameters(self): T1 = TypeVar('T1') T2 = TypeVar('T2') @@ -1064,7 +1084,7 @@ class C(Generic[T1, T2, T3]): pass eval(expected_str) ) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_variadic_parameters(self): T1 = TypeVar('T1') T2 = TypeVar('T2') @@ -1168,7 +1188,6 @@ class C(Generic[*Ts]): pass ) - class UnpackTests(BaseTestCase): def test_accepts_single_type(self): @@ -1999,11 +2018,11 @@ def test_basics(self): self.assertNotEqual(u, Union) def test_union_isinstance(self): - self.assertTrue(isinstance(42, Union[int, str])) - self.assertTrue(isinstance('abc', Union[int, str])) - self.assertFalse(isinstance(3.14, Union[int, str])) - self.assertTrue(isinstance(42, Union[int, list[int]])) - self.assertTrue(isinstance(42, Union[int, Any])) + self.assertIsInstance(42, Union[int, str]) + self.assertIsInstance('abc', Union[int, str]) + self.assertNotIsInstance(3.14, Union[int, str]) + self.assertIsInstance(42, Union[int, list[int]]) + self.assertIsInstance(42, Union[int, Any]) def test_union_isinstance_type_error(self): with self.assertRaises(TypeError): @@ -2020,9 +2039,9 @@ def test_union_isinstance_type_error(self): isinstance(42, Union[Any, str]) def test_optional_isinstance(self): - self.assertTrue(isinstance(42, Optional[int])) - self.assertTrue(isinstance(None, Optional[int])) - self.assertFalse(isinstance('abc', Optional[int])) + self.assertIsInstance(42, Optional[int]) + self.assertIsInstance(None, Optional[int]) + self.assertNotIsInstance('abc', Optional[int]) def test_optional_isinstance_type_error(self): with self.assertRaises(TypeError): @@ -2035,20 +2054,16 @@ def test_optional_isinstance_type_error(self): isinstance(None, Optional[Any]) def test_union_issubclass(self): - self.assertTrue(issubclass(int, Union[int, str])) - self.assertTrue(issubclass(str, Union[int, str])) - self.assertFalse(issubclass(float, Union[int, str])) - self.assertTrue(issubclass(int, Union[int, list[int]])) - self.assertTrue(issubclass(int, Union[int, Any])) - self.assertFalse(issubclass(int, Union[str, Any])) - self.assertTrue(issubclass(int, Union[Any, int])) - self.assertFalse(issubclass(int, Union[Any, str])) + self.assertIsSubclass(int, Union[int, str]) + self.assertIsSubclass(str, Union[int, str]) + self.assertNotIsSubclass(float, Union[int, str]) + self.assertIsSubclass(int, Union[int, list[int]]) + self.assertIsSubclass(int, Union[int, Any]) + self.assertNotIsSubclass(int, Union[str, Any]) + self.assertIsSubclass(int, Union[Any, int]) + self.assertNotIsSubclass(int, Union[Any, str]) def test_union_issubclass_type_error(self): - with self.assertRaises(TypeError): - issubclass(int, Union) - with self.assertRaises(TypeError): - issubclass(Union, int) with self.assertRaises(TypeError): issubclass(Union[int, str], int) with self.assertRaises(TypeError): @@ -2059,12 +2074,12 @@ def test_union_issubclass_type_error(self): issubclass(int, Union[list[int], str]) def test_optional_issubclass(self): - self.assertTrue(issubclass(int, Optional[int])) - self.assertTrue(issubclass(type(None), Optional[int])) - self.assertFalse(issubclass(str, Optional[int])) - self.assertTrue(issubclass(Any, Optional[Any])) - self.assertTrue(issubclass(type(None), Optional[Any])) - self.assertFalse(issubclass(int, Optional[Any])) + self.assertIsSubclass(int, Optional[int]) + self.assertIsSubclass(type(None), Optional[int]) + self.assertNotIsSubclass(str, Optional[int]) + self.assertIsSubclass(Any, Optional[Any]) + self.assertIsSubclass(type(None), Optional[Any]) + self.assertNotIsSubclass(int, Optional[Any]) def test_optional_issubclass_type_error(self): with self.assertRaises(TypeError): @@ -2123,41 +2138,40 @@ class B(metaclass=UnhashableMeta): ... self.assertEqual(Union[A, B].__args__, (A, B)) union1 = Union[A, B] - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"): hash(union1) union2 = Union[int, B] - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"): hash(union2) union3 = Union[A, int] - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"): hash(union3) def test_repr(self): - self.assertEqual(repr(Union), 'typing.Union') u = Union[Employee, int] - self.assertEqual(repr(u), 'typing.Union[%s.Employee, int]' % __name__) + self.assertEqual(repr(u), f'{__name__}.Employee | int') u = Union[int, Employee] - self.assertEqual(repr(u), 'typing.Union[int, %s.Employee]' % __name__) + self.assertEqual(repr(u), f'int | {__name__}.Employee') T = TypeVar('T') u = Union[T, int][int] self.assertEqual(repr(u), repr(int)) u = Union[List[int], int] - self.assertEqual(repr(u), 'typing.Union[typing.List[int], int]') + self.assertEqual(repr(u), 'typing.List[int] | int') u = Union[list[int], dict[str, float]] - self.assertEqual(repr(u), 'typing.Union[list[int], dict[str, float]]') + self.assertEqual(repr(u), 'list[int] | dict[str, float]') u = Union[int | float] - self.assertEqual(repr(u), 'typing.Union[int, float]') + self.assertEqual(repr(u), 'int | float') u = Union[None, str] - self.assertEqual(repr(u), 'typing.Optional[str]') + self.assertEqual(repr(u), 'None | str') u = Union[str, None] - self.assertEqual(repr(u), 'typing.Optional[str]') + self.assertEqual(repr(u), 'str | None') u = Union[None, str, int] - self.assertEqual(repr(u), 'typing.Union[NoneType, str, int]') + self.assertEqual(repr(u), 'None | str | int') u = Optional[str] - self.assertEqual(repr(u), 'typing.Optional[str]') + self.assertEqual(repr(u), 'str | None') def test_dir(self): dir_items = set(dir(Union[str, int])) @@ -2169,14 +2183,11 @@ def test_dir(self): def test_cannot_subclass(self): with self.assertRaisesRegex(TypeError, - r'Cannot subclass typing\.Union'): + r"type 'typing\.Union' is not an acceptable base type"): class C(Union): pass - with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): - class D(type(Union)): - pass with self.assertRaisesRegex(TypeError, - r'Cannot subclass typing\.Union\[int, str\]'): + r'Cannot subclass int \| str'): class E(Union[int, str]): pass @@ -2192,8 +2203,8 @@ def test_cannot_instantiate(self): type(u)() def test_union_generalization(self): - self.assertFalse(Union[str, typing.Iterable[int]] == str) - self.assertFalse(Union[str, typing.Iterable[int]] == typing.Iterable[int]) + self.assertNotEqual(Union[str, typing.Iterable[int]], str) + self.assertNotEqual(Union[str, typing.Iterable[int]], typing.Iterable[int]) self.assertIn(str, Union[str, typing.Iterable[int]].__args__) self.assertIn(typing.Iterable[int], Union[str, typing.Iterable[int]].__args__) @@ -2222,7 +2233,7 @@ def f(x: u): ... def test_function_repr_union(self): def fun() -> int: ... - self.assertEqual(repr(Union[fun, int]), 'typing.Union[fun, int]') + self.assertEqual(repr(Union[fun, int]), f'{__name__}.{fun.__qualname__} | int') def test_union_str_pattern(self): # Shouldn't crash; see http://bugs.python.org/issue25390 @@ -2270,6 +2281,16 @@ class Ints(enum.IntEnum): self.assertEqual(Union[Literal[1], Literal[Ints.B], Literal[True]].__args__, (Literal[1], Literal[Ints.B], Literal[True])) + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: types.UnionType[int, str] | float != types.UnionType[int, str, float] + def test_allow_non_types_in_or(self): + # gh-140348: Test that using | with a Union object allows things that are + # not allowed by is_unionable(). + U1 = Union[int, str] + self.assertEqual(U1 | float, Union[int, str, float]) + self.assertEqual(U1 | "float", Union[int, str, "float"]) + self.assertEqual(float | U1, Union[float, int, str]) + self.assertEqual("float" | U1, Union["float", int, str]) + class TupleTests(BaseTestCase): @@ -2557,7 +2578,7 @@ def test_concatenate(self): def test_nested_paramspec(self): # Since Callable has some special treatment, we want to be sure - # that substituion works correctly, see gh-103054 + # that substitution works correctly, see gh-103054 Callable = self.Callable P = ParamSpec('P') P2 = ParamSpec('P2') @@ -2609,6 +2630,7 @@ def test_errors(self): with self.assertRaisesRegex(TypeError, "few arguments for"): C1[int] + class TypingCallableTests(BaseCallableTests, BaseTestCase): Callable = typing.Callable @@ -2786,6 +2808,7 @@ class Coordinate(Protocol): x: int y: int + @runtime_checkable class Point(Coordinate, Protocol): label: str @@ -3155,6 +3178,21 @@ def x(self): ... with self.assertRaisesRegex(TypeError, only_classes_allowed): issubclass(1, BadPG) + def test_isinstance_against_superproto_doesnt_affect_subproto_instance(self): + @runtime_checkable + class Base(Protocol): + x: int + + @runtime_checkable + class Child(Base, Protocol): + y: str + + class Capybara: + x = 43 + + self.assertIsInstance(Capybara(), Base) + self.assertNotIsInstance(Capybara(), Child) + def test_implicit_issubclass_between_two_protocols(self): @runtime_checkable class CallableMembersProto(Protocol): @@ -3229,7 +3267,7 @@ def meth2(self, x, y): return True self.assertIsSubclass(NotAProtocolButAnImplicitSubclass2, CallableMembersProto) self.assertIsSubclass(NotAProtocolButAnImplicitSubclass3, CallableMembersProto) - @unittest.skip('TODO: RUSTPYTHON; (no gc)') + @unittest.skip("TODO: RUSTPYTHON; (no gc)") def test_isinstance_checks_not_at_whim_of_gc(self): self.addCleanup(gc.enable) gc.disable() @@ -3803,7 +3841,6 @@ def __init__(self): self.assertNotIsInstance(B(), P) self.assertNotIsInstance(C(), P) - @unittest.expectedFailure # TODO: RUSTPYTHON; test doesn't include PEP 649 attrs def test_non_protocol_subclasses(self): class P(Protocol): x = 1 @@ -3849,7 +3886,8 @@ def meth(self): pass acceptable_extra_attrs = { '_is_protocol', '_is_runtime_protocol', '__parameters__', - '__init__', '__annotations__', '__subclasshook__', + '__init__', '__annotations__', '__subclasshook__', '__annotate__', + '__annotations_cache__', '__annotate_func__', } self.assertLessEqual(vars(NonP).keys(), vars(C).keys() | acceptable_extra_attrs) self.assertLessEqual( @@ -4071,8 +4109,8 @@ def test_generic_protocols_repr(self): class P(Protocol[T, S]): pass - self.assertTrue(repr(P[T, S]).endswith('P[~T, ~S]')) - self.assertTrue(repr(P[int, str]).endswith('P[int, str]')) + self.assertEndsWith(repr(P[T, S]), 'P[~T, ~S]') + self.assertEndsWith(repr(P[int, str]), 'P[int, str]') def test_generic_protocols_eq(self): T = TypeVar('T') @@ -4112,12 +4150,12 @@ class PG(Protocol[T]): def meth(self): pass - self.assertTrue(P._is_protocol) - self.assertTrue(PR._is_protocol) - self.assertTrue(PG._is_protocol) - self.assertFalse(P._is_runtime_protocol) - self.assertTrue(PR._is_runtime_protocol) - self.assertTrue(PG[int]._is_protocol) + self.assertIs(P._is_protocol, True) + self.assertIs(PR._is_protocol, True) + self.assertIs(PG._is_protocol, True) + self.assertIs(P._is_runtime_protocol, False) + self.assertIs(PR._is_runtime_protocol, True) + self.assertIs(PG[int]._is_protocol, True) self.assertEqual(typing._get_protocol_attrs(P), {'meth'}) self.assertEqual(typing._get_protocol_attrs(PR), {'x'}) self.assertEqual(frozenset(typing._get_protocol_attrs(PG)), @@ -4173,7 +4211,7 @@ class P(Protocol): Alias2 = typing.Union[P, typing.Iterable] self.assertEqual(Alias, Alias2) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_protocols_pickleable(self): global P, CP # pickle wants to reference the class by name T = TypeVar('T') @@ -4324,11 +4362,50 @@ def __release_buffer__(self, mv: memoryview) -> None: self.assertNotIsSubclass(C, ReleasableBuffer) self.assertNotIsInstance(C(), ReleasableBuffer) + @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: module 'io' has no attribute 'Reader' + def test_io_reader_protocol_allowed(self): + @runtime_checkable + class CustomReader(io.Reader[bytes], Protocol): + def close(self): ... + + class A: pass + class B: + def read(self, sz=-1): + return b"" + def close(self): + pass + + self.assertIsSubclass(B, CustomReader) + self.assertIsInstance(B(), CustomReader) + self.assertNotIsSubclass(A, CustomReader) + self.assertNotIsInstance(A(), CustomReader) + + @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: module 'io' has no attribute 'Writer' + def test_io_writer_protocol_allowed(self): + @runtime_checkable + class CustomWriter(io.Writer[bytes], Protocol): + def close(self): ... + + class A: pass + class B: + def write(self, b): + pass + def close(self): + pass + + self.assertIsSubclass(B, CustomWriter) + self.assertIsInstance(B(), CustomWriter) + self.assertNotIsSubclass(A, CustomWriter) + self.assertNotIsInstance(A(), CustomWriter) + def test_builtin_protocol_allowlist(self): with self.assertRaises(TypeError): class CustomProtocol(TestCase, Protocol): pass + class CustomPathLikeProtocol(os.PathLike, Protocol): + pass + class CustomContextManager(typing.ContextManager, Protocol): pass @@ -4542,6 +4619,42 @@ class Commentable(Protocol): ) self.assertIs(type(exc.__cause__), CustomError) + def test_isinstance_with_deferred_evaluation_of_annotations(self): + @runtime_checkable + class P(Protocol): + def meth(self): + ... + + class DeferredClass: + x: undefined + + class DeferredClassImplementingP: + x: undefined | int + + def __init__(self): + self.x = 0 + + def meth(self): + ... + + # override meth with a non-method attribute to make it part of __annotations__ instead of __dict__ + class SubProtocol(P, Protocol): + meth: undefined + + + self.assertIsSubclass(SubProtocol, P) + self.assertNotIsInstance(DeferredClass(), P) + self.assertIsInstance(DeferredClassImplementingP(), P) + + def test_deferred_evaluation_of_annotations(self): + class DeferredProto(Protocol): + x: DoesNotExist + self.assertEqual(get_protocol_members(DeferredProto), {"x"}) + self.assertEqual( + annotationlib.get_annotations(DeferredProto, format=annotationlib.Format.STRING), + {'x': 'DoesNotExist'} + ) + class GenericTests(BaseTestCase): @@ -4589,6 +4702,35 @@ class D(Generic[T]): pass with self.assertRaises(TypeError): D[()] + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_generic_init_subclass_not_called_error(self): + notes = ["Note: this exception may have been caused by " + r"'GenericTests.test_generic_init_subclass_not_called_error..Base.__init_subclass__' " + "(or the '__init_subclass__' method on a superclass) not calling 'super().__init_subclass__()'"] + + class Base: + def __init_subclass__(cls) -> None: + # Oops, I forgot super().__init_subclass__()! + pass + + with self.subTest(): + class Sub(Base, Generic[T]): + pass + + with self.assertRaises(AttributeError) as cm: + Sub[int] + + self.assertEqual(cm.exception.__notes__, notes) + + with self.subTest(): + class Sub[U](Base): + pass + + with self.assertRaises(AttributeError) as cm: + Sub[int] + + self.assertEqual(cm.exception.__notes__, notes) + def test_generic_subclass_checks(self): for typ in [list[int], List[int], tuple[int, str], Tuple[int, str], @@ -4660,8 +4802,7 @@ class C(Generic[T]): self.assertNotEqual(Z, Y[int]) self.assertNotEqual(Z, Y[T]) - self.assertTrue(str(Z).endswith( - '.C[typing.Tuple[str, int]]')) + self.assertEndsWith(str(Z), '.C[typing.Tuple[str, int]]') def test_new_repr(self): T = TypeVar('T') @@ -4889,12 +5030,12 @@ class A(Generic[T]): self.assertNotEqual(typing.FrozenSet[A[str]], typing.FrozenSet[mod_generics_cache.B.A[str]]) - self.assertTrue(repr(Tuple[A[str]]).endswith('.A[str]]')) - self.assertTrue(repr(Tuple[B.A[str]]).endswith('.B.A[str]]')) - self.assertTrue(repr(Tuple[mod_generics_cache.A[str]]) - .endswith('mod_generics_cache.A[str]]')) - self.assertTrue(repr(Tuple[mod_generics_cache.B.A[str]]) - .endswith('mod_generics_cache.B.A[str]]')) + self.assertEndsWith(repr(Tuple[A[str]]), '.A[str]]') + self.assertEndsWith(repr(Tuple[B.A[str]]), '.B.A[str]]') + self.assertEndsWith(repr(Tuple[mod_generics_cache.A[str]]), + 'mod_generics_cache.A[str]]') + self.assertEndsWith(repr(Tuple[mod_generics_cache.B.A[str]]), + 'mod_generics_cache.B.A[str]]') def test_extended_generic_rules_eq(self): T = TypeVar('T') @@ -4915,11 +5056,11 @@ class Derived(Base): ... def test_extended_generic_rules_repr(self): T = TypeVar('T') self.assertEqual(repr(Union[Tuple, Callable]).replace('typing.', ''), - 'Union[Tuple, Callable]') + 'Tuple | Callable') self.assertEqual(repr(Union[Tuple, Tuple[int]]).replace('typing.', ''), - 'Union[Tuple, Tuple[int]]') + 'Tuple | Tuple[int]') self.assertEqual(repr(Callable[..., Optional[T]][int]).replace('typing.', ''), - 'Callable[..., Optional[int]]') + 'Callable[..., int | None]') self.assertEqual(repr(Callable[[], List[T]][int]).replace('typing.', ''), 'Callable[[], List[int]]') @@ -4985,7 +5126,7 @@ class C3: def f(x: X): ... self.assertEqual( get_type_hints(f, globals(), locals()), - {'x': list[list[ForwardRef('X')]]} + {'x': list[list[EqualToForwardRef('X')]]} ) def test_pep695_generic_class_with_future_annotations(self): @@ -5099,9 +5240,9 @@ def __contains__(self, item): with self.assertRaises(TypeError): issubclass(Tuple[int, ...], typing.Iterable) - def test_fail_with_bare_union(self): + def test_fail_with_special_forms(self): with self.assertRaises(TypeError): - List[Union] + List[Final] with self.assertRaises(TypeError): Tuple[Optional] with self.assertRaises(TypeError): @@ -5148,7 +5289,7 @@ def test_all_repr_eq_any(self): self.assertNotEqual(repr(base), '') self.assertEqual(base, base) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_pickle(self): global C # pickle wants to reference the class by name T = TypeVar('T') @@ -5205,10 +5346,12 @@ class Node(Generic[T]): ... Tuple[Any, Any], Node[T], Node[int], Node[Any], typing.Iterable[T], typing.Iterable[Any], typing.Iterable[int], typing.Dict[int, str], typing.Dict[T, Any], ClassVar[int], ClassVar[List[T]], Tuple['T', 'T'], - Union['T', int], List['T'], typing.Mapping['T', int]] - for t in things + [Any]: - self.assertEqual(t, copy(t)) - self.assertEqual(t, deepcopy(t)) + Union['T', int], List['T'], typing.Mapping['T', int], + Union[b"x", b"y"], Any] + for t in things: + with self.subTest(thing=t): + self.assertEqual(t, copy(t)) + self.assertEqual(t, deepcopy(t)) def test_immutability_by_copy_and_pickle(self): # Special forms like Union, Any, etc., generic aliases to containers like List, @@ -5644,8 +5787,6 @@ def test_subclass_special_form(self): for obj in ( ClassVar[int], Final[int], - Union[int, float], - Optional[int], Literal[1, 2], Concatenate[int, ParamSpec("P")], TypeGuard[int], @@ -5677,7 +5818,7 @@ class A: __parameters__ = (T,) # Bare classes should be skipped for a in (List, list): - for b in (A, int, TypeVar, TypeVarTuple, ParamSpec, types.GenericAlias, types.UnionType): + for b in (A, int, TypeVar, TypeVarTuple, ParamSpec, types.GenericAlias, Union): with self.subTest(generic=a, sub=b): with self.assertRaisesRegex(TypeError, '.* is not a generic class'): a[b][str] @@ -5696,7 +5837,7 @@ class A: for s in (int, G, A, List, list, TypeVar, TypeVarTuple, ParamSpec, - types.GenericAlias, types.UnionType): + types.GenericAlias, Union): for t in Tuple, tuple: with self.subTest(tuple=t, sub=s): @@ -5714,7 +5855,7 @@ class A: with self.assertRaises(TypeError): a[int] - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: ".+__typing_subst__.+tuple.+int.*" does not match "'TypeAliasType' object is not subscriptable" + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: ".+__typing_subst__.+tuple.+int.*" does not match "'TypeAliasType' object is not subscriptable" def test_return_non_tuple_while_unpacking(self): # GH-138497: GenericAlias objects didn't ensure that __typing_subst__ actually # returned a tuple @@ -5778,6 +5919,7 @@ def test_no_isinstance(self): with self.assertRaises(TypeError): issubclass(int, ClassVar) + class FinalTests(BaseTestCase): def test_basics(self): @@ -5834,7 +5976,7 @@ def test_final_unmodified(self): def func(x): ... self.assertIs(func, final(func)) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_dunder_final(self): @final def func(): ... @@ -5856,7 +5998,7 @@ def __call__(self, *args, **kwargs): @Wrapper def wrapped(): ... self.assertIsInstance(wrapped, Wrapper) - self.assertIs(False, hasattr(wrapped, "__final__")) + self.assertNotHasAttr(wrapped, "__final__") class Meta(type): @property @@ -5868,7 +6010,7 @@ class WithMeta(metaclass=Meta): ... # Builtin classes throw TypeError if you try to set an # attribute. final(int) - self.assertIs(False, hasattr(int, "__final__")) + self.assertNotHasAttr(int, "__final__") # Make sure it works with common builtin decorators class Methods: @@ -5949,19 +6091,19 @@ def static_method_bad_order(): self.assertEqual(Derived.class_method_good_order(), 42) self.assertIs(True, Derived.class_method_good_order.__override__) self.assertEqual(Derived.class_method_bad_order(), 42) - self.assertIs(False, hasattr(Derived.class_method_bad_order, "__override__")) + self.assertNotHasAttr(Derived.class_method_bad_order, "__override__") self.assertEqual(Derived.static_method_good_order(), 42) self.assertIs(True, Derived.static_method_good_order.__override__) self.assertEqual(Derived.static_method_bad_order(), 42) - self.assertIs(False, hasattr(Derived.static_method_bad_order, "__override__")) + self.assertNotHasAttr(Derived.static_method_bad_order, "__override__") # Base object is not changed: - self.assertIs(False, hasattr(Base.normal_method, "__override__")) - self.assertIs(False, hasattr(Base.class_method_good_order, "__override__")) - self.assertIs(False, hasattr(Base.class_method_bad_order, "__override__")) - self.assertIs(False, hasattr(Base.static_method_good_order, "__override__")) - self.assertIs(False, hasattr(Base.static_method_bad_order, "__override__")) + self.assertNotHasAttr(Base.normal_method, "__override__") + self.assertNotHasAttr(Base.class_method_good_order, "__override__") + self.assertNotHasAttr(Base.class_method_bad_order, "__override__") + self.assertNotHasAttr(Base.static_method_good_order, "__override__") + self.assertNotHasAttr(Base.static_method_bad_order, "__override__") def test_property(self): class Base: @@ -5984,10 +6126,10 @@ def wrong(self) -> int: instance = Child() self.assertEqual(instance.correct, 2) - self.assertTrue(Child.correct.fget.__override__) + self.assertIs(Child.correct.fget.__override__, True) self.assertEqual(instance.wrong, 2) - self.assertFalse(hasattr(Child.wrong, "__override__")) - self.assertFalse(hasattr(Child.wrong.fset, "__override__")) + self.assertNotHasAttr(Child.wrong, "__override__") + self.assertNotHasAttr(Child.wrong.fset, "__override__") def test_silent_failure(self): class CustomProp: @@ -6004,7 +6146,7 @@ def some(self): return 1 self.assertEqual(WithOverride.some, 1) - self.assertFalse(hasattr(WithOverride.some, "__override__")) + self.assertNotHasAttr(WithOverride.some, "__override__") def test_multiple_decorators(self): def with_wraps(f): # similar to `lru_cache` definition @@ -6025,9 +6167,9 @@ def on_bottom(self, a: int) -> int: instance = WithOverride() self.assertEqual(instance.on_top(1), 2) - self.assertTrue(instance.on_top.__override__) + self.assertIs(instance.on_top.__override__, True) self.assertEqual(instance.on_bottom(1), 3) - self.assertTrue(instance.on_bottom.__override__) + self.assertIs(instance.on_bottom.__override__, True) class CastTests(BaseTestCase): @@ -6065,8 +6207,6 @@ def test_errors(self): # We need this to make sure that `@no_type_check` respects `__module__` attr: -from test.typinganndata import ann_module8 - @no_type_check class NoTypeCheck_Outer: Inner = ann_module8.NoTypeCheck_Outer.Inner @@ -6076,474 +6216,168 @@ class NoTypeCheck_WithFunction: NoTypeCheck_function = ann_module8.NoTypeCheck_function -class ForwardRefTests(BaseTestCase): - - def test_basics(self): +class NoTypeCheckTests(BaseTestCase): + def test_no_type_check(self): - class Node(Generic[T]): + @no_type_check + def foo(a: 'whatevers') -> {}: + pass - def __init__(self, label: T): - self.label = label - self.left = self.right = None + th = get_type_hints(foo) + self.assertEqual(th, {}) - def add_both(self, - left: 'Optional[Node[T]]', - right: 'Node[T]' = None, - stuff: int = None, - blah=None): - self.left = left - self.right = right + def test_no_type_check_class(self): - def add_left(self, node: Optional['Node[T]']): - self.add_both(node, None) + @no_type_check + class C: + def foo(a: 'whatevers') -> {}: + pass - def add_right(self, node: 'Node[T]' = None): - self.add_both(None, node) + cth = get_type_hints(C.foo) + self.assertEqual(cth, {}) + ith = get_type_hints(C().foo) + self.assertEqual(ith, {}) - t = Node[int] - both_hints = get_type_hints(t.add_both, globals(), locals()) - self.assertEqual(both_hints['left'], Optional[Node[T]]) - self.assertEqual(both_hints['right'], Node[T]) - self.assertEqual(both_hints['stuff'], int) - self.assertNotIn('blah', both_hints) + def test_no_type_check_no_bases(self): + class C: + def meth(self, x: int): ... + @no_type_check + class D(C): + c = C - left_hints = get_type_hints(t.add_left, globals(), locals()) - self.assertEqual(left_hints['node'], Optional[Node[T]]) + # verify that @no_type_check never affects bases + self.assertEqual(get_type_hints(C.meth), {'x': int}) - right_hints = get_type_hints(t.add_right, globals(), locals()) - self.assertEqual(right_hints['node'], Node[T]) + # and never child classes: + class Child(D): + def foo(self, x: int): ... - def test_forwardref_instance_type_error(self): - fr = typing.ForwardRef('int') - with self.assertRaises(TypeError): - isinstance(42, fr) + self.assertEqual(get_type_hints(Child.foo), {'x': int}) - def test_forwardref_subclass_type_error(self): - fr = typing.ForwardRef('int') - with self.assertRaises(TypeError): - issubclass(int, fr) + def test_no_type_check_nested_types(self): + # See https://bugs.python.org/issue46571 + class Other: + o: int + class B: # Has the same `__name__`` as `A.B` and different `__qualname__` + o: int + @no_type_check + class A: + a: int + class B: + b: int + class C: + c: int + class D: + d: int - def test_forwardref_only_str_arg(self): - with self.assertRaises(TypeError): - typing.ForwardRef(1) # only `str` type is allowed + Other = Other - def test_forward_equality(self): - fr = typing.ForwardRef('int') - self.assertEqual(fr, typing.ForwardRef('int')) - self.assertNotEqual(List['int'], List[int]) - self.assertNotEqual(fr, typing.ForwardRef('int', module=__name__)) - frm = typing.ForwardRef('int', module=__name__) - self.assertEqual(frm, typing.ForwardRef('int', module=__name__)) - self.assertNotEqual(frm, typing.ForwardRef('int', module='__other_name__')) + for klass in [A, A.B, A.B.C, A.D]: + with self.subTest(klass=klass): + self.assertIs(klass.__no_type_check__, True) + self.assertEqual(get_type_hints(klass), {}) - def test_forward_equality_gth(self): - c1 = typing.ForwardRef('C') - c1_gth = typing.ForwardRef('C') - c2 = typing.ForwardRef('C') - c2_gth = typing.ForwardRef('C') + for not_modified in [Other, B]: + with self.subTest(not_modified=not_modified): + with self.assertRaises(AttributeError): + not_modified.__no_type_check__ + self.assertNotEqual(get_type_hints(not_modified), {}) - class C: - pass - def foo(a: c1_gth, b: c2_gth): - pass + def test_no_type_check_class_and_static_methods(self): + @no_type_check + class Some: + @staticmethod + def st(x: int) -> int: ... + @classmethod + def cl(cls, y: int) -> int: ... - self.assertEqual(get_type_hints(foo, globals(), locals()), {'a': C, 'b': C}) - self.assertEqual(c1, c2) - self.assertEqual(c1, c1_gth) - self.assertEqual(c1_gth, c2_gth) - self.assertEqual(List[c1], List[c1_gth]) - self.assertNotEqual(List[c1], List[C]) - self.assertNotEqual(List[c1_gth], List[C]) - self.assertEqual(Union[c1, c1_gth], Union[c1]) - self.assertEqual(Union[c1, c1_gth, int], Union[c1, int]) - - def test_forward_equality_hash(self): - c1 = typing.ForwardRef('int') - c1_gth = typing.ForwardRef('int') - c2 = typing.ForwardRef('int') - c2_gth = typing.ForwardRef('int') - - def foo(a: c1_gth, b: c2_gth): - pass - get_type_hints(foo, globals(), locals()) + self.assertIs(Some.st.__no_type_check__, True) + self.assertEqual(get_type_hints(Some.st), {}) + self.assertIs(Some.cl.__no_type_check__, True) + self.assertEqual(get_type_hints(Some.cl), {}) - self.assertEqual(hash(c1), hash(c2)) - self.assertEqual(hash(c1_gth), hash(c2_gth)) - self.assertEqual(hash(c1), hash(c1_gth)) + def test_no_type_check_other_module(self): + self.assertIs(NoTypeCheck_Outer.__no_type_check__, True) + with self.assertRaises(AttributeError): + ann_module8.NoTypeCheck_Outer.__no_type_check__ + with self.assertRaises(AttributeError): + ann_module8.NoTypeCheck_Outer.Inner.__no_type_check__ - c3 = typing.ForwardRef('int', module=__name__) - c4 = typing.ForwardRef('int', module='__other_name__') + self.assertIs(NoTypeCheck_WithFunction.__no_type_check__, True) + with self.assertRaises(AttributeError): + ann_module8.NoTypeCheck_function.__no_type_check__ - self.assertNotEqual(hash(c3), hash(c1)) - self.assertNotEqual(hash(c3), hash(c1_gth)) - self.assertNotEqual(hash(c3), hash(c4)) - self.assertEqual(hash(c3), hash(typing.ForwardRef('int', module=__name__))) + def test_no_type_check_foreign_functions(self): + # We should not modify this function: + def some(*args: int) -> int: + ... - def test_forward_equality_namespace(self): + @no_type_check class A: - pass - def namespace1(): - a = typing.ForwardRef('A') - def fun(x: a): - pass - get_type_hints(fun, globals(), locals()) - return a + some_alias = some + some_class = classmethod(some) + some_static = staticmethod(some) - def namespace2(): - a = typing.ForwardRef('A') + with self.assertRaises(AttributeError): + some.__no_type_check__ + self.assertEqual(get_type_hints(some), {'args': int, 'return': int}) - class A: - pass - def fun(x: a): - pass + def test_no_type_check_lambda(self): + @no_type_check + class A: + # Corner case: `lambda` is both an assignment and a function: + bar: Callable[[int], int] = lambda arg: arg - get_type_hints(fun, globals(), locals()) - return a + self.assertIs(A.bar.__no_type_check__, True) + self.assertEqual(get_type_hints(A.bar), {}) - self.assertEqual(namespace1(), namespace1()) - self.assertNotEqual(namespace1(), namespace2()) + def test_no_type_check_TypeError(self): + # This simply should not fail with + # `TypeError: can't set attributes of built-in/extension type 'dict'` + no_type_check(dict) - def test_forward_repr(self): - self.assertEqual(repr(List['int']), "typing.List[ForwardRef('int')]") - self.assertEqual(repr(List[ForwardRef('int', module='mod')]), - "typing.List[ForwardRef('int', module='mod')]") + def test_no_type_check_forward_ref_as_string(self): + class C: + foo: typing.ClassVar[int] = 7 + class D: + foo: ClassVar[int] = 7 + class E: + foo: 'typing.ClassVar[int]' = 7 + class F: + foo: 'ClassVar[int]' = 7 - def test_union_forward(self): + expected_result = {'foo': typing.ClassVar[int]} + for clazz in [C, D, E, F]: + self.assertEqual(get_type_hints(clazz), expected_result) - def foo(a: Union['T']): - pass + def test_meta_no_type_check(self): + depr_msg = ( + "'typing.no_type_check_decorator' is deprecated " + "and slated for removal in Python 3.15" + ) + with self.assertWarnsRegex(DeprecationWarning, depr_msg): + @no_type_check_decorator + def magic_decorator(func): + return func - self.assertEqual(get_type_hints(foo, globals(), locals()), - {'a': Union[T]}) + self.assertEqual(magic_decorator.__name__, 'magic_decorator') - def foo(a: tuple[ForwardRef('T')] | int): + @magic_decorator + def foo(a: 'whatevers') -> {}: pass - self.assertEqual(get_type_hints(foo, globals(), locals()), - {'a': tuple[T] | int}) + @magic_decorator + class C: + def foo(a: 'whatevers') -> {}: + pass - def test_tuple_forward(self): - - def foo(a: Tuple['T']): - pass - - self.assertEqual(get_type_hints(foo, globals(), locals()), - {'a': Tuple[T]}) - - def foo(a: tuple[ForwardRef('T')]): - pass - - self.assertEqual(get_type_hints(foo, globals(), locals()), - {'a': tuple[T]}) - - def test_double_forward(self): - def foo(a: 'List[\'int\']'): - pass - self.assertEqual(get_type_hints(foo, globals(), locals()), - {'a': List[int]}) - - def test_forward_recursion_actually(self): - def namespace1(): - a = typing.ForwardRef('A') - A = a - def fun(x: a): pass - - ret = get_type_hints(fun, globals(), locals()) - return a - - def namespace2(): - a = typing.ForwardRef('A') - A = a - def fun(x: a): pass - - ret = get_type_hints(fun, globals(), locals()) - return a - - def cmp(o1, o2): - return o1 == o2 - - with infinite_recursion(25): - r1 = namespace1() - r2 = namespace2() - self.assertIsNot(r1, r2) - self.assertRaises(RecursionError, cmp, r1, r2) - - def test_union_forward_recursion(self): - ValueList = List['Value'] - Value = Union[str, ValueList] - - class C: - foo: List[Value] - class D: - foo: Union[Value, ValueList] - class E: - foo: Union[List[Value], ValueList] - class F: - foo: Union[Value, List[Value], ValueList] - - self.assertEqual(get_type_hints(C, globals(), locals()), get_type_hints(C, globals(), locals())) - self.assertEqual(get_type_hints(C, globals(), locals()), - {'foo': List[Union[str, List[Union[str, List['Value']]]]]}) - self.assertEqual(get_type_hints(D, globals(), locals()), - {'foo': Union[str, List[Union[str, List['Value']]]]}) - self.assertEqual(get_type_hints(E, globals(), locals()), - {'foo': Union[ - List[Union[str, List[Union[str, List['Value']]]]], - List[Union[str, List['Value']]] - ] - }) - self.assertEqual(get_type_hints(F, globals(), locals()), - {'foo': Union[ - str, - List[Union[str, List['Value']]], - List[Union[str, List[Union[str, List['Value']]]]] - ] - }) - - def test_callable_forward(self): - - def foo(a: Callable[['T'], 'T']): - pass - - self.assertEqual(get_type_hints(foo, globals(), locals()), - {'a': Callable[[T], T]}) - - def test_callable_with_ellipsis_forward(self): - - def foo(a: 'Callable[..., T]'): - pass - - self.assertEqual(get_type_hints(foo, globals(), locals()), - {'a': Callable[..., T]}) - - def test_special_forms_forward(self): - - class C: - a: Annotated['ClassVar[int]', (3, 5)] = 4 - b: Annotated['Final[int]', "const"] = 4 - x: 'ClassVar' = 4 - y: 'Final' = 4 - - class CF: - b: List['Final[int]'] = 4 - - self.assertEqual(get_type_hints(C, globals())['a'], ClassVar[int]) - self.assertEqual(get_type_hints(C, globals())['b'], Final[int]) - self.assertEqual(get_type_hints(C, globals())['x'], ClassVar) - self.assertEqual(get_type_hints(C, globals())['y'], Final) - with self.assertRaises(TypeError): - get_type_hints(CF, globals()), - - def test_syntax_error(self): - - with self.assertRaises(SyntaxError): - Generic['/T'] - - def test_delayed_syntax_error(self): - - def foo(a: 'Node[T'): - pass - - with self.assertRaises(SyntaxError): - get_type_hints(foo) - - def test_syntax_error_empty_string(self): - for form in [typing.List, typing.Set, typing.Type, typing.Deque]: - with self.subTest(form=form): - with self.assertRaises(SyntaxError): - form[''] - - def test_name_error(self): - - def foo(a: 'Noode[T]'): - pass - - with self.assertRaises(NameError): - get_type_hints(foo, locals()) - - def test_no_type_check(self): - - @no_type_check - def foo(a: 'whatevers') -> {}: - pass - - th = get_type_hints(foo) - self.assertEqual(th, {}) - - def test_no_type_check_class(self): - - @no_type_check - class C: - def foo(a: 'whatevers') -> {}: - pass - - cth = get_type_hints(C.foo) - self.assertEqual(cth, {}) - ith = get_type_hints(C().foo) - self.assertEqual(ith, {}) - - def test_no_type_check_no_bases(self): - class C: - def meth(self, x: int): ... - @no_type_check - class D(C): - c = C - - # verify that @no_type_check never affects bases - self.assertEqual(get_type_hints(C.meth), {'x': int}) - - # and never child classes: - class Child(D): - def foo(self, x: int): ... - - self.assertEqual(get_type_hints(Child.foo), {'x': int}) - - def test_no_type_check_nested_types(self): - # See https://bugs.python.org/issue46571 - class Other: - o: int - class B: # Has the same `__name__`` as `A.B` and different `__qualname__` - o: int - @no_type_check - class A: - a: int - class B: - b: int - class C: - c: int - class D: - d: int - - Other = Other - - for klass in [A, A.B, A.B.C, A.D]: - with self.subTest(klass=klass): - self.assertTrue(klass.__no_type_check__) - self.assertEqual(get_type_hints(klass), {}) - - for not_modified in [Other, B]: - with self.subTest(not_modified=not_modified): - with self.assertRaises(AttributeError): - not_modified.__no_type_check__ - self.assertNotEqual(get_type_hints(not_modified), {}) - - def test_no_type_check_class_and_static_methods(self): - @no_type_check - class Some: - @staticmethod - def st(x: int) -> int: ... - @classmethod - def cl(cls, y: int) -> int: ... - - self.assertTrue(Some.st.__no_type_check__) - self.assertEqual(get_type_hints(Some.st), {}) - self.assertTrue(Some.cl.__no_type_check__) - self.assertEqual(get_type_hints(Some.cl), {}) - - def test_no_type_check_other_module(self): - self.assertTrue(NoTypeCheck_Outer.__no_type_check__) - with self.assertRaises(AttributeError): - ann_module8.NoTypeCheck_Outer.__no_type_check__ - with self.assertRaises(AttributeError): - ann_module8.NoTypeCheck_Outer.Inner.__no_type_check__ - - self.assertTrue(NoTypeCheck_WithFunction.__no_type_check__) - with self.assertRaises(AttributeError): - ann_module8.NoTypeCheck_function.__no_type_check__ - - def test_no_type_check_foreign_functions(self): - # We should not modify this function: - def some(*args: int) -> int: - ... - - @no_type_check - class A: - some_alias = some - some_class = classmethod(some) - some_static = staticmethod(some) - - with self.assertRaises(AttributeError): - some.__no_type_check__ - self.assertEqual(get_type_hints(some), {'args': int, 'return': int}) - - def test_no_type_check_lambda(self): - @no_type_check - class A: - # Corner case: `lambda` is both an assignment and a function: - bar: Callable[[int], int] = lambda arg: arg - - self.assertTrue(A.bar.__no_type_check__) - self.assertEqual(get_type_hints(A.bar), {}) - - def test_no_type_check_TypeError(self): - # This simply should not fail with - # `TypeError: can't set attributes of built-in/extension type 'dict'` - no_type_check(dict) - - def test_no_type_check_forward_ref_as_string(self): - class C: - foo: typing.ClassVar[int] = 7 - class D: - foo: ClassVar[int] = 7 - class E: - foo: 'typing.ClassVar[int]' = 7 - class F: - foo: 'ClassVar[int]' = 7 - - expected_result = {'foo': typing.ClassVar[int]} - for clazz in [C, D, E, F]: - self.assertEqual(get_type_hints(clazz), expected_result) - - def test_meta_no_type_check(self): - depr_msg = ( - "'typing.no_type_check_decorator' is deprecated " - "and slated for removal in Python 3.15" - ) - with self.assertWarnsRegex(DeprecationWarning, depr_msg): - @no_type_check_decorator - def magic_decorator(func): - return func - - self.assertEqual(magic_decorator.__name__, 'magic_decorator') - - @magic_decorator - def foo(a: 'whatevers') -> {}: - pass - - @magic_decorator - class C: - def foo(a: 'whatevers') -> {}: - pass - - self.assertEqual(foo.__name__, 'foo') - th = get_type_hints(foo) - self.assertEqual(th, {}) - cth = get_type_hints(C.foo) - self.assertEqual(cth, {}) - ith = get_type_hints(C().foo) - self.assertEqual(ith, {}) - - def test_default_globals(self): - code = ("class C:\n" - " def foo(self, a: 'C') -> 'D': pass\n" - "class D:\n" - " def bar(self, b: 'D') -> C: pass\n" - ) - ns = {} - exec(code, ns) - hints = get_type_hints(ns['C'].foo) - self.assertEqual(hints, {'a': ns['C'], 'return': ns['D']}) - - def test_final_forward_ref(self): - self.assertEqual(gth(Loop, globals())['attr'], Final[Loop]) - self.assertNotEqual(gth(Loop, globals())['attr'], Final[int]) - self.assertNotEqual(gth(Loop, globals())['attr'], Final) - - def test_or(self): - X = ForwardRef('X') - # __or__/__ror__ itself - self.assertEqual(X | "x", Union[X, "x"]) - self.assertEqual("x" | X, Union["x", X]) + self.assertEqual(foo.__name__, 'foo') + th = get_type_hints(foo) + self.assertEqual(th, {}) + cth = get_type_hints(C.foo) + self.assertEqual(cth, {}) + ith = get_type_hints(C().foo) + self.assertEqual(ith, {}) class InternalsTests(BaseTestCase): @@ -6581,6 +6415,16 @@ def test_collect_parameters(self): typing._collect_parameters self.assertEqual(cm.filename, __file__) + @cpython_only + def test_lazy_import(self): + import_helper.ensure_lazy_imports("typing", { + "warnings", + "inspect", + "re", + "contextlib", + "annotationlib", + }) + @lru_cache() def cached_func(x, y): @@ -6687,10 +6531,6 @@ def test_overload_registry_repeated(self): self.assertEqual(list(get_overloads(impl)), overloads) -from test.typinganndata import ( - ann_module, ann_module2, ann_module3, ann_module5, ann_module6, -) - T_a = TypeVar('T_a') class AwaitableWrapper(typing.Awaitable[T_a]): @@ -6843,7 +6683,7 @@ def nested(self: 'ForRefExample'): pass -class GetTypeHintTests(BaseTestCase): +class GetTypeHintsTests(BaseTestCase): def test_get_type_hints_from_various_objects(self): # For invalid objects should fail with TypeError (not AttributeError etc). with self.assertRaises(TypeError): @@ -6853,9 +6693,8 @@ def test_get_type_hints_from_various_objects(self): with self.assertRaises(TypeError): gth(None) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_get_type_hints_modules(self): - ann_module_type_hints = {1: 2, 'f': Tuple[int, int], 'x': int, 'y': str, 'u': int | float} + ann_module_type_hints = {'f': Tuple[int, int], 'x': int, 'y': str, 'u': int | float} self.assertEqual(gth(ann_module), ann_module_type_hints) self.assertEqual(gth(ann_module2), {}) self.assertEqual(gth(ann_module3), {}) @@ -6869,12 +6708,12 @@ def test_get_type_hints_modules_forwardref(self): 'default_b': Optional[mod_generics_cache.B]} self.assertEqual(gth(mod_generics_cache), mgc_hints) - @unittest.expectedFailure # TODO: RUSTPYTHON; test expects outdated result + @unittest.expectedFailure # TODO: RUSTPYTHON; + {'x': } def test_get_type_hints_classes(self): self.assertEqual(gth(ann_module.C), # gth will find the right globalns {'y': Optional[ann_module.C]}) self.assertIsInstance(gth(ann_module.j_class), dict) - self.assertEqual(gth(ann_module.M), {'123': 123, 'o': type}) + self.assertEqual(gth(ann_module.M), {'o': type}) self.assertEqual(gth(ann_module.D), {'j': str, 'k': str, 'y': Optional[ann_module.C]}) self.assertEqual(gth(ann_module.Y), {'z': int}) @@ -6905,8 +6744,8 @@ def test_respect_no_type_check(self): class NoTpCheck: class Inn: def __init__(self, x: 'not a type'): ... - self.assertTrue(NoTpCheck.__no_type_check__) - self.assertTrue(NoTpCheck.Inn.__init__.__no_type_check__) + self.assertIs(NoTpCheck.__no_type_check__, True) + self.assertIs(NoTpCheck.Inn.__init__.__no_type_check__, True) self.assertEqual(gth(ann_module2.NTC.meth), {}) class ABase(Generic[T]): def meth(x: int): ... @@ -7052,111 +6891,320 @@ def __iand__(self, other: Const["MySet[T]"]) -> "MySet[T]": {'other': MySet[T], 'return': MySet[T]} ) - def test_get_type_hints_annotated_with_none_default(self): - # See: https://bugs.python.org/issue46195 - def annotated_with_none_default(x: Annotated[int, 'data'] = None): ... - self.assertEqual( - get_type_hints(annotated_with_none_default), - {'x': int}, - ) - self.assertEqual( - get_type_hints(annotated_with_none_default, include_extras=True), - {'x': Annotated[int, 'data']}, - ) + def test_get_type_hints_annotated_with_none_default(self): + # See: https://bugs.python.org/issue46195 + def annotated_with_none_default(x: Annotated[int, 'data'] = None): ... + self.assertEqual( + get_type_hints(annotated_with_none_default), + {'x': int}, + ) + self.assertEqual( + get_type_hints(annotated_with_none_default, include_extras=True), + {'x': Annotated[int, 'data']}, + ) + + def test_get_type_hints_classes_str_annotations(self): + class Foo: + y = str + x: 'y' + # This previously raised an error under PEP 563. + self.assertEqual(get_type_hints(Foo), {'x': str}) + + def test_get_type_hints_bad_module(self): + # bpo-41515 + class BadModule: + pass + BadModule.__module__ = 'bad' # Something not in sys.modules + self.assertNotIn('bad', sys.modules) + self.assertEqual(get_type_hints(BadModule), {}) + + def test_get_type_hints_annotated_bad_module(self): + # See https://bugs.python.org/issue44468 + class BadBase: + foo: tuple + class BadType(BadBase): + bar: list + BadType.__module__ = BadBase.__module__ = 'bad' + self.assertNotIn('bad', sys.modules) + self.assertEqual(get_type_hints(BadType), {'foo': tuple, 'bar': list}) + + def test_forward_ref_and_final(self): + # https://bugs.python.org/issue45166 + hints = get_type_hints(ann_module5) + self.assertEqual(hints, {'name': Final[str]}) + + hints = get_type_hints(ann_module5.MyClass) + self.assertEqual(hints, {'value': Final}) + + def test_top_level_class_var(self): + # This is not meaningful but we don't raise for it. + # https://github.com/python/cpython/issues/133959 + hints = get_type_hints(ann_module6) + self.assertEqual(hints, {'wrong': ClassVar[int]}) + + def test_get_type_hints_typeddict(self): + self.assertEqual(get_type_hints(TotalMovie), {'title': str, 'year': int}) + self.assertEqual(get_type_hints(TotalMovie, include_extras=True), { + 'title': str, + 'year': NotRequired[int], + }) + + self.assertEqual(get_type_hints(AnnotatedMovie), {'title': str, 'year': int}) + self.assertEqual(get_type_hints(AnnotatedMovie, include_extras=True), { + 'title': Annotated[Required[str], "foobar"], + 'year': NotRequired[Annotated[int, 2000]], + }) + + self.assertEqual(get_type_hints(DeeplyAnnotatedMovie), {'title': str, 'year': int}) + self.assertEqual(get_type_hints(DeeplyAnnotatedMovie, include_extras=True), { + 'title': Annotated[Required[str], "foobar", "another level"], + 'year': NotRequired[Annotated[int, 2000]], + }) + + self.assertEqual(get_type_hints(WeirdlyQuotedMovie), {'title': str, 'year': int}) + self.assertEqual(get_type_hints(WeirdlyQuotedMovie, include_extras=True), { + 'title': Annotated[Required[str], "foobar", "another level"], + 'year': NotRequired[Annotated[int, 2000]], + }) + + self.assertEqual(get_type_hints(_typed_dict_helper.VeryAnnotated), {'a': int}) + self.assertEqual(get_type_hints(_typed_dict_helper.VeryAnnotated, include_extras=True), { + 'a': Annotated[Required[int], "a", "b", "c"] + }) + + self.assertEqual(get_type_hints(ChildTotalMovie), {"title": str, "year": int}) + self.assertEqual(get_type_hints(ChildTotalMovie, include_extras=True), { + "title": Required[str], "year": NotRequired[int] + }) + + self.assertEqual(get_type_hints(ChildDeeplyAnnotatedMovie), {"title": str, "year": int}) + self.assertEqual(get_type_hints(ChildDeeplyAnnotatedMovie, include_extras=True), { + "title": Annotated[Required[str], "foobar", "another level"], + "year": NotRequired[Annotated[int, 2000]] + }) + + def test_get_type_hints_collections_abc_callable(self): + # https://github.com/python/cpython/issues/91621 + P = ParamSpec('P') + def f(x: collections.abc.Callable[[int], int]): ... + def g(x: collections.abc.Callable[..., int]): ... + def h(x: collections.abc.Callable[P, int]): ... + + self.assertEqual(get_type_hints(f), {'x': collections.abc.Callable[[int], int]}) + self.assertEqual(get_type_hints(g), {'x': collections.abc.Callable[..., int]}) + self.assertEqual(get_type_hints(h), {'x': collections.abc.Callable[P, int]}) + + def test_get_type_hints_format(self): + class C: + x: undefined + + with self.assertRaises(NameError): + get_type_hints(C) + + with self.assertRaises(NameError): + get_type_hints(C, format=annotationlib.Format.VALUE) + + annos = get_type_hints(C, format=annotationlib.Format.FORWARDREF) + self.assertIsInstance(annos, dict) + self.assertEqual(list(annos), ['x']) + self.assertIsInstance(annos['x'], annotationlib.ForwardRef) + self.assertEqual(annos['x'].__arg__, 'undefined') + + self.assertEqual(get_type_hints(C, format=annotationlib.Format.STRING), + {'x': 'undefined'}) + # Make sure using an int as format also works: + self.assertEqual(get_type_hints(C, format=4), {'x': 'undefined'}) + + def test_get_type_hints_format_function(self): + def func(x: undefined) -> undefined: ... + + # VALUE + with self.assertRaises(NameError): + get_type_hints(func) + with self.assertRaises(NameError): + get_type_hints(func, format=annotationlib.Format.VALUE) + + # FORWARDREF + self.assertEqual( + get_type_hints(func, format=annotationlib.Format.FORWARDREF), + {'x': EqualToForwardRef('undefined', owner=func), + 'return': EqualToForwardRef('undefined', owner=func)}, + ) + + # STRING + self.assertEqual(get_type_hints(func, format=annotationlib.Format.STRING), + {'x': 'undefined', 'return': 'undefined'}) + + def test_callable_with_ellipsis_forward(self): + + def foo(a: 'Callable[..., T]'): + pass + + self.assertEqual(get_type_hints(foo, globals(), locals()), + {'a': Callable[..., T]}) + + def test_special_forms_no_forward(self): + def f(x: ClassVar[int]): + pass + self.assertEqual(get_type_hints(f), {'x': ClassVar[int]}) + + def test_special_forms_forward(self): + + class C: + a: Annotated['ClassVar[int]', (3, 5)] = 4 + b: Annotated['Final[int]', "const"] = 4 + x: 'ClassVar' = 4 + y: 'Final' = 4 + + class CF: + b: List['Final[int]'] = 4 + + self.assertEqual(get_type_hints(C, globals())['a'], ClassVar[int]) + self.assertEqual(get_type_hints(C, globals())['b'], Final[int]) + self.assertEqual(get_type_hints(C, globals())['x'], ClassVar) + self.assertEqual(get_type_hints(C, globals())['y'], Final) + lfi = get_type_hints(CF, globals())['b'] + self.assertIs(get_origin(lfi), list) + self.assertEqual(get_args(lfi), (Final[int],)) + + def test_union_forward_recursion(self): + ValueList = List['Value'] + Value = Union[str, ValueList] + + class C: + foo: List[Value] + class D: + foo: Union[Value, ValueList] + class E: + foo: Union[List[Value], ValueList] + class F: + foo: Union[Value, List[Value], ValueList] + + self.assertEqual(get_type_hints(C, globals(), locals()), get_type_hints(C, globals(), locals())) + self.assertEqual(get_type_hints(C, globals(), locals()), + {'foo': List[Union[str, List[Union[str, List['Value']]]]]}) + self.assertEqual(get_type_hints(D, globals(), locals()), + {'foo': Union[str, List[Union[str, List['Value']]]]}) + self.assertEqual(get_type_hints(E, globals(), locals()), + {'foo': Union[ + List[Union[str, List[Union[str, List['Value']]]]], + List[Union[str, List['Value']]] + ] + }) + self.assertEqual(get_type_hints(F, globals(), locals()), + {'foo': Union[ + str, + List[Union[str, List['Value']]], + List[Union[str, List[Union[str, List['Value']]]]] + ] + }) + + def test_tuple_forward(self): + + def foo(a: Tuple['T']): + pass + + self.assertEqual(get_type_hints(foo, globals(), locals()), + {'a': Tuple[T]}) + + def foo(a: tuple[ForwardRef('T')]): + pass + + self.assertEqual(get_type_hints(foo, globals(), locals()), + {'a': tuple[T]}) + + def test_double_forward(self): + def foo(a: 'List[\'int\']'): + pass + self.assertEqual(get_type_hints(foo, globals(), locals()), + {'a': List[int]}) + + def test_union_forward(self): - def test_get_type_hints_classes_str_annotations(self): - class Foo: - y = str - x: 'y' - # This previously raised an error under PEP 563. - self.assertEqual(get_type_hints(Foo), {'x': str}) + def foo(a: Union['T']): + pass - def test_get_type_hints_bad_module(self): - # bpo-41515 - class BadModule: + self.assertEqual(get_type_hints(foo, globals(), locals()), + {'a': Union[T]}) + + def foo(a: tuple[ForwardRef('T')] | int): pass - BadModule.__module__ = 'bad' # Something not in sys.modules - self.assertNotIn('bad', sys.modules) - self.assertEqual(get_type_hints(BadModule), {}) - def test_get_type_hints_annotated_bad_module(self): - # See https://bugs.python.org/issue44468 - class BadBase: - foo: tuple - class BadType(BadBase): - bar: list - BadType.__module__ = BadBase.__module__ = 'bad' - self.assertNotIn('bad', sys.modules) - self.assertEqual(get_type_hints(BadType), {'foo': tuple, 'bar': list}) + self.assertEqual(get_type_hints(foo, globals(), locals()), + {'a': tuple[T] | int}) - def test_forward_ref_and_final(self): - # https://bugs.python.org/issue45166 - hints = get_type_hints(ann_module5) - self.assertEqual(hints, {'name': Final[str]}) + def test_default_globals(self): + code = ("class C:\n" + " def foo(self, a: 'C') -> 'D': pass\n" + "class D:\n" + " def bar(self, b: 'D') -> C: pass\n" + ) + ns = {} + exec(code, ns) + hints = get_type_hints(ns['C'].foo) + self.assertEqual(hints, {'a': ns['C'], 'return': ns['D']}) - hints = get_type_hints(ann_module5.MyClass) - self.assertEqual(hints, {'value': Final}) + def test_final_forward_ref(self): + gth = get_type_hints + self.assertEqual(gth(Loop, globals())['attr'], Final[Loop]) + self.assertNotEqual(gth(Loop, globals())['attr'], Final[int]) + self.assertNotEqual(gth(Loop, globals())['attr'], Final) - def test_top_level_class_var(self): - # https://bugs.python.org/issue45166 - with self.assertRaisesRegex( - TypeError, - r'typing.ClassVar\[int\] is not valid as type argument', - ): - get_type_hints(ann_module6) + def test_name_error(self): - @unittest.expectedFailure # TODO: RUSTPYTHON - def test_get_type_hints_typeddict(self): - self.assertEqual(get_type_hints(TotalMovie), {'title': str, 'year': int}) - self.assertEqual(get_type_hints(TotalMovie, include_extras=True), { - 'title': str, - 'year': NotRequired[int], - }) + def foo(a: 'Noode[T]'): + pass - self.assertEqual(get_type_hints(AnnotatedMovie), {'title': str, 'year': int}) - self.assertEqual(get_type_hints(AnnotatedMovie, include_extras=True), { - 'title': Annotated[Required[str], "foobar"], - 'year': NotRequired[Annotated[int, 2000]], - }) + with self.assertRaises(NameError): + get_type_hints(foo, locals()) - self.assertEqual(get_type_hints(DeeplyAnnotatedMovie), {'title': str, 'year': int}) - self.assertEqual(get_type_hints(DeeplyAnnotatedMovie, include_extras=True), { - 'title': Annotated[Required[str], "foobar", "another level"], - 'year': NotRequired[Annotated[int, 2000]], - }) + def test_basics(self): - self.assertEqual(get_type_hints(WeirdlyQuotedMovie), {'title': str, 'year': int}) - self.assertEqual(get_type_hints(WeirdlyQuotedMovie, include_extras=True), { - 'title': Annotated[Required[str], "foobar", "another level"], - 'year': NotRequired[Annotated[int, 2000]], - }) + class Node(Generic[T]): - self.assertEqual(get_type_hints(_typed_dict_helper.VeryAnnotated), {'a': int}) - self.assertEqual(get_type_hints(_typed_dict_helper.VeryAnnotated, include_extras=True), { - 'a': Annotated[Required[int], "a", "b", "c"] - }) + def __init__(self, label: T): + self.label = label + self.left = self.right = None - self.assertEqual(get_type_hints(ChildTotalMovie), {"title": str, "year": int}) - self.assertEqual(get_type_hints(ChildTotalMovie, include_extras=True), { - "title": Required[str], "year": NotRequired[int] - }) + def add_both(self, + left: 'Optional[Node[T]]', + right: 'Node[T]' = None, + stuff: int = None, + blah=None): + self.left = left + self.right = right - self.assertEqual(get_type_hints(ChildDeeplyAnnotatedMovie), {"title": str, "year": int}) - self.assertEqual(get_type_hints(ChildDeeplyAnnotatedMovie, include_extras=True), { - "title": Annotated[Required[str], "foobar", "another level"], - "year": NotRequired[Annotated[int, 2000]] - }) + def add_left(self, node: Optional['Node[T]']): + self.add_both(node, None) - def test_get_type_hints_collections_abc_callable(self): - # https://github.com/python/cpython/issues/91621 - P = ParamSpec('P') - def f(x: collections.abc.Callable[[int], int]): ... - def g(x: collections.abc.Callable[..., int]): ... - def h(x: collections.abc.Callable[P, int]): ... + def add_right(self, node: 'Node[T]' = None): + self.add_both(None, node) - self.assertEqual(get_type_hints(f), {'x': collections.abc.Callable[[int], int]}) - self.assertEqual(get_type_hints(g), {'x': collections.abc.Callable[..., int]}) - self.assertEqual(get_type_hints(h), {'x': collections.abc.Callable[P, int]}) + t = Node[int] + both_hints = get_type_hints(t.add_both, globals(), locals()) + self.assertEqual(both_hints['left'], Optional[Node[T]]) + self.assertEqual(both_hints['right'], Node[T]) + self.assertEqual(both_hints['stuff'], int) + self.assertNotIn('blah', both_hints) + + left_hints = get_type_hints(t.add_left, globals(), locals()) + self.assertEqual(left_hints['node'], Optional[Node[T]]) + + right_hints = get_type_hints(t.add_right, globals(), locals()) + self.assertEqual(right_hints['node'], Node[T]) + + def test_stringified_typeddict(self): + ns = run_code( + """ + from __future__ import annotations + from typing import TypedDict + class TD[UniqueT](TypedDict): + a: UniqueT + """ + ) + TD = ns['TD'] + self.assertEqual(TD.__annotations__, {'a': EqualToForwardRef('UniqueT', owner=TD, module=TD.__module__)}) + self.assertEqual(get_type_hints(TD), {'a': TD.__type_params__[0]}) class GetUtilitiesTestCase(TestCase): @@ -7181,7 +7229,7 @@ class C(Generic[T]): pass self.assertIs(get_origin(Callable), collections.abc.Callable) self.assertIs(get_origin(list[int]), list) self.assertIs(get_origin(list), None) - self.assertIs(get_origin(list | str), types.UnionType) + self.assertIs(get_origin(list | str), Union) self.assertIs(get_origin(P.args), P) self.assertIs(get_origin(P.kwargs), P) self.assertIs(get_origin(Required[int]), Required) @@ -7260,6 +7308,125 @@ class C(Generic[T]): pass self.assertEqual(get_args(Unpack[tuple[Unpack[Ts]]]), (tuple[Unpack[Ts]],)) +class EvaluateForwardRefTests(BaseTestCase): + def test_evaluate_forward_ref(self): + int_ref = ForwardRef('int') + self.assertIs(typing.evaluate_forward_ref(int_ref), int) + self.assertIs( + typing.evaluate_forward_ref(int_ref, type_params=()), + int, + ) + self.assertIs( + typing.evaluate_forward_ref(int_ref, format=annotationlib.Format.VALUE), + int, + ) + self.assertIs( + typing.evaluate_forward_ref( + int_ref, format=annotationlib.Format.FORWARDREF, + ), + int, + ) + self.assertEqual( + typing.evaluate_forward_ref( + int_ref, format=annotationlib.Format.STRING, + ), + 'int', + ) + + def test_evaluate_forward_ref_undefined(self): + missing = ForwardRef('missing') + with self.assertRaises(NameError): + typing.evaluate_forward_ref(missing) + self.assertIs( + typing.evaluate_forward_ref( + missing, format=annotationlib.Format.FORWARDREF, + ), + missing, + ) + self.assertEqual( + typing.evaluate_forward_ref( + missing, format=annotationlib.Format.STRING, + ), + "missing", + ) + + def test_evaluate_forward_ref_nested(self): + ref = ForwardRef("int | list['str']") + self.assertEqual( + typing.evaluate_forward_ref(ref), + int | list[str], + ) + self.assertEqual( + typing.evaluate_forward_ref(ref, format=annotationlib.Format.FORWARDREF), + int | list[str], + ) + self.assertEqual( + typing.evaluate_forward_ref(ref, format=annotationlib.Format.STRING), + "int | list['str']", + ) + + why = ForwardRef('"\'str\'"') + self.assertIs(typing.evaluate_forward_ref(why), str) + + def test_evaluate_forward_ref_none(self): + none_ref = ForwardRef('None') + self.assertIs(typing.evaluate_forward_ref(none_ref), None) + + def test_globals(self): + A = "str" + ref = ForwardRef('list[A]') + with self.assertRaises(NameError): + typing.evaluate_forward_ref(ref) + self.assertEqual( + typing.evaluate_forward_ref(ref, globals={'A': A}), + list[str], + ) + + def test_owner(self): + ref = ForwardRef("A") + + with self.assertRaises(NameError): + typing.evaluate_forward_ref(ref) + + # We default to the globals of `owner`, + # so it no longer raises `NameError` + self.assertIs( + typing.evaluate_forward_ref(ref, owner=Loop), A + ) + + def test_inherited_owner(self): + # owner passed to evaluate_forward_ref + ref = ForwardRef("list['A']") + self.assertEqual( + typing.evaluate_forward_ref(ref, owner=Loop), + list[A], + ) + + # owner set on the ForwardRef + ref = ForwardRef("list['A']", owner=Loop) + self.assertEqual( + typing.evaluate_forward_ref(ref), + list[A], + ) + + def test_partial_evaluation(self): + ref = ForwardRef("list[A]") + with self.assertRaises(NameError): + typing.evaluate_forward_ref(ref) + + self.assertEqual( + typing.evaluate_forward_ref(ref, format=annotationlib.Format.FORWARDREF), + list[EqualToForwardRef('A')], + ) + + @unittest.expectedFailure # TODO: RUSTPYTHON; ImportError: cannot import name 'fwdref_module' + def test_with_module(self): + from test.typinganndata import fwdref_module + + typing.evaluate_forward_ref( + fwdref_module.fw,) + + class CollectionsAbcTests(BaseTestCase): def test_hashable(self): @@ -7987,6 +8154,48 @@ class XMethBad2(NamedTuple): def _source(self): return 'no chance for this as well' + def test_annotation_type_check(self): + # These are rejected by _type_check + with self.assertRaises(TypeError): + class X(NamedTuple): + a: Final + with self.assertRaises(TypeError): + class Y(NamedTuple): + a: (1, 2) + + # Conversion by _type_convert + class Z(NamedTuple): + a: None + b: "str" + annos = {'a': type(None), 'b': EqualToForwardRef("str")} + self.assertEqual(Z.__annotations__, annos) + self.assertEqual(Z.__annotate__(annotationlib.Format.VALUE), annos) + self.assertEqual(Z.__annotate__(annotationlib.Format.FORWARDREF), annos) + self.assertEqual(Z.__annotate__(annotationlib.Format.STRING), {"a": "None", "b": "str"}) + + def test_future_annotations(self): + code = """ + from __future__ import annotations + from typing import NamedTuple + class X(NamedTuple): + a: int + b: None + """ + ns = run_code(textwrap.dedent(code)) + X = ns['X'] + self.assertEqual(X.__annotations__, {'a': EqualToForwardRef("int"), 'b': EqualToForwardRef("None")}) + + def test_deferred_annotations(self): + class X(NamedTuple): + y: undefined + + self.assertEqual(X._fields, ('y',)) + with self.assertRaises(NameError): + X.__annotations__ + + undefined = int + self.assertEqual(X.__annotations__, {'y': int}) + def test_multiple_inheritance(self): class A: pass @@ -8233,7 +8442,7 @@ class Bar(NamedTuple): self.assertIsInstance(bar.attr, Vanilla) self.assertEqual(bar.attr.name, "attr") - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_setname_raises_the_same_as_on_other_classes(self): class CustomException(BaseException): pass @@ -8288,6 +8497,23 @@ class VeryAnnoying(metaclass=Meta): pass class Foo(NamedTuple): attr = very_annoying + def test_super_explicitly_disallowed(self): + expected_message = ( + "uses of super() and __class__ are unsupported " + "in methods of NamedTuple subclasses" + ) + + with self.assertRaises(TypeError, msg=expected_message): + class ThisWontWork(NamedTuple): + def __repr__(self): + return super().__repr__() + + with self.assertRaises(TypeError, msg=expected_message): + class ThisWontWorkEither(NamedTuple): + @property + def name(self): + return __class__.__name__ + class TypedDictTests(BaseTestCase): def test_basics_functional_syntax(self): @@ -8302,7 +8528,11 @@ def test_basics_functional_syntax(self): self.assertEqual(Emp.__name__, 'Emp') self.assertEqual(Emp.__module__, __name__) self.assertEqual(Emp.__bases__, (dict,)) - self.assertEqual(Emp.__annotations__, {'name': str, 'id': int}) + annos = {'name': str, 'id': int} + self.assertEqual(Emp.__annotations__, annos) + self.assertEqual(Emp.__annotate__(annotationlib.Format.VALUE), annos) + self.assertEqual(Emp.__annotate__(annotationlib.Format.FORWARDREF), annos) + self.assertEqual(Emp.__annotate__(annotationlib.Format.STRING), {'name': 'str', 'id': 'int'}) self.assertEqual(Emp.__total__, True) self.assertEqual(Emp.__required_keys__, {'name', 'id'}) self.assertIsInstance(Emp.__required_keys__, frozenset) @@ -8501,6 +8731,36 @@ class Child(Base1, Base2): self.assertEqual(Child.__required_keys__, frozenset(['a'])) self.assertEqual(Child.__optional_keys__, frozenset()) + def test_inheritance_pep563(self): + def _make_td(future, class_name, annos, base, extra_names=None): + lines = [] + if future: + lines.append('from __future__ import annotations') + lines.append('from typing import TypedDict') + lines.append(f'class {class_name}({base}):') + for name, anno in annos.items(): + lines.append(f' {name}: {anno}') + code = '\n'.join(lines) + ns = run_code(code, extra_names) + return ns[class_name] + + for base_future in (True, False): + for child_future in (True, False): + with self.subTest(base_future=base_future, child_future=child_future): + base = _make_td( + base_future, "Base", {"base": "int"}, "TypedDict" + ) + self.assertIsNotNone(base.__annotate__) + child = _make_td( + child_future, "Child", {"child": "int"}, "Base", {"Base": base} + ) + base_anno = ForwardRef("int", module="builtins", owner=base) if base_future else int + child_anno = ForwardRef("int", module="builtins", owner=child) if child_future else int + self.assertEqual(base.__annotations__, {'base': base_anno}) + self.assertEqual( + child.__annotations__, {'child': child_anno, 'base': base_anno} + ) + def test_required_notrequired_keys(self): self.assertEqual(NontotalMovie.__required_keys__, frozenset({"title"})) @@ -8649,14 +8909,14 @@ class NewGeneric[T](TypedDict): # The TypedDict constructor is not itself a TypedDict self.assertIs(is_typeddict(TypedDict), False) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_get_type_hints(self): self.assertEqual( get_type_hints(Bar), {'a': typing.Optional[int], 'b': int} ) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_get_type_hints_generic(self): self.assertEqual( get_type_hints(BarGeneric), @@ -8681,6 +8941,8 @@ class A[T](TypedDict): self.assertEqual(A.__bases__, (Generic, dict)) self.assertEqual(A.__orig_bases__, (TypedDict, Generic[T])) self.assertEqual(A.__mro__, (A, Generic, dict, object)) + self.assertEqual(A.__annotations__, {'a': T}) + self.assertEqual(A.__annotate__(annotationlib.Format.STRING), {'a': 'T'}) self.assertEqual(A.__parameters__, (T,)) self.assertEqual(A[str].__parameters__, ()) self.assertEqual(A[str].__args__, (str,)) @@ -8692,6 +8954,8 @@ class A(TypedDict, Generic[T]): self.assertEqual(A.__bases__, (Generic, dict)) self.assertEqual(A.__orig_bases__, (TypedDict, Generic[T])) self.assertEqual(A.__mro__, (A, Generic, dict, object)) + self.assertEqual(A.__annotations__, {'a': T}) + self.assertEqual(A.__annotate__(annotationlib.Format.STRING), {'a': 'T'}) self.assertEqual(A.__parameters__, (T,)) self.assertEqual(A[str].__parameters__, ()) self.assertEqual(A[str].__args__, (str,)) @@ -8702,6 +8966,8 @@ class A2(Generic[T], TypedDict): self.assertEqual(A2.__bases__, (Generic, dict)) self.assertEqual(A2.__orig_bases__, (Generic[T], TypedDict)) self.assertEqual(A2.__mro__, (A2, Generic, dict, object)) + self.assertEqual(A2.__annotations__, {'a': T}) + self.assertEqual(A2.__annotate__(annotationlib.Format.STRING), {'a': 'T'}) self.assertEqual(A2.__parameters__, (T,)) self.assertEqual(A2[str].__parameters__, ()) self.assertEqual(A2[str].__args__, (str,)) @@ -8712,6 +8978,8 @@ class B(A[KT], total=False): self.assertEqual(B.__bases__, (Generic, dict)) self.assertEqual(B.__orig_bases__, (A[KT],)) self.assertEqual(B.__mro__, (B, Generic, dict, object)) + self.assertEqual(B.__annotations__, {'a': T, 'b': KT}) + self.assertEqual(B.__annotate__(annotationlib.Format.STRING), {'a': 'T', 'b': 'KT'}) self.assertEqual(B.__parameters__, (KT,)) self.assertEqual(B.__total__, False) self.assertEqual(B.__optional_keys__, frozenset(['b'])) @@ -8736,6 +9004,11 @@ class C(B[int]): 'b': KT, 'c': int, }) + self.assertEqual(C.__annotate__(annotationlib.Format.STRING), { + 'a': 'T', + 'b': 'KT', + 'c': 'int', + }) with self.assertRaises(TypeError): C[str] @@ -8755,6 +9028,11 @@ class Point3D(Point2DGeneric[T], Generic[T, KT]): 'b': T, 'c': KT, }) + self.assertEqual(Point3D.__annotate__(annotationlib.Format.STRING), { + 'a': 'T', + 'b': 'T', + 'c': 'KT', + }) self.assertEqual(Point3D[int, str].__origin__, Point3D) with self.assertRaises(TypeError): @@ -8786,10 +9064,15 @@ class WithImplicitAny(B): 'b': KT, 'c': int, }) + self.assertEqual(WithImplicitAny.__annotate__(annotationlib.Format.STRING), { + 'a': 'T', + 'b': 'KT', + 'c': 'int', + }) with self.assertRaises(TypeError): WithImplicitAny[str] - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_non_generic_subscript(self): # For backward compatibility, subscription works # on arbitrary TypedDict types. @@ -8917,7 +9200,7 @@ class Child(Base): self.assertEqual(Child.__readonly_keys__, frozenset()) self.assertEqual(Child.__mutable_keys__, frozenset({'a'})) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_combine_qualifiers(self): class AllTheThings(TypedDict): a: Annotated[Required[ReadOnly[int]], "why not"] @@ -8944,6 +9227,54 @@ class AllTheThings(TypedDict): }, ) + def test_annotations(self): + # _type_check is applied + with self.assertRaisesRegex(TypeError, "Plain typing.Final is not valid as type argument"): + class X(TypedDict): + a: Final + + # _type_convert is applied + class Y(TypedDict): + a: None + b: "int" + fwdref = EqualToForwardRef('int', module=__name__) + self.assertEqual(Y.__annotations__, {'a': type(None), 'b': fwdref}) + self.assertEqual(Y.__annotate__(annotationlib.Format.FORWARDREF), {'a': type(None), 'b': fwdref}) + + # _type_check is also applied later + class Z(TypedDict): + a: undefined + + with self.assertRaises(NameError): + Z.__annotations__ + + undefined = Final + with self.assertRaisesRegex(TypeError, "Plain typing.Final is not valid as type argument"): + Z.__annotations__ + + undefined = None + self.assertEqual(Z.__annotations__, {'a': type(None)}) + + def test_deferred_evaluation(self): + class A(TypedDict): + x: NotRequired[undefined] + y: ReadOnly[undefined] + z: Required[undefined] + + self.assertEqual(A.__required_keys__, frozenset({'y', 'z'})) + self.assertEqual(A.__optional_keys__, frozenset({'x'})) + self.assertEqual(A.__readonly_keys__, frozenset({'y'})) + self.assertEqual(A.__mutable_keys__, frozenset({'x', 'z'})) + + with self.assertRaises(NameError): + A.__annotations__ + + self.assertEqual( + A.__annotate__(annotationlib.Format.STRING), + {'x': 'NotRequired[undefined]', 'y': 'ReadOnly[undefined]', + 'z': 'Required[undefined]'}, + ) + class RequiredTests(BaseTestCase): @@ -9111,7 +9442,7 @@ def test_repr(self): self.assertEqual(repr(Match[str]), 'typing.Match[str]') self.assertEqual(repr(Match[bytes]), 'typing.Match[bytes]') - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_cannot_subclass(self): with self.assertRaisesRegex( TypeError, @@ -9588,6 +9919,19 @@ class B(str): ... self.assertIs(type(field_c2.__metadata__[0]), float) self.assertIs(type(field_c3.__metadata__[0]), bool) + def test_forwardref_partial_evaluation(self): + # Test that Annotated partially evaluates if it contains a ForwardRef + # See: https://github.com/python/cpython/issues/137706 + def f(x: Annotated[undefined, '']): pass + + ann = annotationlib.get_annotations(f, format=annotationlib.Format.FORWARDREF) + + # Test that the attributes are retrievable from the partially evaluated annotation + x_ann = ann['x'] + self.assertIs(get_origin(x_ann), Annotated) + self.assertEqual(x_ann.__origin__, EqualToForwardRef('undefined', owner=f)) + self.assertEqual(x_ann.__metadata__, ('',)) + class TypeAliasTests(BaseTestCase): def test_canonical_usage_with_variable_annotation(self): @@ -10094,6 +10438,7 @@ def test_var_substitution(self): self.assertEqual(C[Concatenate[str, P2]], Concatenate[int, str, P2]) self.assertEqual(C[...], Concatenate[int, ...]) + class TypeGuardTests(BaseTestCase): def test_basics(self): TypeGuard[int] # OK @@ -10285,7 +10630,6 @@ def test_special_attrs(self): typing.ClassVar: 'ClassVar', typing.Concatenate: 'Concatenate', typing.Final: 'Final', - typing.ForwardRef: 'ForwardRef', typing.Literal: 'Literal', typing.NewType: 'NewType', typing.NoReturn: 'NoReturn', @@ -10295,9 +10639,8 @@ def test_special_attrs(self): typing.TypeGuard: 'TypeGuard', typing.TypeIs: 'TypeIs', typing.TypeVar: 'TypeVar', - typing.Union: 'Union', typing.Self: 'Self', - # Subscribed special forms + # Subscripted special forms typing.Annotated[Any, "Annotation"]: 'Annotated', typing.Annotated[int, 'Annotation']: 'Annotated', typing.ClassVar[Any]: 'ClassVar', @@ -10306,13 +10649,12 @@ def test_special_attrs(self): typing.Literal[Any]: 'Literal', typing.Literal[1, 2]: 'Literal', typing.Literal[True, 2]: 'Literal', - typing.Optional[Any]: 'Optional', + typing.Optional[Any]: 'Union', typing.TypeGuard[Any]: 'TypeGuard', typing.TypeIs[Any]: 'TypeIs', typing.Union[Any]: 'Any', typing.Union[int, float]: 'Union', # Incompatible special forms (tested in test_special_attrs2) - # - typing.ForwardRef('set[Any]') # - typing.NewType('TypeName', Any) # - typing.ParamSpec('SpecialAttrsP') # - typing.TypeVar('T') @@ -10326,23 +10668,14 @@ def test_special_attrs(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): s = pickle.dumps(cls, proto) loaded = pickle.loads(s) - self.assertIs(cls, loaded) + if isinstance(cls, Union): + self.assertEqual(cls, loaded) + else: + self.assertIs(cls, loaded) TypeName = typing.NewType('SpecialAttrsTests.TypeName', Any) def test_special_attrs2(self): - # Forward refs provide a different introspection API. __name__ and - # __qualname__ make little sense for forward refs as they can store - # complex typing expressions. - fr = typing.ForwardRef('set[Any]') - self.assertFalse(hasattr(fr, '__name__')) - self.assertFalse(hasattr(fr, '__qualname__')) - self.assertEqual(fr.__module__, 'typing') - # Forward refs are currently unpicklable. - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - with self.assertRaises(TypeError): - pickle.dumps(fr, proto) - self.assertEqual(SpecialAttrsTests.TypeName.__name__, 'TypeName') self.assertEqual( SpecialAttrsTests.TypeName.__qualname__, @@ -10363,7 +10696,7 @@ def test_special_attrs2(self): # to the variable name to which it is assigned". Thus, providing # __qualname__ is unnecessary. self.assertEqual(SpecialAttrsT.__name__, 'SpecialAttrsT') - self.assertFalse(hasattr(SpecialAttrsT, '__qualname__')) + self.assertNotHasAttr(SpecialAttrsT, '__qualname__') self.assertEqual(SpecialAttrsT.__module__, __name__) # Module-level type variables are picklable. for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -10372,7 +10705,7 @@ def test_special_attrs2(self): self.assertIs(SpecialAttrsT, loaded) self.assertEqual(SpecialAttrsP.__name__, 'SpecialAttrsP') - self.assertFalse(hasattr(SpecialAttrsP, '__qualname__')) + self.assertNotHasAttr(SpecialAttrsP, '__qualname__') self.assertEqual(SpecialAttrsP.__module__, __name__) # Module-level ParamSpecs are picklable. for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -10521,7 +10854,7 @@ def test_no_call(self): with self.assertRaises(TypeError): NoDefault() - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_no_attributes(self): with self.assertRaises(AttributeError): NoDefault.foo = 3 @@ -10597,7 +10930,7 @@ class TypeIterationTests(BaseTestCase): Annotated[T, ''], ) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_cannot_iterate(self): expected_error_regex = "object is not iterable" for test_type in self._UNITERABLE_TYPES: @@ -10615,6 +10948,37 @@ def test_is_not_instance_of_iterable(self): self.assertNotIsInstance(type_to_test, collections.abc.Iterable) +class UnionGenericAliasTests(BaseTestCase): + def test_constructor(self): + # Used e.g. in typer, pydantic + with self.assertWarns(DeprecationWarning): + inst = typing._UnionGenericAlias(typing.Union, (int, str)) + self.assertEqual(inst, int | str) + with self.assertWarns(DeprecationWarning): + # name is accepted but ignored + inst = typing._UnionGenericAlias(typing.Union, (int, None), name="Optional") + self.assertEqual(inst, int | None) + + def test_isinstance(self): + # Used e.g. in pydantic + with self.assertWarns(DeprecationWarning): + self.assertTrue(isinstance(Union[int, str], typing._UnionGenericAlias)) + with self.assertWarns(DeprecationWarning): + self.assertFalse(isinstance(int, typing._UnionGenericAlias)) + + def test_eq(self): + # type(t) == _UnionGenericAlias is used in vyos + with self.assertWarns(DeprecationWarning): + self.assertEqual(Union, typing._UnionGenericAlias) + with self.assertWarns(DeprecationWarning): + self.assertEqual(typing._UnionGenericAlias, typing._UnionGenericAlias) + with self.assertWarns(DeprecationWarning): + self.assertNotEqual(int, typing._UnionGenericAlias) + + def test_hashable(self): + self.assertEqual(hash(typing._UnionGenericAlias), hash(Union)) + + def load_tests(loader, tests, pattern): import doctest tests.addTests(doctest.DocTestSuite(typing)) diff --git a/Lib/typing.py b/Lib/typing.py index 77caac9eed1..92b78defd11 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -27,7 +27,7 @@ import operator import sys import types -from types import WrapperDescriptorType, MethodWrapperType, MethodDescriptorType, GenericAlias +from types import GenericAlias from _typing import ( _idfunc, @@ -38,6 +38,7 @@ ParamSpecKwargs, TypeAliasType, Generic, + Union, NoDefault, ) @@ -126,6 +127,7 @@ 'cast', 'clear_overloads', 'dataclass_transform', + 'evaluate_forward_ref', 'final', 'get_args', 'get_origin', @@ -160,7 +162,6 @@ 'Unpack', ] - class _LazyAnnotationLib: def __getattr__(self, attr): global _lazy_annotationlib @@ -176,7 +177,7 @@ def _type_convert(arg, module=None, *, allow_special_forms=False, owner=None): if arg is None: return type(None) if isinstance(arg, str): - return ForwardRef(arg, module=module, is_class=allow_special_forms) + return _make_forward_ref(arg, module=module, is_class=allow_special_forms, owner=owner) return arg @@ -250,21 +251,10 @@ def _type_repr(obj): typically enough to uniquely identify a type. For everything else, we fall back on repr(obj). """ - # When changing this function, don't forget about - # `_collections_abc._type_repr`, which does the same thing - # and must be consistent with this one. - if isinstance(obj, type): - if obj.__module__ == 'builtins': - return obj.__qualname__ - return f'{obj.__module__}.{obj.__qualname__}' - if obj is ...: - return '...' - if isinstance(obj, types.FunctionType): - return obj.__name__ if isinstance(obj, tuple): # Special case for `repr` of types with `ParamSpec`: return '[' + ', '.join(_type_repr(t) for t in obj) + ']' - return repr(obj) + return _lazy_annotationlib.type_repr(obj) def _collect_type_parameters(args, *, enforce_default_ordering: bool = True): @@ -366,41 +356,11 @@ def _deduplicate(params, *, unhashable_fallback=False): if not unhashable_fallback: raise # Happens for cases like `Annotated[dict, {'x': IntValidator()}]` - return _deduplicate_unhashable(params) - -def _deduplicate_unhashable(unhashable_params): - new_unhashable = [] - for t in unhashable_params: - if t not in new_unhashable: - new_unhashable.append(t) - return new_unhashable - -def _compare_args_orderless(first_args, second_args): - first_unhashable = _deduplicate_unhashable(first_args) - second_unhashable = _deduplicate_unhashable(second_args) - t = list(second_unhashable) - try: - for elem in first_unhashable: - t.remove(elem) - except ValueError: - return False - return not t - -def _remove_dups_flatten(parameters): - """Internal helper for Union creation and substitution. - - Flatten Unions among parameters, then remove duplicates. - """ - # Flatten out Union[Union[...], ...]. - params = [] - for p in parameters: - if isinstance(p, (_UnionGenericAlias, types.UnionType)): - params.extend(p.__args__) - else: - params.append(p) - - return tuple(_deduplicate(params, unhashable_fallback=True)) - + new_unhashable = [] + for t in params: + if t not in new_unhashable: + new_unhashable.append(t) + return new_unhashable def _flatten_literal_params(parameters): """Internal helper for Literal creation: flatten Literals among parameters.""" @@ -470,7 +430,8 @@ def __repr__(self): _sentinel = _Sentinel() -def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=frozenset()): +def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=frozenset(), + format=None, owner=None, parent_fwdref=None, prefer_fwd_module=False): """Evaluate all forward references in the given type t. For use of globalns and localns see the docstring for get_type_hints(). @@ -480,12 +441,30 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f if type_params is _sentinel: _deprecation_warning_for_no_type_params_passed("typing._eval_type") type_params = () - if isinstance(t, ForwardRef): - return t._evaluate(globalns, localns, type_params, recursive_guard=recursive_guard) - if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)): + if isinstance(t, _lazy_annotationlib.ForwardRef): + # If the forward_ref has __forward_module__ set, evaluate() infers the globals + # from the module, and it will probably pick better than the globals we have here. + # We do this only for calls from get_type_hints() (which opts in through the + # prefer_fwd_module flag), so that the default behavior remains more straightforward. + if prefer_fwd_module and t.__forward_module__ is not None: + globalns = None + # If there are type params on the owner, we need to add them back, because + # annotationlib won't. + if owner_type_params := getattr(owner, "__type_params__", None): + globalns = getattr( + sys.modules.get(t.__forward_module__, None), "__dict__", None + ) + if globalns is not None: + globalns = dict(globalns) + for type_param in owner_type_params: + globalns[type_param.__name__] = type_param + return evaluate_forward_ref(t, globals=globalns, locals=localns, + type_params=type_params, owner=owner, + _recursive_guard=recursive_guard, format=format) + if isinstance(t, (_GenericAlias, GenericAlias, Union)): if isinstance(t, GenericAlias): args = tuple( - ForwardRef(arg) if isinstance(arg, str) else arg + _make_forward_ref(arg, parent_fwdref=parent_fwdref) if isinstance(arg, str) else arg for arg in t.__args__ ) is_unpacked = t.__unpacked__ @@ -498,7 +477,8 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f ev_args = tuple( _eval_type( - a, globalns, localns, type_params, recursive_guard=recursive_guard + a, globalns, localns, type_params, recursive_guard=recursive_guard, + format=format, owner=owner, prefer_fwd_module=prefer_fwd_module, ) for a in t.__args__ ) @@ -506,7 +486,7 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f return t if isinstance(t, GenericAlias): return GenericAlias(t.__origin__, ev_args) - if isinstance(t, types.UnionType): + if isinstance(t, Union): return functools.reduce(operator.or_, ev_args) else: return t.copy_with(ev_args) @@ -760,59 +740,6 @@ class FastConnector(Connection): item = _type_check(parameters, f'{self} accepts only single type.', allow_special_forms=True) return _GenericAlias(self, (item,)) -@_SpecialForm -def Union(self, parameters): - """Union type; Union[X, Y] means either X or Y. - - On Python 3.10 and higher, the | operator - can also be used to denote unions; - X | Y means the same thing to the type checker as Union[X, Y]. - - To define a union, use e.g. Union[int, str]. Details: - - The arguments must be types and there must be at least one. - - None as an argument is a special case and is replaced by - type(None). - - Unions of unions are flattened, e.g.:: - - assert Union[Union[int, str], float] == Union[int, str, float] - - - Unions of a single argument vanish, e.g.:: - - assert Union[int] == int # The constructor actually returns int - - - Redundant arguments are skipped, e.g.:: - - assert Union[int, str, int] == Union[int, str] - - - When comparing unions, the argument order is ignored, e.g.:: - - assert Union[int, str] == Union[str, int] - - - You cannot subclass or instantiate a union. - - You can use Optional[X] as a shorthand for Union[X, None]. - """ - if parameters == (): - raise TypeError("Cannot take a Union of no types.") - if not isinstance(parameters, tuple): - parameters = (parameters,) - msg = "Union[arg, ...]: each arg must be a type." - parameters = tuple(_type_check(p, msg) for p in parameters) - parameters = _remove_dups_flatten(parameters) - if len(parameters) == 1: - return parameters[0] - if len(parameters) == 2 and type(None) in parameters: - return _UnionGenericAlias(self, parameters, name="Optional") - return _UnionGenericAlias(self, parameters) - -def _make_union(left, right): - """Used from the C implementation of TypeVar. - - TypeVar.__or__ calls this instead of returning types.UnionType - because we want to allow unions between TypeVars and strings - (forward references). - """ - return Union[left, right] - @_SpecialForm def Optional(self, parameters): """Optional[X] is equivalent to Union[X, None].""" @@ -1022,116 +949,85 @@ def run(arg: Child | Unrelated): return _GenericAlias(self, (item,)) -class ForwardRef(_Final, _root=True): - """Internal wrapper to hold a forward reference.""" - - __slots__ = ('__forward_arg__', '__forward_code__', - '__forward_evaluated__', '__forward_value__', - '__forward_is_argument__', '__forward_is_class__', - '__forward_module__') - - def __init__(self, arg, is_argument=True, module=None, *, is_class=False): - if not isinstance(arg, str): - raise TypeError(f"Forward reference must be a string -- got {arg!r}") +def _make_forward_ref(code, *, parent_fwdref=None, **kwargs): + if parent_fwdref is not None: + if parent_fwdref.__forward_module__ is not None: + kwargs['module'] = parent_fwdref.__forward_module__ + if parent_fwdref.__owner__ is not None: + kwargs['owner'] = parent_fwdref.__owner__ + forward_ref = _lazy_annotationlib.ForwardRef(code, **kwargs) + # For compatibility, eagerly compile the forwardref's code. + forward_ref.__forward_code__ + return forward_ref - # If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`. - # Unfortunately, this isn't a valid expression on its own, so we - # do the unpacking manually. - if arg.startswith('*'): - arg_to_compile = f'({arg},)[0]' # E.g. (*Ts,)[0] or (*tuple[int, int],)[0] - else: - arg_to_compile = arg - try: - code = compile(arg_to_compile, '', 'eval') - except SyntaxError: - raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}") - - self.__forward_arg__ = arg - self.__forward_code__ = code - self.__forward_evaluated__ = False - self.__forward_value__ = None - self.__forward_is_argument__ = is_argument - self.__forward_is_class__ = is_class - self.__forward_module__ = module - - def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard): - if type_params is _sentinel: - _deprecation_warning_for_no_type_params_passed("typing.ForwardRef._evaluate") - type_params = () - if self.__forward_arg__ in recursive_guard: - return self - if not self.__forward_evaluated__ or localns is not globalns: - if globalns is None and localns is None: - globalns = localns = {} - elif globalns is None: - globalns = localns - elif localns is None: - localns = globalns - if self.__forward_module__ is not None: - globalns = getattr( - sys.modules.get(self.__forward_module__, None), '__dict__', globalns - ) - - # type parameters require some special handling, - # as they exist in their own scope - # but `eval()` does not have a dedicated parameter for that scope. - # For classes, names in type parameter scopes should override - # names in the global scope (which here are called `localns`!), - # but should in turn be overridden by names in the class scope - # (which here are called `globalns`!) - if type_params: - globalns, localns = dict(globalns), dict(localns) - for param in type_params: - param_name = param.__name__ - if not self.__forward_is_class__ or param_name not in globalns: - globalns[param_name] = param - localns.pop(param_name, None) - - type_ = _type_check( - eval(self.__forward_code__, globalns, localns), - "Forward references must evaluate to types.", - is_argument=self.__forward_is_argument__, - allow_special_forms=self.__forward_is_class__, - ) - self.__forward_value__ = _eval_type( - type_, - globalns, - localns, - type_params, - recursive_guard=(recursive_guard | {self.__forward_arg__}), - ) - self.__forward_evaluated__ = True - return self.__forward_value__ - def __eq__(self, other): - if not isinstance(other, ForwardRef): - return NotImplemented - if self.__forward_evaluated__ and other.__forward_evaluated__: - return (self.__forward_arg__ == other.__forward_arg__ and - self.__forward_value__ == other.__forward_value__) - return (self.__forward_arg__ == other.__forward_arg__ and - self.__forward_module__ == other.__forward_module__) - - def __hash__(self): - return hash((self.__forward_arg__, self.__forward_module__)) - - def __or__(self, other): - return Union[self, other] - - def __ror__(self, other): - return Union[other, self] +def evaluate_forward_ref( + forward_ref, + *, + owner=None, + globals=None, + locals=None, + type_params=None, + format=None, + _recursive_guard=frozenset(), +): + """Evaluate a forward reference as a type hint. + + This is similar to calling the ForwardRef.evaluate() method, + but unlike that method, evaluate_forward_ref() also + recursively evaluates forward references nested within the type hint. + + *forward_ref* must be an instance of ForwardRef. *owner*, if given, + should be the object that holds the annotations that the forward reference + derived from, such as a module, class object, or function. It is used to + infer the namespaces to use for looking up names. *globals* and *locals* + can also be explicitly given to provide the global and local namespaces. + *type_params* is a tuple of type parameters that are in scope when + evaluating the forward reference. This parameter should be provided (though + it may be an empty tuple) if *owner* is not given and the forward reference + does not already have an owner set. *format* specifies the format of the + annotation and is a member of the annotationlib.Format enum, defaulting to + VALUE. - def __repr__(self): - if self.__forward_module__ is None: - module_repr = '' - else: - module_repr = f', module={self.__forward_module__!r}' - return f'ForwardRef({self.__forward_arg__!r}{module_repr})' + """ + if format == _lazy_annotationlib.Format.STRING: + return forward_ref.__forward_arg__ + if forward_ref.__forward_arg__ in _recursive_guard: + return forward_ref + + if format is None: + format = _lazy_annotationlib.Format.VALUE + value = forward_ref.evaluate(globals=globals, locals=locals, + type_params=type_params, owner=owner, format=format) + + if (isinstance(value, _lazy_annotationlib.ForwardRef) + and format == _lazy_annotationlib.Format.FORWARDREF): + return value + + if isinstance(value, str): + value = _make_forward_ref(value, module=forward_ref.__forward_module__, + owner=owner or forward_ref.__owner__, + is_argument=forward_ref.__forward_is_argument__, + is_class=forward_ref.__forward_is_class__) + if owner is None: + owner = forward_ref.__owner__ + return _eval_type( + value, + globals, + locals, + type_params, + recursive_guard=_recursive_guard | {forward_ref.__forward_arg__}, + format=format, + owner=owner, + parent_fwdref=forward_ref, + ) def _is_unpacked_typevartuple(x: Any) -> bool: + # Need to check 'is True' here + # See: https://github.com/python/cpython/issues/137706 return ((not isinstance(x, type)) and - getattr(x, '__typing_is_unpacked_typevartuple__', False)) + getattr(x, '__typing_is_unpacked_typevartuple__', False) is True) def _is_typevar_like(x: Any) -> bool: @@ -1201,7 +1097,7 @@ def _paramspec_prepare_subst(self, alias, args): params = alias.__parameters__ i = params.index(self) if i == len(args) and self.has_default(): - args = [*args, self.__default__] + args = (*args, self.__default__) if i >= len(args): raise TypeError(f"Too few arguments for {alias}") # Special case where Z[[int, str, bool]] == Z[int, str, bool] in PEP 612. @@ -1246,14 +1142,26 @@ def _generic_class_getitem(cls, args): f"Parameters to {cls.__name__}[...] must all be unique") else: # Subscripting a regular Generic subclass. - for param in cls.__parameters__: + try: + parameters = cls.__parameters__ + except AttributeError as e: + init_subclass = getattr(cls, '__init_subclass__', None) + if init_subclass not in {None, Generic.__init_subclass__}: + e.add_note( + f"Note: this exception may have been caused by " + f"{init_subclass.__qualname__!r} (or the " + f"'__init_subclass__' method on a superclass) not " + f"calling 'super().__init_subclass__()'" + ) + raise + for param in parameters: prepare = getattr(param, '__typing_prepare_subst__', None) if prepare is not None: args = prepare(cls, args) _check_generic_specialization(cls, args) new_args = [] - for param, new_arg in zip(cls.__parameters__, args): + for param, new_arg in zip(parameters, args): if isinstance(param, TypeVarTuple): new_args.extend(new_arg) else: @@ -1768,45 +1676,41 @@ def __getitem__(self, params): return self.copy_with(params) -class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True): - def copy_with(self, params): - return Union[params] +class _UnionGenericAliasMeta(type): + def __instancecheck__(self, inst: object) -> bool: + import warnings + warnings._deprecated("_UnionGenericAlias", remove=(3, 17)) + return isinstance(inst, Union) + + def __subclasscheck__(self, inst: type) -> bool: + import warnings + warnings._deprecated("_UnionGenericAlias", remove=(3, 17)) + return issubclass(inst, Union) def __eq__(self, other): - if not isinstance(other, (_UnionGenericAlias, types.UnionType)): - return NotImplemented - try: # fast path - return set(self.__args__) == set(other.__args__) - except TypeError: # not hashable, slow path - return _compare_args_orderless(self.__args__, other.__args__) + import warnings + warnings._deprecated("_UnionGenericAlias", remove=(3, 17)) + if other is _UnionGenericAlias or other is Union: + return True + return NotImplemented def __hash__(self): - return hash(frozenset(self.__args__)) + return hash(Union) - def __repr__(self): - args = self.__args__ - if len(args) == 2: - if args[0] is type(None): - return f'typing.Optional[{_type_repr(args[1])}]' - elif args[1] is type(None): - return f'typing.Optional[{_type_repr(args[0])}]' - return super().__repr__() - def __instancecheck__(self, obj): - for arg in self.__args__: - if isinstance(obj, arg): - return True - return False +class _UnionGenericAlias(metaclass=_UnionGenericAliasMeta): + """Compatibility hack. - def __subclasscheck__(self, cls): - for arg in self.__args__: - if issubclass(cls, arg): - return True - return False + A class named _UnionGenericAlias used to be used to implement + typing.Union. This class exists to serve as a shim to preserve + the meaning of some code that used to use _UnionGenericAlias + directly. - def __reduce__(self): - func, (origin, args) = super().__reduce__() - return func, (Union, args) + """ + def __new__(cls, self_cls, parameters, /, *, name=None): + import warnings + warnings._deprecated("_UnionGenericAlias", remove=(3, 17)) + return Union[parameters] def _value_and_type_iter(parameters): @@ -1945,7 +1849,13 @@ def _get_protocol_attrs(cls): for base in cls.__mro__[:-1]: # without object if base.__name__ in {'Protocol', 'Generic'}: continue - annotations = getattr(base, '__annotations__', {}) + try: + annotations = base.__annotations__ + except Exception: + # Only go through annotationlib to handle deferred annotations if we need to + annotations = _lazy_annotationlib.get_annotations( + base, format=_lazy_annotationlib.Format.FORWARDREF + ) for attr in (*base.__dict__, *annotations): if not attr.startswith('_abc_') and attr not in EXCLUDED_ATTRIBUTES: attrs.add(attr) @@ -1998,8 +1908,7 @@ def _allow_reckless_class_checks(depth=2): The abc and functools modules indiscriminately call isinstance() and issubclass() on the whole MRO of a user class, which may contain protocols. """ - # XXX: RUSTPYTHON; https://github.com/python/cpython/pull/136115 - return _caller(depth) in {'abc', '_py_abc', 'functools', None} + return _caller(depth) in {'abc', 'functools', None} _PROTO_ALLOWLIST = { @@ -2009,6 +1918,8 @@ def _allow_reckless_class_checks(depth=2): 'Reversible', 'Buffer', ], 'contextlib': ['AbstractContextManager', 'AbstractAsyncContextManager'], + 'io': ['Reader', 'Writer'], + 'os': ['PathLike'], } @@ -2161,11 +2072,17 @@ def _proto_hook(cls, other): break # ...or in annotations, if it is a sub-protocol. - annotations = getattr(base, '__annotations__', {}) - if (isinstance(annotations, collections.abc.Mapping) and - attr in annotations and - issubclass(other, Generic) and getattr(other, '_is_protocol', False)): - break + if issubclass(other, Generic) and getattr(other, "_is_protocol", False): + # We avoid the slower path through annotationlib here because in most + # cases it should be unnecessary. + try: + annos = base.__annotations__ + except Exception: + annos = _lazy_annotationlib.get_annotations( + base, format=_lazy_annotationlib.Format.FORWARDREF + ) + if attr in annos: + break else: return NotImplemented return True @@ -2228,7 +2145,7 @@ class _AnnotatedAlias(_NotIterable, _GenericAlias, _root=True): """Runtime representation of an annotated type. At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't' - with extra annotations. The alias behaves like a normal typing alias. + with extra metadata. The alias behaves like a normal typing alias. Instantiating is the same as instantiating the underlying type; binding it to types is also the same. @@ -2407,12 +2324,8 @@ def greet(name: str) -> None: return val -_allowed_types = (types.FunctionType, types.BuiltinFunctionType, - types.MethodType, types.ModuleType, - WrapperDescriptorType, MethodWrapperType, MethodDescriptorType) - - -def get_type_hints(obj, globalns=None, localns=None, include_extras=False): +def get_type_hints(obj, globalns=None, localns=None, include_extras=False, + *, format=None): """Return type hints for an object. This is often the same as obj.__annotations__, but it handles @@ -2445,17 +2358,21 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): """ if getattr(obj, '__no_type_check__', None): return {} + Format = _lazy_annotationlib.Format + if format is None: + format = Format.VALUE # Classes require a special treatment. if isinstance(obj, type): hints = {} for base in reversed(obj.__mro__): + ann = _lazy_annotationlib.get_annotations(base, format=format) + if format == Format.STRING: + hints.update(ann) + continue if globalns is None: base_globals = getattr(sys.modules.get(base.__module__, None), '__dict__', {}) else: base_globals = globalns - ann = _lazy_annotationlib.get_annotations(base) - if isinstance(ann, types.GetSetDescriptorType): - ann = {} base_locals = dict(vars(base)) if localns is None else localns if localns is None and globalns is None: # This is surprising, but required. Before Python 3.10, @@ -2465,14 +2382,33 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): # *base_globals* first rather than *base_locals*. # This only affects ForwardRefs. base_globals, base_locals = base_locals, base_globals + type_params = base.__type_params__ + base_globals, base_locals = _add_type_params_to_scope( + type_params, base_globals, base_locals, True) for name, value in ann.items(): + if isinstance(value, str): + value = _make_forward_ref(value, is_argument=False, is_class=True) + value = _eval_type(value, base_globals, base_locals, (), + format=format, owner=obj, prefer_fwd_module=True) if value is None: value = type(None) - if isinstance(value, str): - value = ForwardRef(value, is_argument=False, is_class=True) - value = _eval_type(value, base_globals, base_locals, base.__type_params__) hints[name] = value - return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()} + if include_extras or format == Format.STRING: + return hints + else: + return {k: _strip_annotations(t) for k, t in hints.items()} + + hints = _lazy_annotationlib.get_annotations(obj, format=format) + if ( + not hints + and not isinstance(obj, types.ModuleType) + and not callable(obj) + and not hasattr(obj, '__annotations__') + and not hasattr(obj, '__annotate__') + ): + raise TypeError(f"{obj!r} is not a module, class, or callable.") + if format == Format.STRING: + return hints if globalns is None: if isinstance(obj, types.ModuleType): @@ -2487,34 +2423,38 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): localns = globalns elif localns is None: localns = globalns - try: - hints = _lazy_annotationlib.get_annotations(obj) - except TypeError: - hints = getattr(obj, '__annotations__', None) - if hints is None: - # Return empty annotations for something that _could_ have them. - if isinstance(obj, _allowed_types): - return {} - else: - raise TypeError('{!r} is not a module, class, method, ' - 'or function.'.format(obj)) - hints = dict(hints) type_params = getattr(obj, "__type_params__", ()) + globalns, localns = _add_type_params_to_scope(type_params, globalns, localns, False) for name, value in hints.items(): - if value is None: - value = type(None) if isinstance(value, str): # class-level forward refs were handled above, this must be either # a module-level annotation or a function argument annotation - value = ForwardRef( + value = _make_forward_ref( value, is_argument=not isinstance(obj, types.ModuleType), is_class=False, ) - hints[name] = _eval_type(value, globalns, localns, type_params) + value = _eval_type(value, globalns, localns, (), format=format, owner=obj, prefer_fwd_module=True) + if value is None: + value = type(None) + hints[name] = value return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()} +# Add type parameters to the globals and locals scope. This is needed for +# compatibility. +def _add_type_params_to_scope(type_params, globalns, localns, is_class): + if not type_params: + return globalns, localns + globalns = dict(globalns) + localns = dict(localns) + for param in type_params: + if not is_class or param.__name__ not in globalns: + globalns[param.__name__] = param + localns.pop(param.__name__, None) + return globalns, localns + + def _strip_annotations(t): """Strip the annotations from a given type.""" if isinstance(t, _AnnotatedAlias): @@ -2531,7 +2471,7 @@ def _strip_annotations(t): if stripped_args == t.__args__: return t return GenericAlias(t.__origin__, stripped_args) - if isinstance(t, types.UnionType): + if isinstance(t, Union): stripped_args = tuple(_strip_annotations(a) for a in t.__args__) if stripped_args == t.__args__: return t @@ -2565,8 +2505,8 @@ def get_origin(tp): return tp.__origin__ if tp is Generic: return Generic - if isinstance(tp, types.UnionType): - return types.UnionType + if isinstance(tp, Union): + return Union return None @@ -2591,7 +2531,7 @@ def get_args(tp): if _should_unflatten_callable_args(tp, res): res = (list(res[:-1]), res[-1]) return res - if isinstance(tp, types.UnionType): + if isinstance(tp, Union): return tp.__args__ return () @@ -2858,7 +2798,7 @@ class Other(Leaf): # Error reported by type checker Sequence = _alias(collections.abc.Sequence, 1) MutableSequence = _alias(collections.abc.MutableSequence, 1) ByteString = _DeprecatedGenericAlias( - collections.abc.ByteString, 0, removal_version=(3, 14) # Not generic. + collections.abc.ByteString, 0, removal_version=(3, 17) # Not generic. ) # Tuple accepts variable number of parameters. Tuple = _TupleType(tuple, -1, inst=False, name='Tuple') @@ -2991,16 +2931,27 @@ def __round__(self, ndigits: int = 0) -> T: pass -def _make_nmtuple(name, types, module, defaults = ()): - fields = [n for n, t in types] - types = {n: _type_check(t, f"field {n} annotation must be a type") - for n, t in types} +def _make_nmtuple(name, fields, annotate_func, module, defaults = ()): nm_tpl = collections.namedtuple(name, fields, defaults=defaults, module=module) - nm_tpl.__annotations__ = nm_tpl.__new__.__annotations__ = types + nm_tpl.__annotate__ = nm_tpl.__new__.__annotate__ = annotate_func return nm_tpl +def _make_eager_annotate(types): + checked_types = {key: _type_check(val, f"field {key} annotation must be a type") + for key, val in types.items()} + def annotate(format): + match format: + case _lazy_annotationlib.Format.VALUE | _lazy_annotationlib.Format.FORWARDREF: + return checked_types + case _lazy_annotationlib.Format.STRING: + return _lazy_annotationlib.annotations_to_string(types) + case _: + raise NotImplementedError(format) + return annotate + + # attributes prohibited to set in NamedTuple class syntax _prohibited = frozenset({'__new__', '__init__', '__slots__', '__getnewargs__', '_fields', '_field_defaults', @@ -3013,6 +2964,9 @@ def _make_nmtuple(name, types, module, defaults = ()): class NamedTupleMeta(type): def __new__(cls, typename, bases, ns): assert _NamedTuple in bases + if "__classcell__" in ns: + raise TypeError( + "uses of super() and __class__ are unsupported in methods of NamedTuple subclasses") for base in bases: if base is not _NamedTuple and base is not Generic: raise TypeError( @@ -3020,13 +2974,30 @@ def __new__(cls, typename, bases, ns): bases = tuple(tuple if base is _NamedTuple else base for base in bases) if "__annotations__" in ns: types = ns["__annotations__"] - elif (annotate := _lazy_annotationlib.get_annotate_from_class_namespace(ns)) is not None: + field_names = list(types) + annotate = _make_eager_annotate(types) + elif (original_annotate := _lazy_annotationlib.get_annotate_from_class_namespace(ns)) is not None: types = _lazy_annotationlib.call_annotate_function( - annotate, _lazy_annotationlib.Format.VALUE) + original_annotate, _lazy_annotationlib.Format.FORWARDREF) + field_names = list(types) + + # For backward compatibility, type-check all the types at creation time + for typ in types.values(): + _type_check(typ, "field annotation must be a type") + + def annotate(format): + annos = _lazy_annotationlib.call_annotate_function( + original_annotate, format) + if format != _lazy_annotationlib.Format.STRING: + return {key: _type_check(val, f"field {key} annotation must be a type") + for key, val in annos.items()} + return annos else: - types = {} + # Empty NamedTuple + field_names = [] + annotate = lambda format: {} default_names = [] - for field_name in types: + for field_name in field_names: if field_name in ns: default_names.append(field_name) elif default_names: @@ -3034,7 +3005,7 @@ def __new__(cls, typename, bases, ns): f"cannot follow default field" f"{'s' if len(default_names) > 1 else ''} " f"{', '.join(default_names)}") - nm_tpl = _make_nmtuple(typename, types.items(), + nm_tpl = _make_nmtuple(typename, field_names, annotate, defaults=[ns[n] for n in default_names], module=ns['__module__']) nm_tpl.__bases__ = bases @@ -3125,7 +3096,11 @@ class Employee(NamedTuple): import warnings warnings._deprecated(deprecated_thing, message=deprecation_msg, remove=(3, 15)) fields = kwargs.items() - nt = _make_nmtuple(typename, fields, module=_caller()) + types = {n: _type_check(t, f"field {n} annotation must be a type") + for n, t in fields} + field_names = [n for n, _ in fields] + + nt = _make_nmtuple(typename, field_names, _make_eager_annotate(types), module=_caller()) nt.__orig_bases__ = (NamedTuple,) return nt @@ -3574,7 +3549,7 @@ def readline(self, limit: int = -1) -> AnyStr: pass @abstractmethod - def readlines(self, hint: int = -1) -> List[AnyStr]: + def readlines(self, hint: int = -1) -> list[AnyStr]: pass @abstractmethod @@ -3590,7 +3565,7 @@ def tell(self) -> int: pass @abstractmethod - def truncate(self, size: int = None) -> int: + def truncate(self, size: int | None = None) -> int: pass @abstractmethod @@ -3602,11 +3577,11 @@ def write(self, s: AnyStr) -> int: pass @abstractmethod - def writelines(self, lines: List[AnyStr]) -> None: + def writelines(self, lines: list[AnyStr]) -> None: pass @abstractmethod - def __enter__(self) -> 'IO[AnyStr]': + def __enter__(self) -> IO[AnyStr]: pass @abstractmethod @@ -3620,11 +3595,11 @@ class BinaryIO(IO[bytes]): __slots__ = () @abstractmethod - def write(self, s: Union[bytes, bytearray]) -> int: + def write(self, s: bytes | bytearray) -> int: pass @abstractmethod - def __enter__(self) -> 'BinaryIO': + def __enter__(self) -> BinaryIO: pass @@ -3645,7 +3620,7 @@ def encoding(self) -> str: @property @abstractmethod - def errors(self) -> Optional[str]: + def errors(self) -> str | None: pass @property @@ -3659,7 +3634,7 @@ def newlines(self) -> Any: pass @abstractmethod - def __enter__(self) -> 'TextIO': + def __enter__(self) -> TextIO: pass @@ -3855,7 +3830,9 @@ def __getattr__(attr): Soft-deprecated objects which are costly to create are only created on-demand here. """ - if attr in {"Pattern", "Match"}: + if attr == "ForwardRef": + obj = _lazy_annotationlib.ForwardRef + elif attr in {"Pattern", "Match"}: import re obj = _alias(getattr(re, attr), 1) elif attr in {"ContextManager", "AsyncContextManager"}: diff --git a/crates/vm/src/builtins/genericalias.rs b/crates/vm/src/builtins/genericalias.rs index e9150e4c088..21034e08f0e 100644 --- a/crates/vm/src/builtins/genericalias.rs +++ b/crates/vm/src/builtins/genericalias.rs @@ -235,11 +235,11 @@ impl PyGenericAlias { Err(vm.new_type_error("issubclass() argument 2 cannot be a parameterized generic")) } - fn __ror__(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + fn __ror__(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { type_::or_(other, zelf, vm) } - fn __or__(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + fn __or__(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { type_::or_(zelf, other, vm) } } @@ -509,7 +509,7 @@ impl AsMapping for PyGenericAlias { impl AsNumber for PyGenericAlias { fn as_number() -> &'static PyNumberMethods { static AS_NUMBER: PyNumberMethods = PyNumberMethods { - or: Some(|a, b, vm| Ok(PyGenericAlias::__or__(a.to_owned(), b.to_owned(), vm))), + or: Some(|a, b, vm| PyGenericAlias::__or__(a.to_owned(), b.to_owned(), vm)), ..PyNumberMethods::NOT_IMPLEMENTED }; &AS_NUMBER diff --git a/crates/vm/src/builtins/mappingproxy.rs b/crates/vm/src/builtins/mappingproxy.rs index f7fb64fa6ab..1036dcfdaf9 100644 --- a/crates/vm/src/builtins/mappingproxy.rs +++ b/crates/vm/src/builtins/mappingproxy.rs @@ -3,13 +3,14 @@ use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func, class::PyClassImpl, + common::hash, convert::ToPyObject, function::{ArgMapping, OptionalArg, PyComparisonValue}, object::{Traverse, TraverseFn}, protocol::{PyMappingMethods, PyNumberMethods, PySequenceMethods}, types::{ - AsMapping, AsNumber, AsSequence, Comparable, Constructor, Iterable, PyComparisonOp, - Representable, + AsMapping, AsNumber, AsSequence, Comparable, Constructor, Hashable, Iterable, + PyComparisonOp, Representable, }, }; use std::sync::LazyLock; @@ -83,6 +84,7 @@ impl Constructor for PyMappingProxy { Constructor, AsSequence, Comparable, + Hashable, AsNumber, Representable ))] @@ -215,6 +217,15 @@ impl Comparable for PyMappingProxy { } } +impl Hashable for PyMappingProxy { + #[inline] + fn hash(zelf: &Py, vm: &VirtualMachine) -> PyResult { + // Delegate hash to the underlying mapping + let obj = zelf.to_object(vm)?; + obj.hash(vm) + } +} + impl AsMapping for PyMappingProxy { fn as_mapping() -> &'static PyMappingMethods { static AS_MAPPING: LazyLock = LazyLock::new(|| PyMappingMethods { diff --git a/crates/vm/src/builtins/mod.rs b/crates/vm/src/builtins/mod.rs index c40f09bdef0..e9ca1f8b403 100644 --- a/crates/vm/src/builtins/mod.rs +++ b/crates/vm/src/builtins/mod.rs @@ -92,7 +92,7 @@ pub(crate) mod zip; pub use zip::PyZip; #[path = "union.rs"] pub(crate) mod union_; -pub use union_::PyUnion; +pub use union_::{PyUnion, make_union}; pub(crate) mod descriptor; pub use float::try_to_bigint as try_f64_to_bigint; diff --git a/crates/vm/src/builtins/namespace.rs b/crates/vm/src/builtins/namespace.rs index 03969c35e7b..2cc1693302a 100644 --- a/crates/vm/src/builtins/namespace.rs +++ b/crates/vm/src/builtins/namespace.rs @@ -42,15 +42,76 @@ impl PyNamespace { ); result.into_pytuple(vm) } + + #[pymethod] + fn __replace__(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + if !args.args.is_empty() { + return Err(vm.new_type_error("__replace__() takes no positional arguments")); + } + + // Create a new instance of the same type + let cls: PyObjectRef = zelf.class().to_owned().into(); + let result = cls.call((), vm)?; + + // Copy the current namespace dict to the new instance + let src_dict = zelf.dict().unwrap(); + let dst_dict = result.dict().unwrap(); + for (key, value) in src_dict { + dst_dict.set_item(&*key, value, vm)?; + } + + // Update with the provided kwargs + for (name, value) in args.kwargs { + let name = vm.ctx.new_str(name); + result.set_attr(&name, value, vm)?; + } + + Ok(result) + } } impl Initializer for PyNamespace { type Args = FuncArgs; fn init(zelf: PyRef, args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { - if !args.args.is_empty() { - return Err(vm.new_type_error("no positional arguments expected")); + // SimpleNamespace accepts 0 or 1 positional argument (a mapping) + if args.args.len() > 1 { + return Err(vm.new_type_error(format!( + "{} expected at most 1 positional argument, got {}", + zelf.class().name(), + args.args.len() + ))); } + + // If there's a positional argument, treat it as a mapping + if let Some(mapping) = args.args.first() { + // Convert to dict if not already + let dict: PyRef = if let Some(d) = mapping.downcast_ref::() { + d.to_owned() + } else { + // Call dict() on the mapping + let dict_type: PyObjectRef = vm.ctx.types.dict_type.to_owned().into(); + dict_type + .call((mapping.clone(),), vm)? + .downcast() + .map_err(|_| vm.new_type_error("dict() did not return a dict"))? + }; + + // Validate keys are strings and set attributes + for (key, value) in dict.into_iter() { + let key_str = key + .downcast_ref::() + .ok_or_else(|| { + vm.new_type_error(format!( + "keywords must be strings, not '{}'", + key.class().name() + )) + })?; + zelf.as_object().set_attr(key_str, value, vm)?; + } + } + + // Apply keyword arguments (these override positional mapping values) for (name, value) in args.kwargs { let name = vm.ctx.new_str(name); zelf.as_object().set_attr(&name, value, vm)?; diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index 829f7d4439a..510c3fb8491 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -20,7 +20,6 @@ use crate::{ borrow::BorrowedValue, lock::{PyRwLock, PyRwLockReadGuard}, }, - convert::ToPyResult, function::{FuncArgs, KwArgs, OptionalArg, PyMethodDef, PySetterValue}, object::{Traverse, TraverseFn}, protocol::{PyIterReturn, PyNumberMethods}, @@ -1038,11 +1037,11 @@ impl PyType { ) } - pub fn __ror__(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + pub fn __ror__(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { or_(other, zelf, vm) } - pub fn __or__(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + pub fn __or__(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { or_(zelf, other, vm) } @@ -1850,7 +1849,7 @@ impl Callable for PyType { impl AsNumber for PyType { fn as_number() -> &'static PyNumberMethods { static AS_NUMBER: PyNumberMethods = PyNumberMethods { - or: Some(|a, b, vm| or_(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)), + or: Some(|a, b, vm| or_(a.to_owned(), b.to_owned(), vm)), ..PyNumberMethods::NOT_IMPLEMENTED }; &AS_NUMBER @@ -2013,9 +2012,9 @@ pub(crate) fn call_slot_new( slot_new(subtype, args, vm) } -pub(crate) fn or_(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { +pub(crate) fn or_(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { if !union_::is_unionable(zelf.clone(), vm) || !union_::is_unionable(other.clone(), vm) { - return vm.ctx.not_implemented(); + return Ok(vm.ctx.not_implemented()); } let tuple = PyTuple::new_ref(vec![zelf, other], &vm.ctx); diff --git a/crates/vm/src/builtins/union.rs b/crates/vm/src/builtins/union.rs index b5e12dcb3c8..9856235ecf4 100644 --- a/crates/vm/src/builtins/union.rs +++ b/crates/vm/src/builtins/union.rs @@ -2,10 +2,10 @@ use super::{genericalias, type_}; use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func, - builtins::{PyFrozenSet, PyGenericAlias, PyStr, PyTuple, PyTupleRef, PyType}, + builtins::{PyFrozenSet, PySet, PyStr, PyTuple, PyTupleRef, PyType}, class::PyClassImpl, common::hash, - convert::{ToPyObject, ToPyResult}, + convert::ToPyObject, function::PyComparisonValue, protocol::{PyMappingMethods, PyNumberMethods}, stdlib::typing::TypeAliasType, @@ -16,9 +16,13 @@ use std::sync::LazyLock; const CLS_ATTRS: &[&str] = &["__module__"]; -#[pyclass(module = "types", name = "UnionType", traverse)] +#[pyclass(module = "typing", name = "Union", traverse)] pub struct PyUnion { args: PyTupleRef, + /// Frozenset of hashable args, or None if all args were hashable + hashable_args: Option>, + /// Tuple of initially unhashable args, or None if all args were hashable + unhashable_args: Option, parameters: PyTupleRef, } @@ -36,9 +40,15 @@ impl PyPayload for PyUnion { } impl PyUnion { - pub fn new(args: PyTupleRef, vm: &VirtualMachine) -> Self { - let parameters = make_parameters(&args, vm); - Self { args, parameters } + /// Create a new union from dedup result (internal use) + fn from_components(result: UnionComponents, vm: &VirtualMachine) -> PyResult { + let parameters = make_parameters(&result.args, vm)?; + Ok(Self { + args: result.args, + hashable_args: result.hashable_args, + unhashable_args: result.unhashable_args, + parameters, + }) } /// Direct access to args field, matching CPython's _Py_union_args @@ -88,10 +98,25 @@ impl PyUnion { } #[pyclass( - flags(BASETYPE), + flags(DISALLOW_INSTANTIATION), with(Hashable, Comparable, AsMapping, AsNumber, Representable) )] impl PyUnion { + #[pygetset] + fn __name__(&self, vm: &VirtualMachine) -> PyObjectRef { + vm.ctx.new_str("Union").into() + } + + #[pygetset] + fn __qualname__(&self, vm: &VirtualMachine) -> PyObjectRef { + vm.ctx.new_str("Union").into() + } + + #[pygetset] + fn __origin__(&self, vm: &VirtualMachine) -> PyObjectRef { + vm.ctx.types.union_type.to_owned().into() + } + #[pygetset] fn __parameters__(&self) -> PyObjectRef { self.parameters.clone().into() @@ -136,17 +161,35 @@ impl PyUnion { } } - fn __or__(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + fn __or__(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { type_::or_(zelf, other, vm) } + #[pymethod] + fn __mro_entries__(zelf: PyRef, _args: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error(format!("Cannot subclass {}", zelf.repr(vm)?))) + } + #[pyclassmethod] fn __class_getitem__( - cls: crate::builtins::PyTypeRef, + _cls: crate::builtins::PyTypeRef, args: PyObjectRef, vm: &VirtualMachine, - ) -> PyGenericAlias { - PyGenericAlias::from_args(cls, args, vm) + ) -> PyResult { + // Convert args to tuple if not already + let args_tuple = if let Some(tuple) = args.downcast_ref::() { + tuple.to_owned() + } else { + PyTuple::new_ref(vec![args], &vm.ctx) + }; + + // Check for empty union + if args_tuple.is_empty() { + return Err(vm.new_type_error("Cannot create empty Union")); + } + + // Create union using make_union to properly handle None -> NoneType conversion + make_union(&args_tuple, vm) } } @@ -159,9 +202,10 @@ pub fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool { || obj.downcast_ref::().is_some() } -fn make_parameters(args: &Py, vm: &VirtualMachine) -> PyTupleRef { +fn make_parameters(args: &Py, vm: &VirtualMachine) -> PyResult { let parameters = genericalias::make_parameters(args, vm); - dedup_and_flatten_args(¶meters, vm) + let result = dedup_and_flatten_args(¶meters, vm)?; + Ok(result.args) } fn flatten_args(args: &Py, vm: &VirtualMachine) -> PyTupleRef { @@ -180,6 +224,12 @@ fn flatten_args(args: &Py, vm: &VirtualMachine) -> PyTupleRef { flattened_args.extend(pyref.args.iter().cloned()); } else if vm.is_none(arg) { flattened_args.push(vm.ctx.types.none_type.to_owned().into()); + } else if arg.downcast_ref::().is_some() { + // Convert string to ForwardRef + match string_to_forwardref(arg.clone(), vm) { + Ok(fr) => flattened_args.push(fr), + Err(_) => flattened_args.push(arg.clone()), + } } else { flattened_args.push(arg.clone()); }; @@ -188,31 +238,105 @@ fn flatten_args(args: &Py, vm: &VirtualMachine) -> PyTupleRef { PyTuple::new_ref(flattened_args, &vm.ctx) } -fn dedup_and_flatten_args(args: &Py, vm: &VirtualMachine) -> PyTupleRef { +fn string_to_forwardref(arg: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // Import annotationlib.ForwardRef and create a ForwardRef + let annotationlib = vm.import("annotationlib", 0)?; + let forwardref_cls = annotationlib.get_attr("ForwardRef", vm)?; + forwardref_cls.call((arg,), vm) +} + +/// Components for creating a PyUnion after deduplication +struct UnionComponents { + /// All unique args in order + args: PyTupleRef, + /// Frozenset of hashable args (for fast equality comparison) + hashable_args: Option>, + /// Tuple of unhashable args at creation time (for hash error message) + unhashable_args: Option, +} + +fn dedup_and_flatten_args(args: &Py, vm: &VirtualMachine) -> PyResult { let args = flatten_args(args, vm); + // Use set-based deduplication like CPython: + // - For hashable elements: use Python's set semantics (hash + equality) + // - For unhashable elements: use equality comparison + // + // This avoids calling __eq__ when hashes differ, matching CPython behavior + // where `int | BadType` doesn't raise even if BadType.__eq__ raises. + let mut new_args: Vec = Vec::with_capacity(args.len()); + + // Track hashable elements using a Python set (uses hash + equality) + let hashable_set = PySet::default().into_ref(&vm.ctx); + let mut hashable_list: Vec = Vec::new(); + let mut unhashable_list: Vec = Vec::new(); + for arg in &*args { - if !new_args.iter().any(|param| { - param - .rich_compare_bool(arg, PyComparisonOp::Eq, vm) - .unwrap_or_default() - }) { - new_args.push(arg.clone()); + // Try to hash the element first + match arg.hash(vm) { + Ok(_) => { + // Element is hashable - use set for deduplication + // Set membership uses hash first, then equality only if hashes match + let contains = vm + .call_method(hashable_set.as_ref(), "__contains__", (arg.clone(),)) + .and_then(|r| r.try_to_bool(vm))?; + if !contains { + hashable_set.add(arg.clone(), vm)?; + hashable_list.push(arg.clone()); + new_args.push(arg.clone()); + } + } + Err(_) => { + // Element is unhashable - use equality comparison + let mut is_duplicate = false; + for existing in &unhashable_list { + match existing.rich_compare_bool(arg, PyComparisonOp::Eq, vm) { + Ok(true) => { + is_duplicate = true; + break; + } + Ok(false) => continue, + Err(e) => return Err(e), + } + } + if !is_duplicate { + unhashable_list.push(arg.clone()); + new_args.push(arg.clone()); + } + } } } new_args.shrink_to_fit(); - PyTuple::new_ref(new_args, &vm.ctx) + // Create hashable_args frozenset if there are hashable elements + let hashable_args = if !hashable_list.is_empty() { + Some(PyFrozenSet::from_iter(vm, hashable_list.into_iter())?.into_ref(&vm.ctx)) + } else { + None + }; + + // Create unhashable_args tuple if there are unhashable elements + let unhashable_args = if !unhashable_list.is_empty() { + Some(PyTuple::new_ref(unhashable_list, &vm.ctx)) + } else { + None + }; + + Ok(UnionComponents { + args: PyTuple::new_ref(new_args, &vm.ctx), + hashable_args, + unhashable_args, + }) } -pub fn make_union(args: &Py, vm: &VirtualMachine) -> PyObjectRef { - let args = dedup_and_flatten_args(args, vm); - match args.len() { - 1 => args[0].to_owned(), - _ => PyUnion::new(args, vm).to_pyobject(vm), - } +pub fn make_union(args: &Py, vm: &VirtualMachine) -> PyResult { + let result = dedup_and_flatten_args(args, vm)?; + Ok(match result.args.len() { + 1 => result.args[0].to_owned(), + _ => PyUnion::from_components(result, vm)?.to_pyobject(vm), + }) } impl PyUnion { @@ -224,14 +348,15 @@ impl PyUnion { needle, vm, )?; - let mut res; + let res; if new_args.is_empty() { - res = make_union(&new_args, vm); + res = make_union(&new_args, vm)?; } else { - res = new_args[0].to_owned(); + let mut tmp = new_args[0].to_owned(); for arg in new_args.iter().skip(1) { - res = vm._or(&res, arg)?; + tmp = vm._or(&tmp, arg)?; } + res = tmp; } Ok(res) @@ -254,7 +379,7 @@ impl AsMapping for PyUnion { impl AsNumber for PyUnion { fn as_number() -> &'static PyNumberMethods { static AS_NUMBER: PyNumberMethods = PyNumberMethods { - or: Some(|a, b, vm| PyUnion::__or__(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)), + or: Some(|a, b, vm| PyUnion::__or__(a.to_owned(), b.to_owned(), vm)), ..PyNumberMethods::NOT_IMPLEMENTED }; &AS_NUMBER @@ -270,15 +395,62 @@ impl Comparable for PyUnion { ) -> PyResult { op.eq_only(|| { let other = class_or_notimplemented!(Self, other); - let a = PyFrozenSet::from_iter(vm, zelf.args.into_iter().cloned())?; - let b = PyFrozenSet::from_iter(vm, other.args.into_iter().cloned())?; - Ok(PyComparisonValue::Implemented( - a.into_pyobject(vm).as_object().rich_compare_bool( - b.into_pyobject(vm).as_object(), - PyComparisonOp::Eq, - vm, - )?, - )) + + // Check if lengths are equal + if zelf.args.len() != other.args.len() { + return Ok(PyComparisonValue::Implemented(false)); + } + + // Fast path: if both unions have all hashable args, compare frozensets directly + // Always use Eq here since eq_only handles Ne by negating the result + if zelf.unhashable_args.is_none() + && other.unhashable_args.is_none() + && let (Some(a), Some(b)) = (&zelf.hashable_args, &other.hashable_args) + { + let eq = a + .as_object() + .rich_compare_bool(b.as_object(), PyComparisonOp::Eq, vm)?; + return Ok(PyComparisonValue::Implemented(eq)); + } + + // Slow path: O(n^2) nested loop comparison for unhashable elements + // Check if all elements in zelf.args are in other.args + for arg_a in &*zelf.args { + let mut found = false; + for arg_b in &*other.args { + match arg_a.rich_compare_bool(arg_b, PyComparisonOp::Eq, vm) { + Ok(true) => { + found = true; + break; + } + Ok(false) => continue, + Err(e) => return Err(e), // Propagate comparison errors + } + } + if !found { + return Ok(PyComparisonValue::Implemented(false)); + } + } + + // Check if all elements in other.args are in zelf.args (for symmetry) + for arg_b in &*other.args { + let mut found = false; + for arg_a in &*zelf.args { + match arg_b.rich_compare_bool(arg_a, PyComparisonOp::Eq, vm) { + Ok(true) => { + found = true; + break; + } + Ok(false) => continue, + Err(e) => return Err(e), // Propagate comparison errors + } + } + if !found { + return Ok(PyComparisonValue::Implemented(false)); + } + } + + Ok(PyComparisonValue::Implemented(true)) }) } } @@ -286,7 +458,36 @@ impl Comparable for PyUnion { impl Hashable for PyUnion { #[inline] fn hash(zelf: &Py, vm: &VirtualMachine) -> PyResult { - let set = PyFrozenSet::from_iter(vm, zelf.args.into_iter().cloned())?; + // If there are any unhashable args from creation time, the union is unhashable + if let Some(ref unhashable_args) = zelf.unhashable_args { + let n = unhashable_args.len(); + // Try to hash each previously unhashable arg to get an error + for arg in unhashable_args.iter() { + arg.hash(vm)?; + } + // All previously unhashable args somehow became hashable + // But still raise an error to maintain consistent hashing + return Err(vm.new_type_error(format!( + "union contains {} unhashable element{}", + n, + if n > 1 { "s" } else { "" } + ))); + } + + // If we have a stored frozenset of hashable args, use that + if let Some(ref hashable_args) = zelf.hashable_args { + return PyFrozenSet::hash(hashable_args, vm); + } + + // Fallback: compute hash from args + let mut args_to_hash = Vec::new(); + for arg in &*zelf.args { + match arg.hash(vm) { + Ok(_) => args_to_hash.push(arg.clone()), + Err(e) => return Err(e), + } + } + let set = PyFrozenSet::from_iter(vm, args_to_hash.into_iter())?; PyFrozenSet::hash(&set.into_ref(&vm.ctx), vm) } } diff --git a/crates/vm/src/stdlib/_abc.rs b/crates/vm/src/stdlib/_abc.rs new file mode 100644 index 00000000000..ef528d95731 --- /dev/null +++ b/crates/vm/src/stdlib/_abc.rs @@ -0,0 +1,481 @@ +//! Implementation of the `_abc` module. +//! +//! This module provides the C implementation of Abstract Base Classes (ABCs) +//! as defined in PEP 3119. + +pub(crate) use _abc::make_module; + +#[pymodule] +mod _abc { + use crate::{ + AsObject, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{PyFrozenSet, PyList, PySet, PyStr, PyTupleRef, PyTypeRef, PyWeak}, + common::lock::PyRwLock, + convert::ToPyObject, + protocol::PyIterReturn, + types::Constructor, + }; + use std::sync::atomic::{AtomicU64, Ordering}; + + // Global invalidation counter + static ABC_INVALIDATION_COUNTER: AtomicU64 = AtomicU64::new(0); + + fn get_invalidation_counter() -> u64 { + ABC_INVALIDATION_COUNTER.load(Ordering::SeqCst) + } + + fn increment_invalidation_counter() { + ABC_INVALIDATION_COUNTER.fetch_add(1, Ordering::SeqCst); + } + + /// Internal state held by ABC machinery. + #[pyattr] + #[pyclass(name = "_abc_data", module = "_abc")] + #[derive(Debug, PyPayload)] + struct AbcData { + // WeakRef sets for registry and caches + registry: PyRwLock>>, + cache: PyRwLock>>, + negative_cache: PyRwLock>>, + negative_cache_version: AtomicU64, + } + + #[pyclass(with(Constructor))] + impl AbcData { + fn new() -> Self { + AbcData { + registry: PyRwLock::new(None), + cache: PyRwLock::new(None), + negative_cache: PyRwLock::new(None), + negative_cache_version: AtomicU64::new(get_invalidation_counter()), + } + } + + fn get_cache_version(&self) -> u64 { + self.negative_cache_version.load(Ordering::SeqCst) + } + + fn set_cache_version(&self, version: u64) { + self.negative_cache_version.store(version, Ordering::SeqCst); + } + } + + impl Constructor for AbcData { + type Args = (); + + fn py_new( + _cls: &crate::Py, + _args: Self::Args, + _vm: &VirtualMachine, + ) -> PyResult { + Ok(AbcData::new()) + } + } + + /// Get the _abc_impl attribute from an ABC class + fn get_impl(cls: &PyObject, vm: &VirtualMachine) -> PyResult> { + let impl_obj = cls.get_attr("_abc_impl", vm)?; + impl_obj + .downcast::() + .map_err(|_| vm.new_type_error("_abc_impl is set to a wrong type".to_owned())) + } + + /// Check if obj is in the weak set + fn in_weak_set( + set_lock: &PyRwLock>>, + obj: &PyObject, + vm: &VirtualMachine, + ) -> PyResult { + let set_opt = set_lock.read(); + let set = match &*set_opt { + Some(s) if !s.elements().is_empty() => s.clone(), + _ => return Ok(false), + }; + drop(set_opt); + + // Create a weak reference to the object + let weak_ref = match obj.downgrade(None, vm) { + Ok(w) => w, + Err(e) => { + // If we can't create a weakref (e.g., TypeError), the object can't be in the set + if e.class().is(vm.ctx.exceptions.type_error) { + return Ok(false); + } + return Err(e); + } + }; + + // Use vm.call_method to call __contains__ + let weak_ref_obj: PyObjectRef = weak_ref.into(); + vm.call_method(set.as_ref(), "__contains__", (weak_ref_obj,))? + .try_to_bool(vm) + } + + /// Add obj to the weak set + fn add_to_weak_set( + set_lock: &PyRwLock>>, + obj: &PyObject, + vm: &VirtualMachine, + ) -> PyResult<()> { + let mut set_opt = set_lock.write(); + let set = match &*set_opt { + Some(s) => s.clone(), + None => { + let new_set = PySet::default().into_ref(&vm.ctx); + *set_opt = Some(new_set.clone()); + new_set + } + }; + drop(set_opt); + + // Create a weak reference to the object + let weak_ref = obj.downgrade(None, vm)?; + set.add(weak_ref.into(), vm)?; + Ok(()) + } + + /// Returns the current ABC cache token. + #[pyfunction] + fn get_cache_token() -> u64 { + get_invalidation_counter() + } + + /// Compute set of abstract method names. + fn compute_abstract_methods(cls: &PyObject, vm: &VirtualMachine) -> PyResult<()> { + let mut abstracts = Vec::new(); + + // Stage 1: direct abstract methods + let ns = cls.get_attr("__dict__", vm)?; + let items = vm.call_method(&ns, "items", ())?; + let iter = items.get_iter(vm)?; + + while let PyIterReturn::Return(item) = iter.next(vm)? { + let tuple: PyTupleRef = item + .downcast() + .map_err(|_| vm.new_type_error("items() returned non-tuple".to_owned()))?; + let elements = tuple.as_slice(); + if elements.len() != 2 { + return Err( + vm.new_type_error("items() returned item which size is not 2".to_owned()) + ); + } + let key = &elements[0]; + let value = &elements[1]; + + // Check if value has __isabstractmethod__ = True + if let Ok(is_abstract) = value.get_attr("__isabstractmethod__", vm) + && is_abstract.try_to_bool(vm)? + { + abstracts.push(key.clone()); + } + } + + // Stage 2: inherited abstract methods + let bases: PyTupleRef = cls + .get_attr("__bases__", vm)? + .downcast() + .map_err(|_| vm.new_type_error("__bases__ is not a tuple".to_owned()))?; + + for base in bases.iter() { + if let Ok(base_abstracts) = base.get_attr("__abstractmethods__", vm) { + let iter = base_abstracts.get_iter(vm)?; + while let PyIterReturn::Return(key) = iter.next(vm)? { + // Try to get the attribute from cls - key should be a string + if let Some(key_str) = key.downcast_ref::() + && let Some(value) = vm.get_attribute_opt(cls.to_owned(), key_str)? + && let Ok(is_abstract) = value.get_attr("__isabstractmethod__", vm) + && is_abstract.try_to_bool(vm)? + { + abstracts.push(key); + } + } + } + } + + // Set __abstractmethods__ + let abstracts_set = PyFrozenSet::from_iter(vm, abstracts.into_iter())?; + cls.set_attr("__abstractmethods__", abstracts_set.into_pyobject(vm), vm)?; + + Ok(()) + } + + /// Internal ABC helper for class set-up. Should be never used outside abc module. + #[pyfunction] + fn _abc_init(cls: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + compute_abstract_methods(&cls, vm)?; + + // Set up inheritance registry + let data = AbcData::new(); + cls.set_attr("_abc_impl", data.to_pyobject(vm), vm)?; + + Ok(()) + } + + /// Internal ABC helper for subclass registration. Should be never used outside abc module. + #[pyfunction] + fn _abc_register( + cls: PyObjectRef, + subclass: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult { + // Type check + if !subclass.class().fast_issubclass(vm.ctx.types.type_type) { + return Err(vm.new_type_error("Can only register classes".to_owned())); + } + + // Check if already a subclass + if subclass.is_subclass(&cls, vm)? { + return Ok(subclass); + } + + // Check for cycles + if cls.is_subclass(&subclass, vm)? { + return Err(vm.new_runtime_error("Refusing to create an inheritance cycle".to_owned())); + } + + // Add to registry + let impl_data = get_impl(&cls, vm)?; + add_to_weak_set(&impl_data.registry, &subclass, vm)?; + + // Invalidate negative cache + increment_invalidation_counter(); + + Ok(subclass) + } + + /// Internal ABC helper for instance checks. Should be never used outside abc module. + #[pyfunction] + fn _abc_instancecheck( + cls: PyObjectRef, + instance: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult { + let impl_data = get_impl(&cls, vm)?; + + // Get instance.__class__ + let subclass = instance.get_attr("__class__", vm)?; + + // Check cache + if in_weak_set(&impl_data.cache, &subclass, vm)? { + return Ok(vm.ctx.true_value.clone().into()); + } + + let subtype: PyObjectRef = instance.class().to_owned().into(); + if subtype.is(&subclass) { + let invalidation_counter = get_invalidation_counter(); + if impl_data.get_cache_version() == invalidation_counter + && in_weak_set(&impl_data.negative_cache, &subclass, vm)? + { + return Ok(vm.ctx.false_value.clone().into()); + } + // Fall back to __subclasscheck__ + return vm.call_method(&cls, "__subclasscheck__", (subclass,)); + } + + // Call __subclasscheck__ on subclass + let result = vm.call_method(&cls, "__subclasscheck__", (subclass.clone(),))?; + + match result.clone().try_to_bool(vm) { + Ok(true) => Ok(result), + Ok(false) => { + // Also try with subtype + vm.call_method(&cls, "__subclasscheck__", (subtype,)) + } + Err(e) => Err(e), + } + } + + /// Check if subclass is in registry (recursive) + fn subclasscheck_check_registry( + impl_data: &AbcData, + subclass: &PyObject, + vm: &VirtualMachine, + ) -> PyResult> { + // Fast path: check if subclass is in weakref directly + if in_weak_set(&impl_data.registry, subclass, vm)? { + return Ok(Some(true)); + } + + let registry_opt = impl_data.registry.read(); + let registry = match &*registry_opt { + Some(s) => s.clone(), + None => return Ok(None), + }; + drop(registry_opt); + + // Make a local copy to protect against concurrent modifications + let registry_copy = PyFrozenSet::from_iter(vm, registry.elements().into_iter())?; + + for weak_ref_obj in registry_copy.elements() { + if let Ok(weak_ref) = weak_ref_obj.downcast::() + && let Some(rkey) = weak_ref.upgrade() + && subclass.to_owned().is_subclass(&rkey, vm)? + { + add_to_weak_set(&impl_data.cache, subclass, vm)?; + return Ok(Some(true)); + } + } + + Ok(None) + } + + /// Internal ABC helper for subclass checks. Should be never used outside abc module. + #[pyfunction] + fn _abc_subclasscheck( + cls: PyObjectRef, + subclass: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult { + // Type check + if !subclass.class().fast_issubclass(vm.ctx.types.type_type) { + return Err(vm.new_type_error("issubclass() arg 1 must be a class".to_owned())); + } + + let impl_data = get_impl(&cls, vm)?; + + // 1. Check cache + if in_weak_set(&impl_data.cache, &subclass, vm)? { + return Ok(true); + } + + // 2. Check negative cache; may have to invalidate + let invalidation_counter = get_invalidation_counter(); + if impl_data.get_cache_version() < invalidation_counter { + // Invalidate the negative cache + // Clone set ref and drop lock before calling into VM to avoid reentrancy + let set = impl_data.negative_cache.read().clone(); + if let Some(ref set) = set { + vm.call_method(set.as_ref(), "clear", ())?; + } + impl_data.set_cache_version(invalidation_counter); + } else if in_weak_set(&impl_data.negative_cache, &subclass, vm)? { + return Ok(false); + } + + // 3. Check the subclass hook + let ok = vm.call_method(&cls, "__subclasshook__", (subclass.clone(),))?; + if ok.is(&vm.ctx.true_value) { + add_to_weak_set(&impl_data.cache, &subclass, vm)?; + return Ok(true); + } + if ok.is(&vm.ctx.false_value) { + add_to_weak_set(&impl_data.negative_cache, &subclass, vm)?; + return Ok(false); + } + if !ok.is(&vm.ctx.not_implemented) { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.assertion_error.to_owned(), + "__subclasshook__ must return either False, True, or NotImplemented".to_owned(), + )); + } + + // 4. Check if it's a direct subclass + let subclass_type: PyTypeRef = subclass + .clone() + .downcast() + .map_err(|_| vm.new_type_error("expected a type object".to_owned()))?; + let cls_type: PyTypeRef = cls + .clone() + .downcast() + .map_err(|_| vm.new_type_error("expected a type object".to_owned()))?; + if subclass_type.fast_issubclass(&cls_type) { + add_to_weak_set(&impl_data.cache, &subclass, vm)?; + return Ok(true); + } + + // 5. Check if it's a subclass of a registered class (recursive) + if let Some(result) = subclasscheck_check_registry(&impl_data, &subclass, vm)? { + return Ok(result); + } + + // 6. Check if it's a subclass of a subclass (recursive) + let subclasses: PyRef = vm + .call_method(&cls, "__subclasses__", ())? + .downcast() + .map_err(|_| vm.new_type_error("__subclasses__() must return a list".to_owned()))?; + + for scls in subclasses.borrow_vec().iter() { + if subclass.is_subclass(scls, vm)? { + add_to_weak_set(&impl_data.cache, &subclass, vm)?; + return Ok(true); + } + } + + // No dice; update negative cache + add_to_weak_set(&impl_data.negative_cache, &subclass, vm)?; + Ok(false) + } + + /// Internal ABC helper for cache and registry debugging. + #[pyfunction] + fn _get_dump(cls: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let impl_data = get_impl(&cls, vm)?; + + let registry = { + let r = impl_data.registry.read(); + match &*r { + Some(s) => { + // Use copy method to get a shallow copy + vm.call_method(s.as_ref(), "copy", ())? + } + None => PySet::default().to_pyobject(vm), + } + }; + + let cache = { + let c = impl_data.cache.read(); + match &*c { + Some(s) => vm.call_method(s.as_ref(), "copy", ())?, + None => PySet::default().to_pyobject(vm), + } + }; + + let negative_cache = { + let nc = impl_data.negative_cache.read(); + match &*nc { + Some(s) => vm.call_method(s.as_ref(), "copy", ())?, + None => PySet::default().to_pyobject(vm), + } + }; + + let version = impl_data.get_cache_version(); + + Ok(vm.ctx.new_tuple(vec![ + registry, + cache, + negative_cache, + vm.ctx.new_int(version).into(), + ])) + } + + /// Internal ABC helper to reset registry of a given class. + #[pyfunction] + fn _reset_registry(cls: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let impl_data = get_impl(&cls, vm)?; + // Clone set ref and drop lock before calling into VM to avoid reentrancy + let set = impl_data.registry.read().clone(); + if let Some(ref set) = set { + vm.call_method(set.as_ref(), "clear", ())?; + } + Ok(()) + } + + /// Internal ABC helper to reset both caches of a given class. + #[pyfunction] + fn _reset_caches(cls: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let impl_data = get_impl(&cls, vm)?; + + // Clone set refs and drop locks before calling into VM to avoid reentrancy + let cache = impl_data.cache.read().clone(); + if let Some(ref set) = cache { + vm.call_method(set.as_ref(), "clear", ())?; + } + + let negative_cache = impl_data.negative_cache.read().clone(); + if let Some(ref set) = negative_cache { + vm.call_method(set.as_ref(), "clear", ())?; + } + + Ok(()) + } +} diff --git a/crates/vm/src/stdlib/mod.rs b/crates/vm/src/stdlib/mod.rs index e46f333a28b..85c28983d74 100644 --- a/crates/vm/src/stdlib/mod.rs +++ b/crates/vm/src/stdlib/mod.rs @@ -1,3 +1,4 @@ +mod _abc; #[cfg(feature = "ast")] pub(crate) mod ast; pub mod atexit; @@ -87,6 +88,7 @@ pub fn get_module_inits() -> StdlibMap { modules! { #[cfg(all())] { + "_abc" => _abc::make_module, "atexit" => atexit::make_module, "_codecs" => codecs::make_module, "_collections" => collections::make_module, diff --git a/crates/vm/src/stdlib/typevar.rs b/crates/vm/src/stdlib/typevar.rs index 65249bfd075..e83eaf83555 100644 --- a/crates/vm/src/stdlib/typevar.rs +++ b/crates/vm/src/stdlib/typevar.rs @@ -1,7 +1,7 @@ // spell-checker:ignore typevarobject funcobj use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, - builtins::{PyTupleRef, PyType, PyTypeRef, pystr::AsPyStr}, + builtins::{PyTuple, PyTupleRef, PyType, PyTypeRef, make_union, pystr::AsPyStr}, common::lock::PyMutex, function::{FuncArgs, IntoFuncArgs, PyComparisonValue}, protocol::PyNumberMethods, @@ -250,7 +250,8 @@ impl AsNumber for TypeVar { fn as_number() -> &'static PyNumberMethods { static AS_NUMBER: PyNumberMethods = PyNumberMethods { or: Some(|a, b, vm| { - _call_typing_func_object(vm, "_make_union", (a.to_owned(), b.to_owned())) + let args = PyTuple::new_ref(vec![a.to_owned(), b.to_owned()], &vm.ctx); + make_union(&args, vm) }), ..PyNumberMethods::NOT_IMPLEMENTED }; @@ -525,7 +526,8 @@ impl AsNumber for ParamSpec { fn as_number() -> &'static PyNumberMethods { static AS_NUMBER: PyNumberMethods = PyNumberMethods { or: Some(|a, b, vm| { - _call_typing_func_object(vm, "_make_union", (a.to_owned(), b.to_owned())) + let args = PyTuple::new_ref(vec![a.to_owned(), b.to_owned()], &vm.ctx); + make_union(&args, vm) }), ..PyNumberMethods::NOT_IMPLEMENTED }; diff --git a/crates/vm/src/stdlib/typing.rs b/crates/vm/src/stdlib/typing.rs index f11acce3490..b8048b16d94 100644 --- a/crates/vm/src/stdlib/typing.rs +++ b/crates/vm/src/stdlib/typing.rs @@ -28,6 +28,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { "ParamSpecArgs" => ParamSpecArgs::class(&vm.ctx).to_owned(), "ParamSpecKwargs" => ParamSpecKwargs::class(&vm.ctx).to_owned(), "Generic" => Generic::class(&vm.ctx).to_owned(), + "Union" => vm.ctx.types.union_type.to_owned(), }); module } @@ -37,7 +38,6 @@ pub(crate) mod decl { use crate::{ Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::{PyStrRef, PyTupleRef, PyType, PyTypeRef, pystr::AsPyStr, type_}, - convert::ToPyResult, function::{FuncArgs, IntoFuncArgs}, protocol::PyNumberMethods, types::{AsNumber, Constructor, Representable}, @@ -188,7 +188,7 @@ pub(crate) mod decl { impl AsNumber for TypeAliasType { fn as_number() -> &'static PyNumberMethods { static AS_NUMBER: PyNumberMethods = PyNumberMethods { - or: Some(|a, b, vm| type_::or_(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)), + or: Some(|a, b, vm| type_::or_(a.to_owned(), b.to_owned(), vm)), ..PyNumberMethods::NOT_IMPLEMENTED }; &AS_NUMBER diff --git a/crates/vm/src/types/slot.rs b/crates/vm/src/types/slot.rs index b637bfc40b6..2f45d9dcff1 100644 --- a/crates/vm/src/types/slot.rs +++ b/crates/vm/src/types/slot.rs @@ -483,7 +483,7 @@ fn hash_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { /// Marks a type as unhashable. Similar to PyObject_HashNotImplemented in CPython pub fn hash_not_implemented(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { - Err(vm.new_type_error(format!("unhashable type: {}", zelf.class().name()))) + Err(vm.new_type_error(format!("unhashable type: '{}'", zelf.class().name()))) } fn call_wrapper(zelf: &PyObject, args: FuncArgs, vm: &VirtualMachine) -> PyResult {