diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 12db84a1209..4d237bdd1a8 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -86,6 +86,7 @@ def test_field_recursive_repr(self): self.assertIn(",type=...,", repr_output) + @unittest.expectedFailure # TODO: RUSTPYTHON; recursive annotation type not shown as ... def test_recursive_annotation(self): class C: pass diff --git a/Lib/test/test_genericalias.py b/Lib/test/test_genericalias.py index 3da8c2b1eea..cc0ca93e79b 100644 --- a/Lib/test/test_genericalias.py +++ b/Lib/test/test_genericalias.py @@ -57,6 +57,11 @@ from weakref import WeakSet, ReferenceType, ref import typing from typing import Unpack +try: + from tkinter import Event +except ImportError: + Event = None +from string.templatelib import Template, Interpolation from typing import TypeVar T = TypeVar('T') @@ -96,7 +101,7 @@ class BaseTest(unittest.TestCase): """Test basics.""" - generic_types = [type, tuple, list, dict, set, frozenset, enumerate, + generic_types = [type, tuple, list, dict, set, frozenset, enumerate, memoryview, defaultdict, deque, SequenceMatcher, dircmp, @@ -133,13 +138,21 @@ class BaseTest(unittest.TestCase): Future, _WorkItem, Morsel, DictReader, DictWriter, - array] + array, + staticmethod, + classmethod, + Template, + Interpolation, + ] if ctypes is not None: - generic_types.extend((ctypes.Array, ctypes.LibraryLoader)) + generic_types.extend((ctypes.Array, ctypes.LibraryLoader, ctypes.py_object)) if ValueProxy is not None: generic_types.extend((ValueProxy, DictProxy, ListProxy, ApplyResult, MPSimpleQueue, MPQueue, MPJoinableQueue)) + if Event is not None: + generic_types.append(Event) + @unittest.expectedFailure # TODO: RUSTPYTHON; memoryview, Template, Interpolation, py_object not subscriptable def test_subscriptable(self): for t in self.generic_types: if t is None: @@ -209,7 +222,6 @@ class MyList(list): self.assertEqual(t.__args__, (int,)) self.assertEqual(t.__parameters__, ()) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_repr(self): class MyList(list): pass @@ -225,13 +237,63 @@ class MyGeneric: self.assertEqual(repr(x2), 'tuple[*tuple[int, str]]') x3 = tuple[*tuple[int, ...]] self.assertEqual(repr(x3), 'tuple[*tuple[int, ...]]') - self.assertTrue(repr(MyList[int]).endswith('.BaseTest.test_repr..MyList[int]')) + self.assertEndsWith(repr(MyList[int]), '.BaseTest.test_repr..MyList[int]') self.assertEqual(repr(list[str]()), '[]') # instances should keep their normal repr # gh-105488 - self.assertTrue(repr(MyGeneric[int]).endswith('MyGeneric[int]')) - self.assertTrue(repr(MyGeneric[[]]).endswith('MyGeneric[[]]')) - self.assertTrue(repr(MyGeneric[[int, str]]).endswith('MyGeneric[[int, str]]')) + self.assertEndsWith(repr(MyGeneric[int]), 'MyGeneric[int]') + self.assertEndsWith(repr(MyGeneric[[]]), 'MyGeneric[[]]') + self.assertEndsWith(repr(MyGeneric[[int, str]]), 'MyGeneric[[int, str]]') + + def test_evil_repr1(self): + # gh-143635 + class Zap: + def __init__(self, container): + self.container = container + def __getattr__(self, name): + if name == "__origin__": + self.container.clear() + return None + if name == "__args__": + return () + raise AttributeError + + params = [] + params.append(Zap(params)) + alias = GenericAlias(list, (params,)) + repr_str = repr(alias) + self.assertTrue(repr_str.startswith("list[["), repr_str) + + def test_evil_repr2(self): + class Zap: + def __init__(self, container): + self.container = container + def __getattr__(self, name): + if name == "__qualname__": + self.container.clear() + return "abcd" + if name == "__module__": + return None + raise AttributeError + + params = [] + params.append(Zap(params)) + alias = GenericAlias(list, (params,)) + repr_str = repr(alias) + self.assertTrue(repr_str.startswith("list[["), repr_str) + + def test_evil_repr3(self): + # gh-143823 + lst = [] + class X: + def __repr__(self): + lst.clear() + return "x" + + lst += [X(), 1] + ga = GenericAlias(int, lst) + with self.assertRaises(IndexError): + repr(ga) def test_exposed_type(self): import types @@ -333,7 +395,6 @@ def test_parameter_chaining(self): with self.assertRaises(TypeError): dict[T, T][str, int] - @unittest.expectedFailure # TODO: RUSTPYTHON def test_equality(self): self.assertEqual(list[int], list[int]) self.assertEqual(dict[str, int], dict[str, int]) @@ -352,7 +413,7 @@ def test_isinstance(self): def test_issubclass(self): class L(list): ... - self.assertTrue(issubclass(L, list)) + self.assertIsSubclass(L, list) with self.assertRaises(TypeError): issubclass(L, list[str]) @@ -424,7 +485,6 @@ def test_union_generic(self): self.assertEqual(a.__args__, (list[T], tuple[T, ...])) self.assertEqual(a.__parameters__, (T,)) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_dir(self): ga = list[int] dir_of_gen_alias = set(dir(ga)) @@ -447,6 +507,7 @@ def test_dir(self): with self.subTest(entry=entry): getattr(ga, entry) # must not raise `AttributeError` + @unittest.expectedFailure # TODO: RUSTPYTHON; memoryview, Template, Interpolation, py_object not subscriptable def test_weakref(self): for t in self.generic_types: if t is None: @@ -490,6 +551,76 @@ def test_del_iter(self): iter_x = iter(t) del iter_x + def test_paramspec_specialization(self): + # gh-124445 + T = TypeVar("T") + U = TypeVar("U") + type X[**P] = Callable[P, int] + + generic = X[[T]] + self.assertEqual(generic.__args__, ([T],)) + self.assertEqual(generic.__parameters__, (T,)) + specialized = generic[str] + self.assertEqual(specialized.__args__, ([str],)) + self.assertEqual(specialized.__parameters__, ()) + + generic = X[(T,)] + self.assertEqual(generic.__args__, (T,)) + self.assertEqual(generic.__parameters__, (T,)) + specialized = generic[str] + self.assertEqual(specialized.__args__, (str,)) + self.assertEqual(specialized.__parameters__, ()) + + generic = X[[T, U]] + self.assertEqual(generic.__args__, ([T, U],)) + self.assertEqual(generic.__parameters__, (T, U)) + specialized = generic[str, int] + self.assertEqual(specialized.__args__, ([str, int],)) + self.assertEqual(specialized.__parameters__, ()) + + generic = X[(T, U)] + self.assertEqual(generic.__args__, (T, U)) + self.assertEqual(generic.__parameters__, (T, U)) + specialized = generic[str, int] + self.assertEqual(specialized.__args__, (str, int)) + self.assertEqual(specialized.__parameters__, ()) + + def test_nested_paramspec_specialization(self): + # gh-124445 + type X[**P, T] = Callable[P, T] + + x_list = X[[int, str], float] + self.assertEqual(x_list.__args__, ([int, str], float)) + self.assertEqual(x_list.__parameters__, ()) + + x_tuple = X[(int, str), float] + self.assertEqual(x_tuple.__args__, ((int, str), float)) + self.assertEqual(x_tuple.__parameters__, ()) + + U = TypeVar("U") + V = TypeVar("V") + + multiple_params_list = X[[int, U], V] + self.assertEqual(multiple_params_list.__args__, ([int, U], V)) + self.assertEqual(multiple_params_list.__parameters__, (U, V)) + multiple_params_list_specialized = multiple_params_list[str, float] + self.assertEqual(multiple_params_list_specialized.__args__, ([int, str], float)) + self.assertEqual(multiple_params_list_specialized.__parameters__, ()) + + multiple_params_tuple = X[(int, U), V] + self.assertEqual(multiple_params_tuple.__args__, ((int, U), V)) + self.assertEqual(multiple_params_tuple.__parameters__, (U, V)) + multiple_params_tuple_specialized = multiple_params_tuple[str, float] + self.assertEqual(multiple_params_tuple_specialized.__args__, ((int, str), float)) + self.assertEqual(multiple_params_tuple_specialized.__parameters__, ()) + + deeply_nested = X[[U, [V], int], V] + self.assertEqual(deeply_nested.__args__, ([U, [V], int], V)) + self.assertEqual(deeply_nested.__parameters__, (U, V)) + deeply_nested_specialized = deeply_nested[str, float] + self.assertEqual(deeply_nested_specialized.__args__, ([str, [float], int], float)) + self.assertEqual(deeply_nested_specialized.__parameters__, ()) + class TypeIterationTests(unittest.TestCase): _UNITERABLE_TYPES = (list, tuple) diff --git a/Lib/test/test_reprlib.py b/Lib/test/test_reprlib.py index 34aba4cfd35..3396b54cc9f 100644 --- a/Lib/test/test_reprlib.py +++ b/Lib/test/test_reprlib.py @@ -845,7 +845,6 @@ def __repr__(self): self.assertIs(X.f, X.__repr__.__wrapped__) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'TypeVar' object has no attribute '__name__' def test__type_params__(self): class My: @recursive_repr() diff --git a/Lib/test/test_type_aliases.py b/Lib/test/test_type_aliases.py new file mode 100644 index 00000000000..ee1791bc1d0 --- /dev/null +++ b/Lib/test/test_type_aliases.py @@ -0,0 +1,415 @@ +import pickle +import types +import unittest +from test.support import check_syntax_error, run_code +from test.typinganndata import mod_generics_cache + +from typing import ( + Callable, TypeAliasType, TypeVar, TypeVarTuple, ParamSpec, Unpack, get_args, +) + + +class TypeParamsInvalidTest(unittest.TestCase): + def test_name_collisions(self): + check_syntax_error(self, 'type TA1[A, **A] = None', "duplicate type parameter 'A'") + check_syntax_error(self, 'type T[A, *A] = None', "duplicate type parameter 'A'") + check_syntax_error(self, 'type T[*A, **A] = None', "duplicate type parameter 'A'") + + def test_name_non_collision_02(self): + ns = run_code("""type TA1[A] = lambda A: A""") + self.assertIsInstance(ns["TA1"], TypeAliasType) + self.assertTrue(callable(ns["TA1"].__value__)) + self.assertEqual("arg", ns["TA1"].__value__("arg")) + + def test_name_non_collision_03(self): + ns = run_code(""" + class Outer[A]: + type TA1[A] = None + """ + ) + outer_A, = ns["Outer"].__type_params__ + inner_A, = ns["Outer"].TA1.__type_params__ + self.assertIsNot(outer_A, inner_A) + + +class TypeParamsAccessTest(unittest.TestCase): + def test_alias_access_01(self): + ns = run_code("type TA1[A, B] = dict[A, B]") + alias = ns["TA1"] + self.assertIsInstance(alias, TypeAliasType) + self.assertEqual(alias.__type_params__, get_args(alias.__value__)) + + def test_alias_access_02(self): + ns = run_code(""" + type TA1[A, B] = TA1[A, B] | int + """ + ) + alias = ns["TA1"] + self.assertIsInstance(alias, TypeAliasType) + A, B = alias.__type_params__ + self.assertEqual(alias.__value__, alias[A, B] | int) + + def test_alias_access_03(self): + ns = run_code(""" + class Outer[A]: + def inner[B](self): + type TA1[C] = TA1[A, B] | int + return TA1 + """ + ) + cls = ns["Outer"] + A, = cls.__type_params__ + B, = cls.inner.__type_params__ + alias = cls.inner(None) + self.assertIsInstance(alias, TypeAliasType) + alias2 = cls.inner(None) + self.assertIsNot(alias, alias2) + self.assertEqual(len(alias.__type_params__), 1) + + self.assertEqual(alias.__value__, alias[A, B] | int) + + +class TypeParamsAliasValueTest(unittest.TestCase): + def test_alias_value_01(self): + type TA1 = int + + self.assertIsInstance(TA1, TypeAliasType) + self.assertEqual(TA1.__value__, int) + self.assertEqual(TA1.__parameters__, ()) + self.assertEqual(TA1.__type_params__, ()) + + type TA2 = TA1 | str + + self.assertIsInstance(TA2, TypeAliasType) + a, b = TA2.__value__.__args__ + self.assertEqual(a, TA1) + self.assertEqual(b, str) + self.assertEqual(TA2.__parameters__, ()) + self.assertEqual(TA2.__type_params__, ()) + + def test_alias_value_02(self): + class Parent[A]: + type TA1[B] = dict[A, B] + + self.assertIsInstance(Parent.TA1, TypeAliasType) + self.assertEqual(len(Parent.TA1.__parameters__), 1) + self.assertEqual(len(Parent.__parameters__), 1) + a, = Parent.__parameters__ + b, = Parent.TA1.__parameters__ + self.assertEqual(Parent.__type_params__, (a,)) + self.assertEqual(Parent.TA1.__type_params__, (b,)) + self.assertEqual(Parent.TA1.__value__, dict[a, b]) + + def test_alias_value_03(self): + def outer[A](): + type TA1[B] = dict[A, B] + return TA1 + + o = outer() + self.assertIsInstance(o, TypeAliasType) + self.assertEqual(len(o.__parameters__), 1) + self.assertEqual(len(outer.__type_params__), 1) + b = o.__parameters__[0] + self.assertEqual(o.__type_params__, (b,)) + + def test_alias_value_04(self): + def more_generic[T, *Ts, **P](): + type TA[T2, *Ts2, **P2] = tuple[Callable[P, tuple[T, *Ts]], Callable[P2, tuple[T2, *Ts2]]] + return TA + + alias = more_generic() + self.assertIsInstance(alias, TypeAliasType) + T2, Ts2, P2 = alias.__type_params__ + self.assertEqual(alias.__parameters__, (T2, *Ts2, P2)) + T, Ts, P = more_generic.__type_params__ + self.assertEqual(alias.__value__, tuple[Callable[P, tuple[T, *Ts]], Callable[P2, tuple[T2, *Ts2]]]) + + def test_subscripting(self): + type NonGeneric = int + type Generic[A] = dict[A, A] + type VeryGeneric[T, *Ts, **P] = Callable[P, tuple[T, *Ts]] + + with self.assertRaises(TypeError): + NonGeneric[int] + + specialized = Generic[int] + self.assertIsInstance(specialized, types.GenericAlias) + self.assertIs(specialized.__origin__, Generic) + self.assertEqual(specialized.__args__, (int,)) + + specialized2 = VeryGeneric[int, str, float, [bool, range]] + self.assertIsInstance(specialized2, types.GenericAlias) + self.assertIs(specialized2.__origin__, VeryGeneric) + self.assertEqual(specialized2.__args__, (int, str, float, [bool, range])) + + def test_repr(self): + type Simple = int + type VeryGeneric[T, *Ts, **P] = Callable[P, tuple[T, *Ts]] + + self.assertEqual(repr(Simple), "Simple") + self.assertEqual(repr(VeryGeneric), "VeryGeneric") + self.assertEqual(repr(VeryGeneric[int, bytes, str, [float, object]]), + "VeryGeneric[int, bytes, str, [float, object]]") + self.assertEqual(repr(VeryGeneric[int, []]), + "VeryGeneric[int, []]") + self.assertEqual(repr(VeryGeneric[int, [VeryGeneric[int], list[str]]]), + "VeryGeneric[int, [VeryGeneric[int], list[str]]]") + + def test_recursive_repr(self): + type Recursive = Recursive + self.assertEqual(repr(Recursive), "Recursive") + + type X = list[Y] + type Y = list[X] + self.assertEqual(repr(X), "X") + self.assertEqual(repr(Y), "Y") + + type GenericRecursive[X] = list[X | GenericRecursive[X]] + self.assertEqual(repr(GenericRecursive), "GenericRecursive") + self.assertEqual(repr(GenericRecursive[int]), "GenericRecursive[int]") + self.assertEqual(repr(GenericRecursive[GenericRecursive[int]]), + "GenericRecursive[GenericRecursive[int]]") + + def test_raising(self): + type MissingName = list[_My_X] + with self.assertRaisesRegex( + NameError, + "cannot access free variable '_My_X' where it is not associated with a value", + ): + MissingName.__value__ + _My_X = int + self.assertEqual(MissingName.__value__, list[int]) + del _My_X + # Cache should still work: + self.assertEqual(MissingName.__value__, list[int]) + + # Explicit exception: + type ExprException = 1 / 0 + with self.assertRaises(ZeroDivisionError): + ExprException.__value__ + + +class TypeAliasConstructorTest(unittest.TestCase): + def test_basic(self): + TA = TypeAliasType("TA", int) + self.assertEqual(TA.__name__, "TA") + self.assertIs(TA.__value__, int) + self.assertEqual(TA.__type_params__, ()) + self.assertEqual(TA.__module__, __name__) + + def test_attributes_with_exec(self): + ns = {} + exec("type TA = int", ns, ns) + TA = ns["TA"] + self.assertEqual(TA.__name__, "TA") + self.assertIs(TA.__value__, int) + self.assertEqual(TA.__type_params__, ()) + self.assertIs(TA.__module__, None) + + def test_generic(self): + T = TypeVar("T") + TA = TypeAliasType("TA", list[T], type_params=(T,)) + self.assertEqual(TA.__name__, "TA") + self.assertEqual(TA.__value__, list[T]) + self.assertEqual(TA.__type_params__, (T,)) + self.assertEqual(TA.__module__, __name__) + self.assertIs(type(TA[int]), types.GenericAlias) + + def test_not_generic(self): + TA = TypeAliasType("TA", list[int], type_params=()) + self.assertEqual(TA.__name__, "TA") + self.assertEqual(TA.__value__, list[int]) + self.assertEqual(TA.__type_params__, ()) + self.assertEqual(TA.__module__, __name__) + with self.assertRaisesRegex( + TypeError, + "Only generic type aliases are subscriptable", + ): + TA[int] + + def test_type_params_order_with_defaults(self): + HasNoDefaultT = TypeVar("HasNoDefaultT") + WithDefaultT = TypeVar("WithDefaultT", default=int) + + HasNoDefaultP = ParamSpec("HasNoDefaultP") + WithDefaultP = ParamSpec("WithDefaultP", default=HasNoDefaultP) + + HasNoDefaultTT = TypeVarTuple("HasNoDefaultTT") + WithDefaultTT = TypeVarTuple("WithDefaultTT", default=HasNoDefaultTT) + + for type_params in [ + (HasNoDefaultT, WithDefaultT), + (HasNoDefaultP, WithDefaultP), + (HasNoDefaultTT, WithDefaultTT), + ]: + with self.subTest(type_params=type_params): + TypeAliasType("A", int, type_params=type_params) # ok + + msg = "follows default type parameter" + for type_params in [ + (WithDefaultT, HasNoDefaultT), + (WithDefaultP, HasNoDefaultP), + (WithDefaultTT, HasNoDefaultTT), + (WithDefaultT, HasNoDefaultP), # different types + ]: + with self.subTest(type_params=type_params): + with self.assertRaisesRegex(TypeError, msg): + TypeAliasType("A", int, type_params=type_params) + + def test_expects_type_like(self): + T = TypeVar("T") + + msg = "Expected a type param" + with self.assertRaisesRegex(TypeError, msg): + TypeAliasType("A", int, type_params=(1,)) + with self.assertRaisesRegex(TypeError, msg): + TypeAliasType("A", int, type_params=(1, 2)) + with self.assertRaisesRegex(TypeError, msg): + TypeAliasType("A", int, type_params=(T, 2)) + + def test_keywords(self): + TA = TypeAliasType(name="TA", value=int) + self.assertEqual(TA.__name__, "TA") + self.assertIs(TA.__value__, int) + self.assertEqual(TA.__type_params__, ()) + self.assertEqual(TA.__module__, __name__) + + def test_errors(self): + with self.assertRaises(TypeError): + TypeAliasType() + with self.assertRaises(TypeError): + TypeAliasType("TA") + with self.assertRaises(TypeError): + TypeAliasType("TA", list, ()) + with self.assertRaises(TypeError): + TypeAliasType("TA", list, type_params=42) + + +class TypeAliasTypeTest(unittest.TestCase): + def test_immutable(self): + with self.assertRaises(TypeError): + TypeAliasType.whatever = "not allowed" + + def test_no_subclassing(self): + with self.assertRaisesRegex(TypeError, "not an acceptable base type"): + class MyAlias(TypeAliasType): + pass + + def test_union(self): + type Alias1 = int + type Alias2 = str + union = Alias1 | Alias2 + self.assertIsInstance(union, types.UnionType) + self.assertEqual(get_args(union), (Alias1, Alias2)) + union2 = Alias1 | list[float] + self.assertIsInstance(union2, types.UnionType) + self.assertEqual(get_args(union2), (Alias1, list[float])) + union3 = list[range] | Alias1 + self.assertIsInstance(union3, types.UnionType) + self.assertEqual(get_args(union3), (list[range], Alias1)) + + def test_module(self): + self.assertEqual(TypeAliasType.__module__, "typing") + type Alias = int + self.assertEqual(Alias.__module__, __name__) + self.assertEqual(mod_generics_cache.Alias.__module__, + mod_generics_cache.__name__) + self.assertEqual(mod_generics_cache.OldStyle.__module__, + mod_generics_cache.__name__) + + def test_unpack(self): + type Alias = tuple[int, int] + unpacked = (*Alias,)[0] + self.assertEqual(unpacked, Unpack[Alias]) + + class Foo[*Ts]: + pass + + x = Foo[str, *Alias] + self.assertEqual(x.__args__, (str, Unpack[Alias])) + + +# All these type aliases are used for pickling tests: +T = TypeVar('T') +type SimpleAlias = int +type RecursiveAlias = dict[str, RecursiveAlias] +type GenericAlias[X] = list[X] +type GenericAliasMultipleTypes[X, Y] = dict[X, Y] +type RecursiveGenericAlias[X] = dict[str, RecursiveAlias[X]] +type BoundGenericAlias[X: int] = set[X] +type ConstrainedGenericAlias[LongName: (str, bytes)] = list[LongName] +type AllTypesAlias[A, *B, **C] = Callable[C, A] | tuple[*B] + + +class TypeAliasPickleTest(unittest.TestCase): + def test_pickling(self): + things_to_test = [ + SimpleAlias, + RecursiveAlias, + + GenericAlias, + GenericAlias[T], + GenericAlias[int], + + GenericAliasMultipleTypes, + GenericAliasMultipleTypes[str, T], + GenericAliasMultipleTypes[T, str], + GenericAliasMultipleTypes[int, str], + + RecursiveGenericAlias, + RecursiveGenericAlias[T], + RecursiveGenericAlias[int], + + BoundGenericAlias, + BoundGenericAlias[int], + BoundGenericAlias[T], + + ConstrainedGenericAlias, + ConstrainedGenericAlias[str], + ConstrainedGenericAlias[T], + + AllTypesAlias, + AllTypesAlias[int, str, T, [T, object]], + + # Other modules: + mod_generics_cache.Alias, + mod_generics_cache.OldStyle, + ] + for thing in things_to_test: + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(thing=thing, proto=proto): + pickled = pickle.dumps(thing, protocol=proto) + self.assertEqual(pickle.loads(pickled), thing) + + type ClassLevel = str + + def test_pickling_local(self): + type A = int + things_to_test = [ + self.ClassLevel, + A, + ] + for thing in things_to_test: + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(thing=thing, proto=proto): + with self.assertRaises(pickle.PickleError): + pickle.dumps(thing, protocol=proto) + + +class TypeParamsExoticGlobalsTest(unittest.TestCase): + def test_exec_with_unusual_globals(self): + class customdict(dict): + def __missing__(self, key): + return key + + code = compile("type Alias = undefined", "test", "exec") + ns = customdict() + exec(code, ns) + Alias = ns["Alias"] + self.assertEqual(Alias.__value__, "undefined") + + code = compile("class A: type Alias = undefined", "test", "exec") + ns = customdict() + exec(code, ns) + Alias = ns["A"].Alias + self.assertEqual(Alias.__value__, "undefined") diff --git a/Lib/test/test_type_annotations.py b/Lib/test/test_type_annotations.py new file mode 100644 index 00000000000..4ed786cca3a --- /dev/null +++ b/Lib/test/test_type_annotations.py @@ -0,0 +1,877 @@ +import annotationlib +import inspect +import textwrap +import types +import unittest +from test.support import run_code, check_syntax_error, import_helper, cpython_only +from test.test_inspect import inspect_stringized_annotations + + +class TypeAnnotationTests(unittest.TestCase): + + def test_lazy_create_annotations(self): + # type objects lazy create their __annotations__ dict on demand. + # the annotations dict is stored in type.__dict__ (as __annotations_cache__). + # a freshly created type shouldn't have an annotations dict yet. + foo = type("Foo", (), {}) + for i in range(3): + self.assertFalse("__annotations_cache__" in foo.__dict__) + d = foo.__annotations__ + self.assertTrue("__annotations_cache__" in foo.__dict__) + self.assertEqual(foo.__annotations__, d) + self.assertEqual(foo.__dict__['__annotations_cache__'], d) + del foo.__annotations__ + + def test_setting_annotations(self): + foo = type("Foo", (), {}) + for i in range(3): + self.assertFalse("__annotations_cache__" in foo.__dict__) + d = {'a': int} + foo.__annotations__ = d + self.assertTrue("__annotations_cache__" in foo.__dict__) + self.assertEqual(foo.__annotations__, d) + self.assertEqual(foo.__dict__['__annotations_cache__'], d) + del foo.__annotations__ + + def test_annotations_getset_raises(self): + # builtin types don't have __annotations__ (yet!) + with self.assertRaises(AttributeError): + print(float.__annotations__) + with self.assertRaises(TypeError): + float.__annotations__ = {} + with self.assertRaises(TypeError): + del float.__annotations__ + + # double delete + foo = type("Foo", (), {}) + foo.__annotations__ = {} + del foo.__annotations__ + with self.assertRaises(AttributeError): + del foo.__annotations__ + + def test_annotations_are_created_correctly(self): + class C: + a:int=3 + b:str=4 + self.assertEqual(C.__annotations__, {"a": int, "b": str}) + self.assertTrue("__annotations_cache__" in C.__dict__) + del C.__annotations__ + self.assertFalse("__annotations_cache__" in C.__dict__) + + def test_pep563_annotations(self): + isa = inspect_stringized_annotations + self.assertEqual( + isa.__annotations__, {"a": "int", "b": "str"}, + ) + self.assertEqual( + isa.MyClass.__annotations__, {"a": "int", "b": "str"}, + ) + + def test_explicitly_set_annotations(self): + class C: + __annotations__ = {"what": int} + self.assertEqual(C.__annotations__, {"what": int}) + + def test_explicitly_set_annotate(self): + class C: + __annotate__ = lambda format: {"what": int} + self.assertEqual(C.__annotations__, {"what": int}) + self.assertIsInstance(C.__annotate__, types.FunctionType) + self.assertEqual(C.__annotate__(annotationlib.Format.VALUE), {"what": int}) + + def test_del_annotations_and_annotate(self): + # gh-132285 + called = False + class A: + def __annotate__(format): + nonlocal called + called = True + return {'a': int} + + self.assertEqual(A.__annotations__, {'a': int}) + self.assertTrue(called) + self.assertTrue(A.__annotate__) + + del A.__annotations__ + called = False + + self.assertEqual(A.__annotations__, {}) + self.assertFalse(called) + self.assertIs(A.__annotate__, None) + + def test_descriptor_still_works(self): + class C: + def __init__(self, name=None, bases=None, d=None): + self.my_annotations = None + + @property + def __annotations__(self): + if not hasattr(self, 'my_annotations'): + self.my_annotations = {} + if not isinstance(self.my_annotations, dict): + self.my_annotations = {} + return self.my_annotations + + @__annotations__.setter + def __annotations__(self, value): + if not isinstance(value, dict): + raise ValueError("can only set __annotations__ to a dict") + self.my_annotations = value + + @__annotations__.deleter + def __annotations__(self): + if getattr(self, 'my_annotations', False) is None: + raise AttributeError('__annotations__') + self.my_annotations = None + + c = C() + self.assertEqual(c.__annotations__, {}) + d = {'a':'int'} + c.__annotations__ = d + self.assertEqual(c.__annotations__, d) + with self.assertRaises(ValueError): + c.__annotations__ = 123 + del c.__annotations__ + with self.assertRaises(AttributeError): + del c.__annotations__ + self.assertEqual(c.__annotations__, {}) + + + class D(metaclass=C): + pass + + self.assertEqual(D.__annotations__, {}) + d = {'a':'int'} + D.__annotations__ = d + self.assertEqual(D.__annotations__, d) + with self.assertRaises(ValueError): + D.__annotations__ = 123 + del D.__annotations__ + with self.assertRaises(AttributeError): + del D.__annotations__ + self.assertEqual(D.__annotations__, {}) + + def test_partially_executed_module(self): + partialexe = import_helper.import_fresh_module("test.typinganndata.partialexecution") + self.assertEqual( + partialexe.a.__annotations__, + {"v1": int, "v2": int}, + ) + self.assertEqual(partialexe.b.annos, {"v1": int}) + + @cpython_only + def test_no_cell(self): + # gh-130924: Test that uses of annotations in local scopes do not + # create cell variables. + def f(x): + a: x + return x + + self.assertEqual(f.__code__.co_cellvars, ()) + + +def build_module(code: str, name: str = "top") -> types.ModuleType: + ns = run_code(code) + mod = types.ModuleType(name) + mod.__dict__.update(ns) + return mod + + +class TestSetupAnnotations(unittest.TestCase): + def check(self, code: str): + code = textwrap.dedent(code) + for scope in ("module", "class"): + with self.subTest(scope=scope): + if scope == "class": + code = f"class C:\n{textwrap.indent(code, ' ')}" + ns = run_code(code) + annotations = ns["C"].__annotations__ + else: + annotations = build_module(code).__annotations__ + self.assertEqual(annotations, {"x": int}) + + def test_top_level(self): + self.check("x: int = 1") + + def test_blocks(self): + self.check("if True:\n x: int = 1") + self.check(""" + while True: + x: int = 1 + break + """) + self.check(""" + while False: + pass + else: + x: int = 1 + """) + self.check(""" + for i in range(1): + x: int = 1 + """) + self.check(""" + for i in range(1): + pass + else: + x: int = 1 + """) + + def test_try(self): + self.check(""" + try: + x: int = 1 + except: + pass + """) + self.check(""" + try: + pass + except: + pass + else: + x: int = 1 + """) + self.check(""" + try: + pass + except: + pass + finally: + x: int = 1 + """) + self.check(""" + try: + 1/0 + except: + x: int = 1 + """) + + def test_try_star(self): + self.check(""" + try: + x: int = 1 + except* Exception: + pass + """) + self.check(""" + try: + pass + except* Exception: + pass + else: + x: int = 1 + """) + self.check(""" + try: + pass + except* Exception: + pass + finally: + x: int = 1 + """) + self.check(""" + try: + 1/0 + except* Exception: + x: int = 1 + """) + + def test_match(self): + self.check(""" + match 0: + case 0: + x: int = 1 + """) + + +class AnnotateTests(unittest.TestCase): + """See PEP 649.""" + def test_manual_annotate(self): + def f(): + pass + mod = types.ModuleType("mod") + class X: + pass + + for obj in (f, mod, X): + with self.subTest(obj=obj): + self.check_annotations(obj) + + def check_annotations(self, f): + self.assertEqual(f.__annotations__, {}) + self.assertIs(f.__annotate__, None) + + with self.assertRaisesRegex(TypeError, "__annotate__ must be callable or None"): + f.__annotate__ = 42 + f.__annotate__ = lambda: 42 + with self.assertRaisesRegex(TypeError, r"takes 0 positional arguments but 1 was given"): + print(f.__annotations__) + + f.__annotate__ = lambda x: 42 + with self.assertRaisesRegex(TypeError, r"__annotate__ returned non-dict of type 'int'"): + print(f.__annotations__) + + f.__annotate__ = lambda x: {"x": x} + self.assertEqual(f.__annotations__, {"x": 1}) + + # Setting annotate to None does not invalidate the cached __annotations__ + f.__annotate__ = None + self.assertEqual(f.__annotations__, {"x": 1}) + + # But setting it to a new callable does + f.__annotate__ = lambda x: {"y": x} + self.assertEqual(f.__annotations__, {"y": 1}) + + # Setting f.__annotations__ also clears __annotate__ + f.__annotations__ = {"z": 43} + self.assertIs(f.__annotate__, None) + + def test_user_defined_annotate(self): + class X: + a: int + + def __annotate__(format): + return {"a": str} + self.assertEqual(X.__annotate__(annotationlib.Format.VALUE), {"a": str}) + self.assertEqual(annotationlib.get_annotations(X), {"a": str}) + + mod = build_module( + """ + a: int + def __annotate__(format): + return {"a": str} + """ + ) + self.assertEqual(mod.__annotate__(annotationlib.Format.VALUE), {"a": str}) + self.assertEqual(annotationlib.get_annotations(mod), {"a": str}) + + +class DeferredEvaluationTests(unittest.TestCase): + def test_function(self): + def func(x: undefined, /, y: undefined, *args: undefined, z: undefined, **kwargs: undefined) -> undefined: + pass + + with self.assertRaises(NameError): + func.__annotations__ + + undefined = 1 + self.assertEqual(func.__annotations__, { + "x": 1, + "y": 1, + "args": 1, + "z": 1, + "kwargs": 1, + "return": 1, + }) + + def test_async_function(self): + async def func(x: undefined, /, y: undefined, *args: undefined, z: undefined, **kwargs: undefined) -> undefined: + pass + + with self.assertRaises(NameError): + func.__annotations__ + + undefined = 1 + self.assertEqual(func.__annotations__, { + "x": 1, + "y": 1, + "args": 1, + "z": 1, + "kwargs": 1, + "return": 1, + }) + + def test_class(self): + class X: + a: undefined + + with self.assertRaises(NameError): + X.__annotations__ + + undefined = 1 + self.assertEqual(X.__annotations__, {"a": 1}) + + def test_module(self): + ns = run_code("x: undefined = 1") + anno = ns["__annotate__"] + with self.assertRaises(NotImplementedError): + anno(3) + + with self.assertRaises(NameError): + anno(1) + + ns["undefined"] = 1 + self.assertEqual(anno(1), {"x": 1}) + + def test_class_scoping(self): + class Outer: + def meth(self, x: Nested): ... + x: Nested + class Nested: ... + + self.assertEqual(Outer.meth.__annotations__, {"x": Outer.Nested}) + self.assertEqual(Outer.__annotations__, {"x": Outer.Nested}) + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_no_exotic_expressions(self): + preludes = [ + "", + "class X:\n ", + "def f():\n ", + "async def f():\n ", + ] + for prelude in preludes: + with self.subTest(prelude=prelude): + check_syntax_error(self, prelude + "def func(x: (yield)): ...", "yield expression cannot be used within an annotation") + check_syntax_error(self, prelude + "def func(x: (yield from x)): ...", "yield expression cannot be used within an annotation") + check_syntax_error(self, prelude + "def func(x: (y := 3)): ...", "named expression cannot be used within an annotation") + check_syntax_error(self, prelude + "def func(x: (await 42)): ...", "await expression cannot be used within an annotation") + check_syntax_error(self, prelude + "def func(x: [y async for y in x]): ...", "asynchronous comprehension outside of an asynchronous function") + check_syntax_error(self, prelude + "def func(x: {y async for y in x}): ...", "asynchronous comprehension outside of an asynchronous function") + check_syntax_error(self, prelude + "def func(x: {y: y async for y in x}): ...", "asynchronous comprehension outside of an asynchronous function") + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_no_exotic_expressions_in_unevaluated_annotations(self): + preludes = [ + "", + "class X: ", + "def f(): ", + "async def f(): ", + ] + for prelude in preludes: + with self.subTest(prelude=prelude): + check_syntax_error(self, prelude + "(x): (yield)", "yield expression cannot be used within an annotation") + check_syntax_error(self, prelude + "(x): (yield from x)", "yield expression cannot be used within an annotation") + check_syntax_error(self, prelude + "(x): (y := 3)", "named expression cannot be used within an annotation") + check_syntax_error(self, prelude + "(x): (__debug__ := 3)", "named expression cannot be used within an annotation") + check_syntax_error(self, prelude + "(x): (await 42)", "await expression cannot be used within an annotation") + check_syntax_error(self, prelude + "(x): [y async for y in x]", "asynchronous comprehension outside of an asynchronous function") + check_syntax_error(self, prelude + "(x): {y async for y in x}", "asynchronous comprehension outside of an asynchronous function") + check_syntax_error(self, prelude + "(x): {y: y async for y in x}", "asynchronous comprehension outside of an asynchronous function") + + def test_ignore_non_simple_annotations(self): + ns = run_code("class X: (y): int") + self.assertEqual(ns["X"].__annotations__, {}) + ns = run_code("class X: int.b: int") + self.assertEqual(ns["X"].__annotations__, {}) + ns = run_code("class X: int[str]: int") + self.assertEqual(ns["X"].__annotations__, {}) + + def test_generated_annotate(self): + def func(x: int): + pass + class X: + x: int + mod = build_module("x: int") + for obj in (func, X, mod): + with self.subTest(obj=obj): + annotate = obj.__annotate__ + self.assertIsInstance(annotate, types.FunctionType) + self.assertEqual(annotate.__name__, "__annotate__") + with self.assertRaises(NotImplementedError): + annotate(annotationlib.Format.FORWARDREF) + with self.assertRaises(NotImplementedError): + annotate(annotationlib.Format.STRING) + with self.assertRaises(TypeError): + annotate(None) + self.assertEqual(annotate(annotationlib.Format.VALUE), {"x": int}) + + sig = inspect.signature(annotate) + self.assertEqual(sig, inspect.Signature([ + inspect.Parameter("format", inspect.Parameter.POSITIONAL_ONLY) + ])) + + def test_comprehension_in_annotation(self): + # This crashed in an earlier version of the code + ns = run_code("x: [y for y in range(10)]") + self.assertEqual(ns["__annotate__"](1), {"x": list(range(10))}) + + def test_future_annotations(self): + code = """ + from __future__ import annotations + + def f(x: int) -> int: pass + """ + ns = run_code(code) + f = ns["f"] + self.assertIsInstance(f.__annotate__, types.FunctionType) + annos = {"x": "int", "return": "int"} + self.assertEqual(f.__annotate__(annotationlib.Format.VALUE), annos) + self.assertEqual(f.__annotations__, annos) + + def test_set_annotations(self): + function_code = textwrap.dedent(""" + def f(x: int): + pass + """) + class_code = textwrap.dedent(""" + class f: + x: int + """) + for future in (False, True): + for label, code in (("function", function_code), ("class", class_code)): + with self.subTest(future=future, label=label): + if future: + code = "from __future__ import annotations\n" + code + ns = run_code(code) + f = ns["f"] + anno = "int" if future else int + self.assertEqual(f.__annotations__, {"x": anno}) + + f.__annotations__ = {"x": str} + self.assertEqual(f.__annotations__, {"x": str}) + + def test_name_clash_with_format(self): + # this test would fail if __annotate__'s parameter was called "format" + # during symbol table construction + code = """ + class format: pass + + def f(x: format): pass + """ + ns = run_code(code) + f = ns["f"] + self.assertEqual(f.__annotations__, {"x": ns["format"]}) + + code = """ + class Outer: + class format: pass + + def meth(self, x: format): ... + """ + ns = run_code(code) + self.assertEqual(ns["Outer"].meth.__annotations__, {"x": ns["Outer"].format}) + + code = """ + def f(format): + def inner(x: format): pass + return inner + res = f("closure var") + """ + ns = run_code(code) + self.assertEqual(ns["res"].__annotations__, {"x": "closure var"}) + + code = """ + def f(x: format): + pass + """ + ns = run_code(code) + # picks up the format() builtin + self.assertEqual(ns["f"].__annotations__, {"x": format}) + + code = """ + def outer(): + def f(x: format): + pass + if False: + class format: pass + return f + f = outer() + """ + ns = run_code(code) + with self.assertRaisesRegex( + NameError, + "cannot access free variable 'format' where it is not associated with a value in enclosing scope", + ): + ns["f"].__annotations__ + + +class ConditionalAnnotationTests(unittest.TestCase): + def check_scopes(self, code, true_annos, false_annos): + for scope in ("class", "module"): + for (cond, expected) in ( + # Constants (so code might get optimized out) + (True, true_annos), (False, false_annos), + # Non-constant expressions + ("not not len", true_annos), ("not len", false_annos), + ): + with self.subTest(scope=scope, cond=cond): + code_to_run = code.format(cond=cond) + if scope == "class": + code_to_run = "class Cls:\n" + textwrap.indent(textwrap.dedent(code_to_run), " " * 4) + ns = run_code(code_to_run) + if scope == "class": + self.assertEqual(ns["Cls"].__annotations__, expected) + else: + self.assertEqual(ns["__annotate__"](annotationlib.Format.VALUE), + expected) + + def test_with(self): + code = """ + class Swallower: + def __enter__(self): + pass + + def __exit__(self, *args): + return True + + with Swallower(): + if {cond}: + about_to_raise: int + raise Exception + in_with: "with" + """ + self.check_scopes(code, {"about_to_raise": int}, {"in_with": "with"}) + + def test_simple_if(self): + code = """ + if {cond}: + in_if: "if" + else: + in_if: "else" + """ + self.check_scopes(code, {"in_if": "if"}, {"in_if": "else"}) + + def test_if_elif(self): + code = """ + if not len: + in_if: "if" + elif {cond}: + in_elif: "elif" + else: + in_else: "else" + """ + self.check_scopes( + code, + {"in_elif": "elif"}, + {"in_else": "else"} + ) + + def test_try(self): + code = """ + try: + if {cond}: + raise Exception + in_try: "try" + except Exception: + in_except: "except" + finally: + in_finally: "finally" + """ + self.check_scopes( + code, + {"in_except": "except", "in_finally": "finally"}, + {"in_try": "try", "in_finally": "finally"} + ) + + def test_try_star(self): + code = """ + try: + if {cond}: + raise Exception + in_try_star: "try" + except* Exception: + in_except_star: "except" + finally: + in_finally: "finally" + """ + self.check_scopes( + code, + {"in_except_star": "except", "in_finally": "finally"}, + {"in_try_star": "try", "in_finally": "finally"} + ) + + def test_while(self): + code = """ + while {cond}: + in_while: "while" + break + else: + in_else: "else" + """ + self.check_scopes( + code, + {"in_while": "while"}, + {"in_else": "else"} + ) + + def test_for(self): + code = """ + for _ in ([1] if {cond} else []): + in_for: "for" + else: + in_else: "else" + """ + self.check_scopes( + code, + {"in_for": "for", "in_else": "else"}, + {"in_else": "else"} + ) + + def test_match(self): + code = """ + match {cond}: + case True: + x: "true" + case False: + x: "false" + """ + self.check_scopes( + code, + {"x": "true"}, + {"x": "false"} + ) + + def test_nesting_override(self): + code = """ + if {cond}: + x: "foo" + if {cond}: + x: "bar" + """ + self.check_scopes( + code, + {"x": "bar"}, + {} + ) + + def test_nesting_outer(self): + code = """ + if {cond}: + outer_before: "outer_before" + if len: + inner_if: "inner_if" + else: + inner_else: "inner_else" + outer_after: "outer_after" + """ + self.check_scopes( + code, + {"outer_before": "outer_before", "inner_if": "inner_if", + "outer_after": "outer_after"}, + {} + ) + + def test_nesting_inner(self): + code = """ + if len: + outer_before: "outer_before" + if {cond}: + inner_if: "inner_if" + else: + inner_else: "inner_else" + outer_after: "outer_after" + """ + self.check_scopes( + code, + {"outer_before": "outer_before", "inner_if": "inner_if", + "outer_after": "outer_after"}, + {"outer_before": "outer_before", "inner_else": "inner_else", + "outer_after": "outer_after"}, + ) + + def test_non_name_annotations(self): + code = """ + before: "before" + if {cond}: + a = "x" + a[0]: int + else: + a = object() + a.b: str + after: "after" + """ + expected = {"before": "before", "after": "after"} + self.check_scopes(code, expected, expected) + + +class RegressionTests(unittest.TestCase): + # gh-132479 + @unittest.expectedFailure # TODO: RUSTPYTHON; SyntaxError: the symbol 'unique_name_6' must be present in the symbol table + def test_complex_comprehension_inlining(self): + # Test that the various repro cases from the issue don't crash + cases = [ + """ + (unique_name_0): 0 + unique_name_1: ( + 0 + for ( + 0 + for unique_name_2 in 0 + for () in (0 for unique_name_3 in unique_name_4 for unique_name_5 in name_1) + ).name_3 in {0: 0 for name_1 in unique_name_8} + if name_1 + ) + """, + """ + unique_name_0: 0 + unique_name_1: { + 0: 0 + for unique_name_2 in [0 for name_0 in unique_name_4] + if { + 0: 0 + for unique_name_5 in 0 + if name_0 + if ((name_0 for unique_name_8 in unique_name_9) for [] in 0) + } + } + """, + """ + 0[0]: {0 for name_0 in unique_name_1} + unique_name_2: { + 0: (lambda: name_0 for unique_name_4 in unique_name_5) + for unique_name_6 in () + if name_0 + } + """, + ] + for case in cases: + case = textwrap.dedent(case) + compile(case, "", "exec") + + def test_complex_comprehension_inlining_exec(self): + code = """ + unique_name_1 = unique_name_5 = [1] + name_0 = 42 + unique_name_7: {name_0 for name_0 in unique_name_1} + unique_name_2: { + 0: (lambda: name_0 for unique_name_4 in unique_name_5) + for unique_name_6 in [1] + if name_0 + } + """ + mod = build_module(code) + annos = mod.__annotations__ + self.assertEqual(annos.keys(), {"unique_name_7", "unique_name_2"}) + self.assertEqual(annos["unique_name_7"], {True}) + genexp = annos["unique_name_2"][0] + lamb = list(genexp)[0] + self.assertEqual(lamb(), 42) + + # gh-138349 + def test_module_level_annotation_plus_listcomp(self): + cases = [ + """ + def report_error(): + pass + try: + [0 for name_2 in unique_name_0 if (lambda: name_2)] + except: + pass + annotated_name: 0 + """, + """ + class Generic: + pass + try: + [0 for name_2 in unique_name_0 if (0 for unique_name_1 in unique_name_2 for unique_name_3 in name_2)] + except: + pass + annotated_name: 0 + """, + """ + class Generic: + pass + annotated_name: 0 + try: + [0 for name_2 in [[0]] for unique_name_1 in unique_name_2 if (lambda: name_2)] + except: + pass + """, + ] + for code in cases: + with self.subTest(code=code): + mod = build_module(code) + annos = mod.__annotations__ + self.assertEqual(annos, {"annotated_name": 0}) diff --git a/Lib/test/test_type_params.py b/Lib/test/test_type_params.py new file mode 100644 index 00000000000..07b4957adec --- /dev/null +++ b/Lib/test/test_type_params.py @@ -0,0 +1,1484 @@ +import annotationlib +import textwrap +import types +import unittest +import pickle +import weakref +from test.support import check_syntax_error, run_code, run_no_yield_async_fn + +from typing import Generic, NoDefault, Sequence, TypeAliasType, TypeVar, TypeVarTuple, ParamSpec, get_args + + +class TypeParamsInvalidTest(unittest.TestCase): + def test_name_collisions(self): + check_syntax_error(self, 'def func[**A, A](): ...', "duplicate type parameter 'A'") + check_syntax_error(self, 'def func[A, *A](): ...', "duplicate type parameter 'A'") + check_syntax_error(self, 'def func[*A, **A](): ...', "duplicate type parameter 'A'") + + check_syntax_error(self, 'class C[**A, A](): ...', "duplicate type parameter 'A'") + check_syntax_error(self, 'class C[A, *A](): ...', "duplicate type parameter 'A'") + check_syntax_error(self, 'class C[*A, **A](): ...', "duplicate type parameter 'A'") + + def test_name_non_collision_02(self): + ns = run_code("""def func[A](A): return A""") + func = ns["func"] + self.assertEqual(func(1), 1) + A, = func.__type_params__ + self.assertEqual(A.__name__, "A") + + def test_name_non_collision_03(self): + ns = run_code("""def func[A](*A): return A""") + func = ns["func"] + self.assertEqual(func(1), (1,)) + A, = func.__type_params__ + self.assertEqual(A.__name__, "A") + + def test_name_non_collision_04(self): + # Mangled names should not cause a conflict. + ns = run_code(""" + class ClassA: + def func[__A](self, __A): return __A + """ + ) + cls = ns["ClassA"] + self.assertEqual(cls().func(1), 1) + A, = cls.func.__type_params__ + self.assertEqual(A.__name__, "__A") + + def test_name_non_collision_05(self): + ns = run_code(""" + class ClassA: + def func[_ClassA__A](self, __A): return __A + """ + ) + cls = ns["ClassA"] + self.assertEqual(cls().func(1), 1) + A, = cls.func.__type_params__ + self.assertEqual(A.__name__, "_ClassA__A") + + def test_name_non_collision_06(self): + ns = run_code(""" + class ClassA[X]: + def func(self, X): return X + """ + ) + cls = ns["ClassA"] + self.assertEqual(cls().func(1), 1) + X, = cls.__type_params__ + self.assertEqual(X.__name__, "X") + + def test_name_non_collision_07(self): + ns = run_code(""" + class ClassA[X]: + def func(self): + X = 1 + return X + """ + ) + cls = ns["ClassA"] + self.assertEqual(cls().func(), 1) + X, = cls.__type_params__ + self.assertEqual(X.__name__, "X") + + def test_name_non_collision_08(self): + ns = run_code(""" + class ClassA[X]: + def func(self): + return [X for X in [1, 2]] + """ + ) + cls = ns["ClassA"] + self.assertEqual(cls().func(), [1, 2]) + X, = cls.__type_params__ + self.assertEqual(X.__name__, "X") + + def test_name_non_collision_9(self): + ns = run_code(""" + class ClassA[X]: + def func[X](self): + ... + """ + ) + cls = ns["ClassA"] + outer_X, = cls.__type_params__ + inner_X, = cls.func.__type_params__ + self.assertEqual(outer_X.__name__, "X") + self.assertEqual(inner_X.__name__, "X") + self.assertIsNot(outer_X, inner_X) + + def test_name_non_collision_10(self): + ns = run_code(""" + class ClassA[X]: + X: int + """ + ) + cls = ns["ClassA"] + X, = cls.__type_params__ + self.assertEqual(X.__name__, "X") + self.assertIs(cls.__annotations__["X"], int) + + def test_name_non_collision_13(self): + ns = run_code(""" + X = 1 + def outer(): + def inner[X](): + global X + X = 2 + return inner + """ + ) + self.assertEqual(ns["X"], 1) + outer = ns["outer"] + outer()() + self.assertEqual(ns["X"], 2) + + def test_disallowed_expressions(self): + check_syntax_error(self, "type X = (yield)") + check_syntax_error(self, "type X = (yield from x)") + check_syntax_error(self, "type X = (await 42)") + check_syntax_error(self, "async def f(): type X = (yield)") + check_syntax_error(self, "type X = (y := 3)") + check_syntax_error(self, "class X[T: (yield)]: pass") + check_syntax_error(self, "class X[T: (yield from x)]: pass") + check_syntax_error(self, "class X[T: (await 42)]: pass") + check_syntax_error(self, "class X[T: (y := 3)]: pass") + check_syntax_error(self, "class X[T](y := Sequence[T]): pass") + check_syntax_error(self, "def f[T](y: (x := Sequence[T])): pass") + check_syntax_error(self, "class X[T]([(x := 3) for _ in range(2)] and B): pass") + check_syntax_error(self, "def f[T: [(x := 3) for _ in range(2)]](): pass") + check_syntax_error(self, "type T = [(x := 3) for _ in range(2)]") + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: "\(MRO\) for bases object, Generic" does not match "Unable to find mro order which keeps local precedence ordering" + def test_incorrect_mro_explicit_object(self): + with self.assertRaisesRegex(TypeError, r"\(MRO\) for bases object, Generic"): + class My[X](object): ... + + +class TypeParamsNonlocalTest(unittest.TestCase): + def test_nonlocal_disallowed_01(self): + code = """ + def outer(): + X = 1 + def inner[X](): + nonlocal X + return X + """ + check_syntax_error(self, code) + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised + def test_nonlocal_disallowed_02(self): + code = """ + def outer2[T](): + def inner1(): + nonlocal T + """ + check_syntax_error(self, textwrap.dedent(code)) + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised + def test_nonlocal_disallowed_03(self): + code = """ + class Cls[T]: + nonlocal T + """ + check_syntax_error(self, textwrap.dedent(code)) + + def test_nonlocal_allowed(self): + code = """ + def func[T](): + T = "func" + def inner(): + nonlocal T + T = "inner" + inner() + assert T == "inner" + """ + ns = run_code(code) + func = ns["func"] + T, = func.__type_params__ + self.assertEqual(T.__name__, "T") + + +class TypeParamsAccessTest(unittest.TestCase): + def test_class_access_01(self): + ns = run_code(""" + class ClassA[A, B](dict[A, B]): + ... + """ + ) + cls = ns["ClassA"] + A, B = cls.__type_params__ + self.assertEqual(types.get_original_bases(cls), (dict[A, B], Generic[A, B])) + + def test_class_access_02(self): + ns = run_code(""" + class MyMeta[A, B](type): ... + class ClassA[A, B](metaclass=MyMeta[A, B]): + ... + """ + ) + meta = ns["MyMeta"] + cls = ns["ClassA"] + A1, B1 = meta.__type_params__ + A2, B2 = cls.__type_params__ + self.assertIsNot(A1, A2) + self.assertIsNot(B1, B2) + self.assertIs(type(cls), meta) + + def test_class_access_03(self): + code = """ + def my_decorator(a): + ... + @my_decorator(A) + class ClassA[A, B](): + ... + """ + + with self.assertRaisesRegex(NameError, "name 'A' is not defined"): + run_code(code) + + def test_function_access_01(self): + ns = run_code(""" + def func[A, B](a: dict[A, B]): + ... + """ + ) + func = ns["func"] + A, B = func.__type_params__ + self.assertEqual(func.__annotations__["a"], dict[A, B]) + + @unittest.expectedFailure # TODO: RUSTPYTHON; SyntaxError: the symbol 'list' must be present in the symbol table + def test_function_access_02(self): + code = """ + def func[A](a = list[A]()): + ... + """ + + with self.assertRaisesRegex(NameError, "name 'A' is not defined"): + run_code(code) + + def test_function_access_03(self): + code = """ + def my_decorator(a): + ... + @my_decorator(A) + def func[A](): + ... + """ + + with self.assertRaisesRegex(NameError, "name 'A' is not defined"): + run_code(code) + + @unittest.expectedFailure # TODO: RUSTPYTHON; NameError: name 'x' is not defined + def test_method_access_01(self): + ns = run_code(""" + class ClassA: + x = int + def func[T](self, a: x, b: T): + ... + """ + ) + cls = ns["ClassA"] + self.assertIs(cls.func.__annotations__["a"], int) + T, = cls.func.__type_params__ + self.assertIs(cls.func.__annotations__["b"], T) + + def test_nested_access_01(self): + ns = run_code(""" + class ClassA[A]: + def funcB[B](self): + class ClassC[C]: + def funcD[D](self): + return lambda: (A, B, C, D) + return ClassC + """ + ) + cls = ns["ClassA"] + A, = cls.__type_params__ + B, = cls.funcB.__type_params__ + classC = cls().funcB() + C, = classC.__type_params__ + D, = classC.funcD.__type_params__ + self.assertEqual(classC().funcD()(), (A, B, C, D)) + + def test_out_of_scope_01(self): + code = """ + class ClassA[T]: ... + x = T + """ + + with self.assertRaisesRegex(NameError, "name 'T' is not defined"): + run_code(code) + + def test_out_of_scope_02(self): + code = """ + class ClassA[A]: + def funcB[B](self): ... + + x = B + """ + + with self.assertRaisesRegex(NameError, "name 'B' is not defined"): + run_code(code) + + @unittest.expectedFailure # TODO: RUSTPYTHON; NameError: name 'x' is not defined + def test_class_scope_interaction_01(self): + ns = run_code(""" + class C: + x = 1 + def method[T](self, arg: x): pass + """) + cls = ns["C"] + self.assertEqual(cls.method.__annotations__["arg"], 1) + + def test_class_scope_interaction_02(self): + ns = run_code(""" + class C: + class Base: pass + class Child[T](Base): pass + """) + cls = ns["C"] + self.assertEqual(cls.Child.__bases__, (cls.Base, Generic)) + T, = cls.Child.__type_params__ + self.assertEqual(types.get_original_bases(cls.Child), (cls.Base, Generic[T])) + + def test_class_deref(self): + ns = run_code(""" + class C[T]: + T = "class" + type Alias = T + """) + cls = ns["C"] + self.assertEqual(cls.Alias.__value__, "class") + + def test_shadowing_nonlocal(self): + ns = run_code(""" + def outer[T](): + T = "outer" + def inner(): + nonlocal T + T = "inner" + return T + return lambda: T, inner + """) + outer = ns["outer"] + T, = outer.__type_params__ + self.assertEqual(T.__name__, "T") + getter, inner = outer() + self.assertEqual(getter(), "outer") + self.assertEqual(inner(), "inner") + self.assertEqual(getter(), "inner") + + def test_reference_previous_typevar(self): + def func[S, T: Sequence[S]](): + pass + + S, T = func.__type_params__ + self.assertEqual(T.__bound__, Sequence[S]) + + def test_super(self): + class Base: + def meth(self): + return "base" + + class Child(Base): + # Having int in the annotation ensures the class gets cells for both + # __class__ and __classdict__ + def meth[T](self, arg: int) -> T: + return super().meth() + "child" + + c = Child() + self.assertEqual(c.meth(1), "basechild") + + def test_type_alias_containing_lambda(self): + type Alias[T] = lambda: T + T, = Alias.__type_params__ + self.assertIs(Alias.__value__(), T) + + def test_class_base_containing_lambda(self): + # Test that scopes nested inside hidden functions work correctly + outer_var = "outer" + class Base[T]: ... + class Child[T](Base[lambda: (int, outer_var, T)]): ... + base, _ = types.get_original_bases(Child) + func, = get_args(base) + T, = Child.__type_params__ + self.assertEqual(func(), (int, "outer", T)) + + def test_comprehension_01(self): + type Alias[T: ([T for T in (T, [1])[1]], T)] = [T for T in T.__name__] + self.assertEqual(Alias.__value__, ["T"]) + T, = Alias.__type_params__ + self.assertEqual(T.__constraints__, ([1], T)) + + def test_comprehension_02(self): + type Alias[T: [lambda: T for T in (T, [1])[1]]] = [lambda: T for T in T.__name__] + func, = Alias.__value__ + self.assertEqual(func(), "T") + T, = Alias.__type_params__ + func, = T.__bound__ + self.assertEqual(func(), 1) + + def test_comprehension_03(self): + def F[T: [lambda: T for T in (T, [1])[1]]](): return [lambda: T for T in T.__name__] + func, = F() + self.assertEqual(func(), "T") + T, = F.__type_params__ + func, = T.__bound__ + self.assertEqual(func(), 1) + + def test_gen_exp_in_nested_class(self): + code = """ + from test.test_type_params import make_base + + class C[T]: + T = "class" + class Inner(make_base(T for _ in (1,)), make_base(T)): + pass + """ + C = run_code(code)["C"] + T, = C.__type_params__ + base1, base2 = C.Inner.__bases__ + self.assertEqual(list(base1.__arg__), [T]) + self.assertEqual(base2.__arg__, "class") + + def test_gen_exp_in_nested_generic_class(self): + code = """ + from test.test_type_params import make_base + + class C[T]: + T = "class" + class Inner[U](make_base(T for _ in (1,)), make_base(T)): + pass + """ + ns = run_code(code) + inner = ns["C"].Inner + base1, base2, _ = inner.__bases__ + self.assertEqual(list(base1.__arg__), [ns["C"].__type_params__[0]]) + self.assertEqual(base2.__arg__, "class") + + def test_listcomp_in_nested_class(self): + code = """ + from test.test_type_params import make_base + + class C[T]: + T = "class" + class Inner(make_base([T for _ in (1,)]), make_base(T)): + pass + """ + C = run_code(code)["C"] + T, = C.__type_params__ + base1, base2 = C.Inner.__bases__ + self.assertEqual(base1.__arg__, [T]) + self.assertEqual(base2.__arg__, "class") + + def test_listcomp_in_nested_generic_class(self): + code = """ + from test.test_type_params import make_base + + class C[T]: + T = "class" + class Inner[U](make_base([T for _ in (1,)]), make_base(T)): + pass + """ + ns = run_code(code) + inner = ns["C"].Inner + base1, base2, _ = inner.__bases__ + self.assertEqual(base1.__arg__, [ns["C"].__type_params__[0]]) + self.assertEqual(base2.__arg__, "class") + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: ~T != 'class' + def test_gen_exp_in_generic_method(self): + code = """ + class C[T]: + T = "class" + def meth[U](x: (T for _ in (1,)), y: T): + pass + """ + ns = run_code(code) + meth = ns["C"].meth + self.assertEqual(list(meth.__annotations__["x"]), [ns["C"].__type_params__[0]]) + self.assertEqual(meth.__annotations__["y"], "class") + + def test_nested_scope_in_generic_alias(self): + code = """ + T = "global" + class C: + T = "class" + {} + """ + cases = [ + "type Alias[T] = (T for _ in (1,))", + "type Alias = (T for _ in (1,))", + "type Alias[T] = [T for _ in (1,)]", + "type Alias = [T for _ in (1,)]", + ] + for case in cases: + with self.subTest(case=case): + ns = run_code(code.format(case)) + alias = ns["C"].Alias + value = list(alias.__value__)[0] + if alias.__type_params__: + self.assertIs(value, alias.__type_params__[0]) + else: + self.assertEqual(value, "global") + + def test_lambda_in_alias_in_class(self): + code = """ + T = "global" + class C: + T = "class" + type Alias = lambda: T + """ + C = run_code(code)["C"] + self.assertEqual(C.Alias.__value__(), "global") + + def test_lambda_in_alias_in_generic_class(self): + code = """ + class C[T]: + T = "class" + type Alias = lambda: T + """ + C = run_code(code)["C"] + self.assertIs(C.Alias.__value__(), C.__type_params__[0]) + + def test_lambda_in_generic_alias_in_class(self): + # A lambda nested in the alias cannot see the class scope, but can see + # a surrounding annotation scope. + code = """ + T = U = "global" + class C: + T = "class" + U = "class" + type Alias[T] = lambda: (T, U) + """ + C = run_code(code)["C"] + T, U = C.Alias.__value__() + self.assertIs(T, C.Alias.__type_params__[0]) + self.assertEqual(U, "global") + + def test_lambda_in_generic_alias_in_generic_class(self): + # A lambda nested in the alias cannot see the class scope, but can see + # a surrounding annotation scope. + code = """ + class C[T, U]: + T = "class" + U = "class" + type Alias[T] = lambda: (T, U) + """ + C = run_code(code)["C"] + T, U = C.Alias.__value__() + self.assertIs(T, C.Alias.__type_params__[0]) + self.assertIs(U, C.__type_params__[1]) + + def test_type_special_case(self): + # https://github.com/python/cpython/issues/119011 + self.assertEqual(type.__type_params__, ()) + self.assertEqual(object.__type_params__, ()) + + +def make_base(arg): + class Base: + __arg__ = arg + return Base + + +def global_generic_func[T](): + pass + +class GlobalGenericClass[T]: + pass + + +class TypeParamsLazyEvaluationTest(unittest.TestCase): + def test_qualname(self): + class Foo[T]: + pass + + def func[T](): + pass + + self.assertEqual(Foo.__qualname__, "TypeParamsLazyEvaluationTest.test_qualname..Foo") + self.assertEqual(func.__qualname__, "TypeParamsLazyEvaluationTest.test_qualname..func") + self.assertEqual(global_generic_func.__qualname__, "global_generic_func") + self.assertEqual(GlobalGenericClass.__qualname__, "GlobalGenericClass") + + def test_recursive_class(self): + class Foo[T: Foo, U: (Foo, Foo)]: + pass + + type_params = Foo.__type_params__ + self.assertEqual(len(type_params), 2) + self.assertEqual(type_params[0].__name__, "T") + self.assertIs(type_params[0].__bound__, Foo) + self.assertEqual(type_params[0].__constraints__, ()) + self.assertIs(type_params[0].__default__, NoDefault) + + self.assertEqual(type_params[1].__name__, "U") + self.assertIs(type_params[1].__bound__, None) + self.assertEqual(type_params[1].__constraints__, (Foo, Foo)) + self.assertIs(type_params[1].__default__, NoDefault) + + def test_evaluation_error(self): + class Foo[T: Undefined, U: (Undefined,)]: + pass + + type_params = Foo.__type_params__ + with self.assertRaises(NameError): + type_params[0].__bound__ + self.assertEqual(type_params[0].__constraints__, ()) + self.assertIs(type_params[1].__bound__, None) + self.assertIs(type_params[0].__default__, NoDefault) + self.assertIs(type_params[1].__default__, NoDefault) + with self.assertRaises(NameError): + type_params[1].__constraints__ + + Undefined = "defined" + self.assertEqual(type_params[0].__bound__, "defined") + self.assertEqual(type_params[0].__constraints__, ()) + + self.assertIs(type_params[1].__bound__, None) + self.assertEqual(type_params[1].__constraints__, ("defined",)) + + +class TypeParamsClassScopeTest(unittest.TestCase): + def test_alias(self): + class X: + T = int + type U = T + self.assertIs(X.U.__value__, int) + + ns = run_code(""" + glb = "global" + class X: + cls = "class" + type U = (glb, cls) + """) + cls = ns["X"] + self.assertEqual(cls.U.__value__, ("global", "class")) + + def test_bound(self): + class X: + T = int + def foo[U: T](self): ... + self.assertIs(X.foo.__type_params__[0].__bound__, int) + + ns = run_code(""" + glb = "global" + class X: + cls = "class" + def foo[T: glb, U: cls](self): ... + """) + cls = ns["X"] + T, U = cls.foo.__type_params__ + self.assertEqual(T.__bound__, "global") + self.assertEqual(U.__bound__, "class") + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: is not + def test_modified_later(self): + class X: + T = int + def foo[U: T](self): ... + type Alias = T + X.T = float + self.assertIs(X.foo.__type_params__[0].__bound__, float) + self.assertIs(X.Alias.__value__, float) + + @unittest.expectedFailure # TODO: RUSTPYTHON; + global + def test_binding_uses_global(self): + ns = run_code(""" + x = "global" + def outer(): + x = "nonlocal" + class Cls: + type Alias = x + val = Alias.__value__ + def meth[T: x](self, arg: x): ... + bound = meth.__type_params__[0].__bound__ + annotation = meth.__annotations__["arg"] + x = "class" + return Cls + """) + cls = ns["outer"]() + self.assertEqual(cls.val, "global") + self.assertEqual(cls.bound, "global") + self.assertEqual(cls.annotation, "global") + + def test_no_binding_uses_nonlocal(self): + ns = run_code(""" + x = "global" + def outer(): + x = "nonlocal" + class Cls: + type Alias = x + val = Alias.__value__ + def meth[T: x](self, arg: x): ... + bound = meth.__type_params__[0].__bound__ + return Cls + """) + cls = ns["outer"]() + self.assertEqual(cls.val, "nonlocal") + self.assertEqual(cls.bound, "nonlocal") + self.assertEqual(cls.meth.__annotations__["arg"], "nonlocal") + + @unittest.expectedFailure # TODO: RUSTPYTHON; + global + def test_explicit_global(self): + ns = run_code(""" + x = "global" + def outer(): + x = "nonlocal" + class Cls: + global x + type Alias = x + Cls.x = "class" + return Cls + """) + cls = ns["outer"]() + self.assertEqual(cls.Alias.__value__, "global") + + def test_explicit_global_with_no_static_bound(self): + ns = run_code(""" + def outer(): + class Cls: + global x + type Alias = x + Cls.x = "class" + return Cls + """) + ns["x"] = "global" + cls = ns["outer"]() + self.assertEqual(cls.Alias.__value__, "global") + + @unittest.expectedFailure # TODO: RUSTPYTHON; + global from class + def test_explicit_global_with_assignment(self): + ns = run_code(""" + x = "global" + def outer(): + x = "nonlocal" + class Cls: + global x + type Alias = x + x = "global from class" + Cls.x = "class" + return Cls + """) + cls = ns["outer"]() + self.assertEqual(cls.Alias.__value__, "global from class") + + def test_explicit_nonlocal(self): + ns = run_code(""" + x = "global" + def outer(): + x = "nonlocal" + class Cls: + nonlocal x + type Alias = x + x = "class" + return Cls + """) + cls = ns["outer"]() + self.assertEqual(cls.Alias.__value__, "class") + + def test_nested_free(self): + ns = run_code(""" + def f(): + T = str + class C: + T = int + class D[U](T): + x = T + return C + """) + C = ns["f"]() + self.assertIn(int, C.D.__bases__) + self.assertIs(C.D.x, str) + + +class DynamicClassTest(unittest.TestCase): + def _set_type_params(self, ns, params): + ns['__type_params__'] = params + + def test_types_new_class_with_callback(self): + T = TypeVar('T', infer_variance=True) + Klass = types.new_class('Klass', (Generic[T],), {}, + lambda ns: self._set_type_params(ns, (T,))) + + self.assertEqual(Klass.__bases__, (Generic,)) + self.assertEqual(Klass.__orig_bases__, (Generic[T],)) + self.assertEqual(Klass.__type_params__, (T,)) + self.assertEqual(Klass.__parameters__, (T,)) + + def test_types_new_class_no_callback(self): + T = TypeVar('T', infer_variance=True) + Klass = types.new_class('Klass', (Generic[T],), {}) + + self.assertEqual(Klass.__bases__, (Generic,)) + self.assertEqual(Klass.__orig_bases__, (Generic[T],)) + self.assertEqual(Klass.__type_params__, ()) # must be explicitly set + self.assertEqual(Klass.__parameters__, (T,)) + + +class TypeParamsManglingTest(unittest.TestCase): + def test_mangling(self): + class Foo[__T]: + param = __T + def meth[__U](self, arg: __T, arg2: __U): + return (__T, __U) + type Alias[__V] = (__T, __V) + + T = Foo.__type_params__[0] + self.assertEqual(T.__name__, "__T") + U = Foo.meth.__type_params__[0] + self.assertEqual(U.__name__, "__U") + V = Foo.Alias.__type_params__[0] + self.assertEqual(V.__name__, "__V") + + anno = Foo.meth.__annotations__ + self.assertIs(anno["arg"], T) + self.assertIs(anno["arg2"], U) + self.assertEqual(Foo().meth(1, 2), (T, U)) + + self.assertEqual(Foo.Alias.__value__, (T, V)) + + def test_no_leaky_mangling_in_module(self): + ns = run_code(""" + __before = "before" + class X[T]: pass + __after = "after" + """) + self.assertEqual(ns["__before"], "before") + self.assertEqual(ns["__after"], "after") + + def test_no_leaky_mangling_in_function(self): + ns = run_code(""" + def f(): + class X[T]: pass + _X_foo = 2 + __foo = 1 + assert locals()['__foo'] == 1 + return __foo + """) + self.assertEqual(ns["f"](), 1) + + def test_no_leaky_mangling_in_class(self): + ns = run_code(""" + class Outer: + __before = "before" + class Inner[T]: + __x = "inner" + __after = "after" + """) + Outer = ns["Outer"] + self.assertEqual(Outer._Outer__before, "before") + self.assertEqual(Outer.Inner._Inner__x, "inner") + self.assertEqual(Outer._Outer__after, "after") + + @unittest.expectedFailure # TODO: RUSTPYTHON; NameError: name '_Derived__Base' is not defined + def test_no_mangling_in_bases(self): + ns = run_code(""" + class __Base: + def __init_subclass__(self, **kwargs): + self.kwargs = kwargs + + class Derived[T](__Base, __kwarg=1): + pass + """) + Derived = ns["Derived"] + self.assertEqual(Derived.__bases__, (ns["__Base"], Generic)) + self.assertEqual(Derived.kwargs, {"__kwarg": 1}) + + @unittest.expectedFailure # TODO: RUSTPYTHON; SyntaxError: the symbol '_Y__X' must be present in the symbol table + def test_no_mangling_in_nested_scopes(self): + ns = run_code(""" + from test.test_type_params import make_base + + class __X: + pass + + class Y[T: __X]( + make_base(lambda: __X), + # doubly nested scope + make_base(lambda: (lambda: __X)), + # list comprehension + make_base([__X for _ in (1,)]), + # genexp + make_base(__X for _ in (1,)), + ): + pass + """) + Y = ns["Y"] + T, = Y.__type_params__ + self.assertIs(T.__bound__, ns["__X"]) + base0 = Y.__bases__[0] + self.assertIs(base0.__arg__(), ns["__X"]) + base1 = Y.__bases__[1] + self.assertIs(base1.__arg__()(), ns["__X"]) + base2 = Y.__bases__[2] + self.assertEqual(base2.__arg__, [ns["__X"]]) + base3 = Y.__bases__[3] + self.assertEqual(list(base3.__arg__), [ns["__X"]]) + + @unittest.expectedFailure # TODO: RUSTPYTHON; SyntaxError: the symbol '_Foo__T' must be present in the symbol table + def test_type_params_are_mangled(self): + ns = run_code(""" + from test.test_type_params import make_base + + class Foo[__T, __U: __T](make_base(__T), make_base(lambda: __T)): + param = __T + """) + Foo = ns["Foo"] + T, U = Foo.__type_params__ + self.assertEqual(T.__name__, "__T") + self.assertEqual(U.__name__, "__U") + self.assertIs(U.__bound__, T) + self.assertIs(Foo.param, T) + + base1, base2, *_ = Foo.__bases__ + self.assertIs(base1.__arg__, T) + self.assertIs(base2.__arg__(), T) + + +class TypeParamsComplexCallsTest(unittest.TestCase): + def test_defaults(self): + # Generic functions with both defaults and kwdefaults trigger a specific code path + # in the compiler. + def func[T](a: T = "a", *, b: T = "b"): + return (a, b) + + T, = func.__type_params__ + self.assertIs(func.__annotations__["a"], T) + self.assertIs(func.__annotations__["b"], T) + self.assertEqual(func(), ("a", "b")) + self.assertEqual(func(1), (1, "b")) + self.assertEqual(func(b=2), ("a", 2)) + + def test_complex_base(self): + class Base: + def __init_subclass__(cls, **kwargs) -> None: + cls.kwargs = kwargs + + kwargs = {"c": 3} + # Base classes with **kwargs trigger a different code path in the compiler. + class C[T](Base, a=1, b=2, **kwargs): + pass + + T, = C.__type_params__ + self.assertEqual(T.__name__, "T") + self.assertEqual(C.kwargs, {"a": 1, "b": 2, "c": 3}) + self.assertEqual(C.__bases__, (Base, Generic)) + + bases = (Base,) + class C2[T](*bases, **kwargs): + pass + + T, = C2.__type_params__ + self.assertEqual(T.__name__, "T") + self.assertEqual(C2.kwargs, {"c": 3}) + self.assertEqual(C2.__bases__, (Base, Generic)) + + def test_starargs_base(self): + class C1[T](*()): pass + + T, = C1.__type_params__ + self.assertEqual(T.__name__, "T") + self.assertEqual(C1.__bases__, (Generic,)) + + class Base: pass + bases = [Base] + class C2[T](*bases): pass + + T, = C2.__type_params__ + self.assertEqual(T.__name__, "T") + self.assertEqual(C2.__bases__, (Base, Generic)) + + +class TypeParamsTraditionalTypeVarsTest(unittest.TestCase): + def test_traditional_01(self): + code = """ + from typing import Generic + class ClassA[T](Generic[T]): ... + """ + + with self.assertRaisesRegex(TypeError, r"Cannot inherit from Generic\[...\] multiple times."): + run_code(code) + + def test_traditional_02(self): + from typing import TypeVar + S = TypeVar("S") + with self.assertRaises(TypeError): + class ClassA[T](dict[T, S]): ... + + def test_traditional_03(self): + # This does not generate a runtime error, but it should be + # flagged as an error by type checkers. + from typing import TypeVar + S = TypeVar("S") + def func[T](a: T, b: S) -> T | S: + return a + + +class TypeParamsTypeVarTest(unittest.TestCase): + def test_typevar_01(self): + def func1[A: str, B: str | int, C: (int, str)](): + return (A, B, C) + + a, b, c = func1() + + self.assertIsInstance(a, TypeVar) + self.assertEqual(a.__bound__, str) + self.assertTrue(a.__infer_variance__) + self.assertFalse(a.__covariant__) + self.assertFalse(a.__contravariant__) + + self.assertIsInstance(b, TypeVar) + self.assertEqual(b.__bound__, str | int) + self.assertTrue(b.__infer_variance__) + self.assertFalse(b.__covariant__) + self.assertFalse(b.__contravariant__) + + self.assertIsInstance(c, TypeVar) + self.assertEqual(c.__bound__, None) + self.assertEqual(c.__constraints__, (int, str)) + self.assertTrue(c.__infer_variance__) + self.assertFalse(c.__covariant__) + self.assertFalse(c.__contravariant__) + + def test_typevar_generator(self): + def get_generator[A](): + def generator1[C](): + yield C + + def generator2[B](): + yield A + yield B + yield from generator1() + return generator2 + + gen = get_generator() + + a, b, c = [x for x in gen()] + + self.assertIsInstance(a, TypeVar) + self.assertEqual(a.__name__, "A") + self.assertIsInstance(b, TypeVar) + self.assertEqual(b.__name__, "B") + self.assertIsInstance(c, TypeVar) + self.assertEqual(c.__name__, "C") + + def test_typevar_coroutine(self): + def get_coroutine[A](): + async def coroutine[B](): + return (A, B) + return coroutine + + co = get_coroutine() + + a, b = run_no_yield_async_fn(co) + + self.assertIsInstance(a, TypeVar) + self.assertEqual(a.__name__, "A") + self.assertIsInstance(b, TypeVar) + self.assertEqual(b.__name__, "B") + + +class TypeParamsTypeVarTupleTest(unittest.TestCase): + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: "cannot use bound with TypeVarTuple" does not match "invalid syntax (, line 1)" + def test_typevartuple_01(self): + code = """def func1[*A: str](): pass""" + check_syntax_error(self, code, "cannot use bound with TypeVarTuple") + code = """def func1[*A: (int, str)](): pass""" + check_syntax_error(self, code, "cannot use constraints with TypeVarTuple") + code = """class X[*A: str]: pass""" + check_syntax_error(self, code, "cannot use bound with TypeVarTuple") + code = """class X[*A: (int, str)]: pass""" + check_syntax_error(self, code, "cannot use constraints with TypeVarTuple") + code = """type X[*A: str] = int""" + check_syntax_error(self, code, "cannot use bound with TypeVarTuple") + code = """type X[*A: (int, str)] = int""" + check_syntax_error(self, code, "cannot use constraints with TypeVarTuple") + + def test_typevartuple_02(self): + def func1[*A](): + return A + + a = func1() + self.assertIsInstance(a, TypeVarTuple) + + +class TypeParamsTypeVarParamSpecTest(unittest.TestCase): + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: "cannot use bound with ParamSpec" does not match "invalid syntax (, line 1)" + def test_paramspec_01(self): + code = """def func1[**A: str](): pass""" + check_syntax_error(self, code, "cannot use bound with ParamSpec") + code = """def func1[**A: (int, str)](): pass""" + check_syntax_error(self, code, "cannot use constraints with ParamSpec") + code = """class X[**A: str]: pass""" + check_syntax_error(self, code, "cannot use bound with ParamSpec") + code = """class X[**A: (int, str)]: pass""" + check_syntax_error(self, code, "cannot use constraints with ParamSpec") + code = """type X[**A: str] = int""" + check_syntax_error(self, code, "cannot use bound with ParamSpec") + code = """type X[**A: (int, str)] = int""" + check_syntax_error(self, code, "cannot use constraints with ParamSpec") + + def test_paramspec_02(self): + def func1[**A](): + return A + + a = func1() + self.assertIsInstance(a, ParamSpec) + self.assertTrue(a.__infer_variance__) + self.assertFalse(a.__covariant__) + self.assertFalse(a.__contravariant__) + + +class TypeParamsTypeParamsDunder(unittest.TestCase): + def test_typeparams_dunder_class_01(self): + class Outer[A, B]: + class Inner[C, D]: + @staticmethod + def get_typeparams(): + return A, B, C, D + + a, b, c, d = Outer.Inner.get_typeparams() + self.assertEqual(Outer.__type_params__, (a, b)) + self.assertEqual(Outer.Inner.__type_params__, (c, d)) + + self.assertEqual(Outer.__parameters__, (a, b)) + self.assertEqual(Outer.Inner.__parameters__, (c, d)) + + def test_typeparams_dunder_class_02(self): + class ClassA: + pass + + self.assertEqual(ClassA.__type_params__, ()) + + def test_typeparams_dunder_class_03(self): + code = """ + class ClassA[A](): + pass + ClassA.__type_params__ = () + params = ClassA.__type_params__ + """ + + ns = run_code(code) + self.assertEqual(ns["params"], ()) + + def test_typeparams_dunder_function_01(self): + def outer[A, B](): + def inner[C, D](): + return A, B, C, D + + return inner + + inner = outer() + a, b, c, d = inner() + self.assertEqual(outer.__type_params__, (a, b)) + self.assertEqual(inner.__type_params__, (c, d)) + + def test_typeparams_dunder_function_02(self): + def func1(): + pass + + self.assertEqual(func1.__type_params__, ()) + + def test_typeparams_dunder_function_03(self): + code = """ + def func[A](): + pass + func.__type_params__ = () + """ + + ns = run_code(code) + self.assertEqual(ns["func"].__type_params__, ()) + + + +# All these type aliases are used for pickling tests: +T = TypeVar('T') +def func1[X](x: X) -> X: ... +def func2[X, Y](x: X | Y) -> X | Y: ... +def func3[X, *Y, **Z](x: X, y: tuple[*Y], z: Z) -> X: ... +def func4[X: int, Y: (bytes, str)](x: X, y: Y) -> X | Y: ... + +class Class1[X]: ... +class Class2[X, Y]: ... +class Class3[X, *Y, **Z]: ... +class Class4[X: int, Y: (bytes, str)]: ... + + +class TypeParamsPickleTest(unittest.TestCase): + def test_pickling_functions(self): + things_to_test = [ + func1, + func2, + func3, + func4, + ] + for thing in things_to_test: + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(thing=thing, proto=proto): + pickled = pickle.dumps(thing, protocol=proto) + self.assertEqual(pickle.loads(pickled), thing) + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_pickling_classes(self): + things_to_test = [ + Class1, + Class1[int], + Class1[T], + + Class2, + Class2[int, T], + Class2[T, int], + Class2[int, str], + + Class3, + Class3[int, T, str, bytes, [float, object, T]], + + Class4, + Class4[int, bytes], + Class4[T, bytes], + Class4[int, T], + Class4[T, T], + ] + for thing in things_to_test: + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(thing=thing, proto=proto): + pickled = pickle.dumps(thing, protocol=proto) + self.assertEqual(pickle.loads(pickled), thing) + + for klass in things_to_test: + real_class = getattr(klass, '__origin__', klass) + thing = klass() + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(thing=thing, proto=proto): + pickled = pickle.dumps(thing, protocol=proto) + # These instances are not equal, + # but class check is good enough: + self.assertIsInstance(pickle.loads(pickled), real_class) + + +class TypeParamsWeakRefTest(unittest.TestCase): + def test_weakrefs(self): + T = TypeVar('T') + P = ParamSpec('P') + class OldStyle(Generic[T]): + pass + + class NewStyle[T]: + pass + + cases = [ + T, + TypeVar('T', bound=int), + P, + P.args, + P.kwargs, + TypeVarTuple('Ts'), + OldStyle, + OldStyle[int], + OldStyle(), + NewStyle, + NewStyle[int], + NewStyle(), + Generic[T], + ] + for case in cases: + with self.subTest(case=case): + weakref.ref(case) + + +class TypeParamsRuntimeTest(unittest.TestCase): + def test_name_error(self): + # gh-109118: This crashed the interpreter due to a refcounting bug + code = """ + class name_2[name_5]: + class name_4[name_5](name_0): + pass + """ + with self.assertRaises(NameError): + run_code(code) + + # Crashed with a slightly different stack trace + code = """ + class name_2[name_5]: + class name_4[name_5: name_5](name_0): + pass + """ + with self.assertRaises(NameError): + run_code(code) + + def test_broken_class_namespace(self): + code = """ + class WeirdMapping(dict): + def __missing__(self, key): + if key == "T": + raise RuntimeError + raise KeyError(key) + + class Meta(type): + def __prepare__(name, bases): + return WeirdMapping() + + class MyClass[V](metaclass=Meta): + class Inner[U](T): + pass + """ + with self.assertRaises(RuntimeError): + run_code(code) + + +class DefaultsTest(unittest.TestCase): + def test_defaults_on_func(self): + ns = run_code(""" + def func[T=int, **U=float, *V=None](): + pass + """) + + T, U, V = ns["func"].__type_params__ + self.assertIs(T.__default__, int) + self.assertIs(U.__default__, float) + self.assertIs(V.__default__, None) + + def test_defaults_on_class(self): + ns = run_code(""" + class C[T=int, **U=float, *V=None]: + pass + """) + + T, U, V = ns["C"].__type_params__ + self.assertIs(T.__default__, int) + self.assertIs(U.__default__, float) + self.assertIs(V.__default__, None) + + def test_defaults_on_type_alias(self): + ns = run_code(""" + type Alias[T = int, **U = float, *V = None] = int + """) + + T, U, V = ns["Alias"].__type_params__ + self.assertIs(T.__default__, int) + self.assertIs(U.__default__, float) + self.assertIs(V.__default__, None) + + def test_starred_invalid(self): + check_syntax_error(self, "type Alias[T = *int] = int") + check_syntax_error(self, "type Alias[**P = *int] = int") + + def test_starred_typevartuple(self): + ns = run_code(""" + default = tuple[int, str] + type Alias[*Ts = *default] = Ts + """) + + Ts, = ns["Alias"].__type_params__ + self.assertEqual(Ts.__default__, next(iter(ns["default"]))) + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: SyntaxError not raised + def test_nondefault_after_default(self): + check_syntax_error(self, "def func[T=int, U](): pass", "non-default type parameter 'U' follows default type parameter") + check_syntax_error(self, "class C[T=int, U]: pass", "non-default type parameter 'U' follows default type parameter") + check_syntax_error(self, "type A[T=int, U] = int", "non-default type parameter 'U' follows default type parameter") + + @unittest.expectedFailure # TODO: RUSTPYTHON; + defined + def test_lazy_evaluation(self): + ns = run_code(""" + type Alias[T = Undefined, *U = Undefined, **V = Undefined] = int + """) + + T, U, V = ns["Alias"].__type_params__ + + with self.assertRaises(NameError): + T.__default__ + with self.assertRaises(NameError): + U.__default__ + with self.assertRaises(NameError): + V.__default__ + + ns["Undefined"] = "defined" + self.assertEqual(T.__default__, "defined") + self.assertEqual(U.__default__, "defined") + self.assertEqual(V.__default__, "defined") + + # Now it is cached + ns["Undefined"] = "redefined" + self.assertEqual(T.__default__, "defined") + self.assertEqual(U.__default__, "defined") + self.assertEqual(V.__default__, "defined") + + def test_symtable_key_regression_default(self): + # Test against the bugs that would happen if we used .default_ + # as the key in the symtable. + ns = run_code(""" + type X[T = [T for T in [T]]] = T + """) + + T, = ns["X"].__type_params__ + self.assertEqual(T.__default__, [T]) + + def test_symtable_key_regression_name(self): + # Test against the bugs that would happen if we used .name + # as the key in the symtable. + ns = run_code(""" + type X1[T = A] = T + type X2[T = B] = T + A = "A" + B = "B" + """) + + self.assertEqual(ns["X1"].__type_params__[0].__default__, "A") + self.assertEqual(ns["X2"].__type_params__[0].__default__, "B") + + +class TestEvaluateFunctions(unittest.TestCase): + @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'TypeAliasType' object has no attribute 'evaluate_value' + def test_general(self): + type Alias = int + Alias2 = TypeAliasType("Alias2", int) + def f[T: int = int, **P = int, *Ts = int](): pass + T, P, Ts = f.__type_params__ + T2 = TypeVar("T2", bound=int, default=int) + P2 = ParamSpec("P2", default=int) + Ts2 = TypeVarTuple("Ts2", default=int) + cases = [ + Alias.evaluate_value, + Alias2.evaluate_value, + T.evaluate_bound, + T.evaluate_default, + P.evaluate_default, + Ts.evaluate_default, + T2.evaluate_bound, + T2.evaluate_default, + P2.evaluate_default, + Ts2.evaluate_default, + ] + for case in cases: + with self.subTest(case=case): + self.assertIs(case(1), int) + self.assertIs(annotationlib.call_evaluate_function(case, annotationlib.Format.VALUE), int) + self.assertIs(annotationlib.call_evaluate_function(case, annotationlib.Format.FORWARDREF), int) + self.assertEqual(annotationlib.call_evaluate_function(case, annotationlib.Format.STRING), 'int') + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_constraints(self): + def f[T: (int, str)](): pass + T, = f.__type_params__ + T2 = TypeVar("T2", int, str) + for case in [T, T2]: + with self.subTest(case=case): + self.assertEqual(case.evaluate_constraints(1), (int, str)) + self.assertEqual(annotationlib.call_evaluate_function(case.evaluate_constraints, annotationlib.Format.VALUE), (int, str)) + self.assertEqual(annotationlib.call_evaluate_function(case.evaluate_constraints, annotationlib.Format.FORWARDREF), (int, str)) + self.assertEqual(annotationlib.call_evaluate_function(case.evaluate_constraints, annotationlib.Format.STRING), '(int, str)') + + @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'TypeVar' object has no attribute 'evaluate_bound' + def test_const_evaluator(self): + T = TypeVar("T", bound=int) + self.assertEqual(repr(T.evaluate_bound), ">") + + ConstEvaluator = type(T.evaluate_bound) + + with self.assertRaisesRegex(TypeError, r"cannot create '_typing\._ConstEvaluator' instances"): + ConstEvaluator() # This used to segfault. + with self.assertRaisesRegex(TypeError, r"cannot set 'attribute' attribute of immutable type '_typing\._ConstEvaluator'"): + ConstEvaluator.attribute = 1 diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 3372312c6d8..a37d64946f8 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -986,7 +986,6 @@ class C(Generic[T]): pass ) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_two_parameters(self): T1 = TypeVar('T1') T2 = TypeVar('T2') @@ -4210,7 +4209,7 @@ class P(Protocol): Alias2 = typing.Union[P, typing.Iterable] self.assertEqual(Alias, Alias2) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Generic() takes no arguments def test_protocols_pickleable(self): global P, CP # pickle wants to reference the class by name T = TypeVar('T') @@ -5288,7 +5287,7 @@ def test_all_repr_eq_any(self): self.assertNotEqual(repr(base), '') self.assertEqual(base, base) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Generic() takes no arguments def test_pickle(self): global C # pickle wants to reference the class by name T = TypeVar('T') @@ -5975,7 +5974,6 @@ def test_final_unmodified(self): def func(x): ... self.assertIs(func, final(func)) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_dunder_final(self): @final def func(): ... @@ -7417,7 +7415,6 @@ def test_partial_evaluation(self): list[EqualToForwardRef('A')], ) - @unittest.expectedFailure # TODO: RUSTPYTHON; ImportError: cannot import name 'fwdref_module' def test_with_module(self): from test.typinganndata import fwdref_module @@ -8440,7 +8437,7 @@ class Bar(NamedTuple): self.assertIsInstance(bar.attr, Vanilla) self.assertEqual(bar.attr.name, "attr") - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON; + Error calling __set_name__ on 'Annoying' instance attr in 'NamedTupleClass' def test_setname_raises_the_same_as_on_other_classes(self): class CustomException(BaseException): pass @@ -8907,14 +8904,12 @@ class NewGeneric[T](TypedDict): # The TypedDict constructor is not itself a TypedDict self.assertIs(is_typeddict(TypedDict), False) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_get_type_hints(self): self.assertEqual( get_type_hints(Bar), {'a': typing.Optional[int], 'b': int} ) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_get_type_hints_generic(self): self.assertEqual( get_type_hints(BarGeneric), @@ -9070,7 +9065,6 @@ class WithImplicitAny(B): with self.assertRaises(TypeError): WithImplicitAny[str] - @unittest.expectedFailure # TODO: RUSTPYTHON def test_non_generic_subscript(self): # For backward compatibility, subscription works # on arbitrary TypedDict types. @@ -9198,7 +9192,6 @@ class Child(Base): self.assertEqual(Child.__readonly_keys__, frozenset()) self.assertEqual(Child.__mutable_keys__, frozenset({'a'})) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_combine_qualifiers(self): class AllTheThings(TypedDict): a: Annotated[Required[ReadOnly[int]], "why not"] @@ -9440,7 +9433,7 @@ def test_repr(self): self.assertEqual(repr(Match[str]), 'typing.Match[str]') self.assertEqual(repr(Match[bytes]), 'typing.Match[bytes]') - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: "type 're\.Match' is not an acceptable base type" does not match "type '_sre.Match' is not an acceptable base type" def test_cannot_subclass(self): with self.assertRaisesRegex( TypeError, @@ -10852,7 +10845,6 @@ def test_no_call(self): with self.assertRaises(TypeError): NoDefault() - @unittest.expectedFailure # TODO: RUSTPYTHON def test_no_attributes(self): with self.assertRaises(AttributeError): NoDefault.foo = 3 diff --git a/Lib/test/typinganndata/fwdref_module.py b/Lib/test/typinganndata/fwdref_module.py new file mode 100644 index 00000000000..7347a7a4245 --- /dev/null +++ b/Lib/test/typinganndata/fwdref_module.py @@ -0,0 +1,6 @@ +from typing import ForwardRef + +MyList = list[int] +MyDict = dict[str, 'MyList'] + +fw = ForwardRef('MyDict', module=__name__) diff --git a/Lib/test/typinganndata/partialexecution/__init__.py b/Lib/test/typinganndata/partialexecution/__init__.py new file mode 100644 index 00000000000..c39074ea84b --- /dev/null +++ b/Lib/test/typinganndata/partialexecution/__init__.py @@ -0,0 +1 @@ +from . import a diff --git a/Lib/test/typinganndata/partialexecution/a.py b/Lib/test/typinganndata/partialexecution/a.py new file mode 100644 index 00000000000..ed0b8dcbd55 --- /dev/null +++ b/Lib/test/typinganndata/partialexecution/a.py @@ -0,0 +1,5 @@ +v1: int + +from . import b + +v2: int diff --git a/Lib/test/typinganndata/partialexecution/b.py b/Lib/test/typinganndata/partialexecution/b.py new file mode 100644 index 00000000000..36b8d2e52a3 --- /dev/null +++ b/Lib/test/typinganndata/partialexecution/b.py @@ -0,0 +1,3 @@ +from . import a + +annos = a.__annotations__ diff --git a/Lib/typing.py b/Lib/typing.py index 92b78defd11..380211183a4 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -1545,9 +1545,9 @@ def __init__(self, origin, nparams, *, inst=True, name=None, defaults=()): self._nparams = nparams self._defaults = defaults if origin.__module__ == 'builtins': - self.__doc__ = f'A generic version of {origin.__qualname__}.' + self.__doc__ = f'Deprecated alias to {origin.__qualname__}.' else: - self.__doc__ = f'A generic version of {origin.__module__}.{origin.__qualname__}.' + self.__doc__ = f'Deprecated alias to {origin.__module__}.{origin.__qualname__}.' @_tp_cache def __getitem__(self, params): diff --git a/crates/codegen/src/compile.rs b/crates/codegen/src/compile.rs index 909dc3fd4ed..826848ff271 100644 --- a/crates/codegen/src/compile.rs +++ b/crates/codegen/src/compile.rs @@ -1071,7 +1071,7 @@ impl Compiler { ), CompilerScope::Annotation => ( bytecode::CodeFlags::NEWLOCALS | bytecode::CodeFlags::OPTIMIZED, - 0, + 1, // format is positional-only 1, // annotation scope takes one argument (format) 0, ), @@ -1232,17 +1232,15 @@ impl Compiler { /// Enter annotation scope using the symbol table's annotation_block /// Returns false if no annotation_block exists - fn enter_annotation_scope(&mut self, func_name: &str) -> CompileResult { + fn enter_annotation_scope(&mut self, _func_name: &str) -> CompileResult { if !self.push_annotation_symbol_table() { return Ok(false); } let key = self.symbol_table_stack.len() - 1; let lineno = self.get_source_line_number().get(); - let annotate_name = format!(""); - self.enter_scope( - &annotate_name, + "__annotate__", CompilerScope::Annotation, key, lineno.to_u32(), @@ -1886,15 +1884,17 @@ impl Compiler { // Special handling for class scope implicit cell variables // These are treated as Cell even if not explicitly marked in symbol table - // Only for LOAD operations - explicit stores like `__class__ = property(...)` - // should use STORE_NAME to store in class namespace dict + // __class__ and __classdict__: only LOAD uses Cell (stores go to class namespace) + // __conditional_annotations__: both LOAD and STORE use Cell (it's a mutable set + // that the annotation scope accesses through the closure) let symbol_scope = { let current_table = self.current_symbol_table(); if current_table.typ == CompilerScope::Class - && usage == NameUsage::Load - && (name == "__class__" - || name == "__classdict__" - || name == "__conditional_annotations__") + && ((usage == NameUsage::Load + && (name == "__class__" + || name == "__classdict__" + || name == "__conditional_annotations__")) + || (name == "__conditional_annotations__" && usage == NameUsage::Store)) { Some(SymbolScope::Cell) } else { @@ -2437,27 +2437,79 @@ impl Compiler { }); if let Some(type_params) = type_params { - // For TypeAlias, we need to use push_symbol_table to properly handle the TypeAlias scope + // Outer scope for TypeParams self.push_symbol_table()?; + let key = self.symbol_table_stack.len() - 1; + let lineno = self.get_source_line_number().get().to_u32(); + let scope_name = format!(""); + self.enter_scope(&scope_name, CompilerScope::TypeParams, key, lineno)?; + + // TypeParams scope is function-like + let prev_ctx = self.ctx; + self.ctx = CompileContext { + loop_data: None, + in_class: prev_ctx.in_class, + func: FunctionContext::Function, + in_async_scope: false, + }; - // Compile type params and push to stack + // Compile type params inside the scope self.compile_type_params(type_params)?; - // Stack now has [name, type_params_tuple] + // Stack: [type_params_tuple] - // Compile value expression (can now see T1, T2) + // Inner closure for lazy value evaluation + self.push_symbol_table()?; + let inner_key = self.symbol_table_stack.len() - 1; + self.enter_scope("TypeAlias", CompilerScope::TypeParams, inner_key, lineno)?; self.compile_expression(value)?; - // Stack: [name, type_params_tuple, value] + emit!(self, Instruction::ReturnValue); + let value_code = self.exit_scope(); + self.make_closure(value_code, bytecode::MakeFunctionFlags::empty())?; + // Stack: [type_params_tuple, value_closure] + + // Swap so unpack_sequence reverse gives correct order + emit!(self, Instruction::Swap { index: 2_u32 }); + // Stack: [value_closure, type_params_tuple] + + // Build tuple and return from TypeParams scope + emit!(self, Instruction::BuildTuple { size: 2 }); + emit!(self, Instruction::ReturnValue); + + let code = self.exit_scope(); + self.ctx = prev_ctx; + self.make_closure(code, bytecode::MakeFunctionFlags::empty())?; + emit!(self, Instruction::PushNull); + emit!(self, Instruction::Call { nargs: 0 }); - // Pop the TypeAlias scope - self.pop_symbol_table(); + // Unpack: (value_closure, type_params_tuple) + // UnpackSequence reverses → stack: [name, type_params_tuple, value_closure] + emit!(self, Instruction::UnpackSequence { size: 2 }); } else { // Push None for type_params self.emit_load_const(ConstantData::None); // Stack: [name, None] - // Compile value expression + // Create a closure for lazy evaluation of the value + self.push_symbol_table()?; + let key = self.symbol_table_stack.len() - 1; + let lineno = self.get_source_line_number().get().to_u32(); + self.enter_scope("TypeAlias", CompilerScope::TypeParams, key, lineno)?; + + let prev_ctx = self.ctx; + self.ctx = CompileContext { + loop_data: None, + in_class: prev_ctx.in_class, + func: FunctionContext::Function, + in_async_scope: false, + }; + self.compile_expression(value)?; - // Stack: [name, None, value] + emit!(self, Instruction::ReturnValue); + + let code = self.exit_scope(); + self.ctx = prev_ctx; + self.make_closure(code, bytecode::MakeFunctionFlags::empty())?; + // Stack: [name, None, closure] } // Build tuple of 3 elements and call intrinsic @@ -2583,6 +2635,15 @@ impl Compiler { // Enter scope with the type parameter name self.enter_scope(name, CompilerScope::TypeParams, key, lineno)?; + // TypeParams scope is function-like + let prev_ctx = self.ctx; + self.ctx = CompileContext { + loop_data: None, + in_class: prev_ctx.in_class, + func: FunctionContext::Function, + in_async_scope: false, + }; + // Compile the expression if allow_starred && matches!(expr, ast::Expr::Starred(_)) { if let ast::Expr::Starred(starred) = expr { @@ -2598,14 +2659,10 @@ impl Compiler { // Exit scope and create closure let code = self.exit_scope(); - // Note: exit_scope already calls pop_symbol_table, so we don't need to call it again + self.ctx = prev_ctx; - // Create type params function with closure + // Create closure for lazy evaluation self.make_closure(code, bytecode::MakeFunctionFlags::empty())?; - emit!(self, Instruction::PushNull); - - // Call the function immediately - emit!(self, Instruction::Call { nargs: 0 }); Ok(()) } @@ -3844,16 +3901,8 @@ impl Compiler { // Check if we have conditional annotations let has_conditional = self.current_symbol_table().has_conditional_annotations; - // Get parent scope type and name BEFORE pushing annotation symbol table + // Get parent scope type BEFORE pushing annotation symbol table let parent_scope_type = self.current_symbol_table().typ; - let parent_name = self - .symbol_table_stack - .last() - .map(|t| t.name.as_str()) - .unwrap_or("module") - .to_owned(); - let scope_name = format!(""); - // Try to push annotation symbol table from current scope if !self.push_current_annotation_symbol_table() { return Ok(false); @@ -3862,7 +3911,12 @@ impl Compiler { // Enter annotation scope for code generation let key = self.symbol_table_stack.len() - 1; let lineno = self.get_source_line_number().get(); - self.enter_scope(&scope_name, CompilerScope::Annotation, key, lineno.to_u32())?; + self.enter_scope( + "__annotate__", + CompilerScope::Annotation, + key, + lineno.to_u32(), + )?; // Add 'format' parameter to varnames self.current_code_info() @@ -3991,6 +4045,9 @@ impl Compiler { let is_generic = type_params.is_some(); let mut num_typeparam_args = 0; + // Save context before entering TypeParams scope + let saved_ctx = self.ctx; + if is_generic { // Count args to pass to type params scope if funcflags.contains(bytecode::MakeFunctionFlags::DEFAULTS) { @@ -4010,6 +4067,14 @@ impl Compiler { type_params_name, )?; + // TypeParams scope is function-like + self.ctx = CompileContext { + loop_data: None, + in_class: saved_ctx.in_class, + func: FunctionContext::Function, + in_async_scope: false, + }; + // Add parameter names to varnames for the type params scope // These will be passed as arguments when the closure is called let current_info = self.current_code_info(); @@ -4068,6 +4133,7 @@ impl Compiler { // Exit type params scope and create closure let type_params_code = self.exit_scope(); + self.ctx = saved_ctx; // Make closure for type params code self.make_closure(type_params_code, bytecode::MakeFunctionFlags::empty())?; @@ -4316,12 +4382,26 @@ impl Compiler { Self::find_ann(body) || Self::find_ann(orelse) } ast::Stmt::With(ast::StmtWith { body, .. }) => Self::find_ann(body), + ast::Stmt::Match(ast::StmtMatch { cases, .. }) => { + cases.iter().any(|case| Self::find_ann(&case.body)) + } ast::Stmt::Try(ast::StmtTry { body, + handlers, orelse, finalbody, .. - }) => Self::find_ann(body) || Self::find_ann(orelse) || Self::find_ann(finalbody), + }) => { + Self::find_ann(body) + || handlers.iter().any(|h| { + let ast::ExceptHandler::ExceptHandler( + ast::ExceptHandlerExceptHandler { body, .. }, + ) = h; + Self::find_ann(body) + }) + || Self::find_ann(orelse) + || Self::find_ann(finalbody) + } _ => false, }; if res { @@ -4458,6 +4538,9 @@ impl Compiler { let is_generic = type_params.is_some(); let firstlineno = self.get_source_line_number().get().to_u32(); + // Save context before entering any scopes + let saved_ctx = self.ctx; + // Step 1: If generic, enter type params scope and compile type params if is_generic { let type_params_name = format!(""); @@ -4472,6 +4555,14 @@ impl Compiler { // Set private name for name mangling self.code_stack.last_mut().unwrap().private = Some(name.to_owned()); + // TypeParams scope is function-like + self.ctx = CompileContext { + loop_data: None, + in_class: saved_ctx.in_class, + func: FunctionContext::Function, + in_async_scope: false, + }; + // Compile type parameters and store as .type_params self.compile_type_params(type_params.unwrap())?; let dot_type_params = self.name(".type_params"); @@ -4622,6 +4713,7 @@ impl Compiler { // Exit type params scope and wrap in function let type_params_code = self.exit_scope(); + self.ctx = saved_ctx; // Execute the type params function self.make_closure(type_params_code, bytecode::MakeFunctionFlags::empty())?; @@ -6180,8 +6272,7 @@ impl Compiler { // Only add to __conditional_annotations__ set if actually conditional if is_conditional { - let cond_annotations_name = self.name("__conditional_annotations__"); - emit!(self, Instruction::LoadName(cond_annotations_name)); + self.load_name("__conditional_annotations__")?; self.emit_load_const(ConstantData::Integer { value: annotation_index.into(), }); diff --git a/crates/codegen/src/symboltable.rs b/crates/codegen/src/symboltable.rs index 4862c25b8b6..f324f48d507 100644 --- a/crates/codegen/src/symboltable.rs +++ b/crates/codegen/src/symboltable.rs @@ -396,12 +396,17 @@ impl SymbolTableAnalyzer { // we need to pass class symbols to the annotation scope let is_class = symbol_table.typ == CompilerScope::Class; - // Clone class symbols if needed for annotation scope (to avoid borrow conflict) - let class_symbols_for_ann = if is_class - && annotation_block - .as_ref() - .is_some_and(|b| b.can_see_class_scope) - { + // Clone class symbols if needed for child scopes with can_see_class_scope + let needs_class_symbols = (is_class + && (sub_tables.iter().any(|st| st.can_see_class_scope) + || annotation_block + .as_ref() + .is_some_and(|b| b.can_see_class_scope))) + || (!is_class + && class_entry.is_some() + && sub_tables.iter().any(|st| st.can_see_class_scope)); + + let class_symbols_clone = if is_class && needs_class_symbols { Some(symbols.clone()) } else { None @@ -412,15 +417,32 @@ impl SymbolTableAnalyzer { let inner_scope = unsafe { &mut *(list as *mut _ as *mut Self) }; // Analyze sub scopes and collect their free variables for sub_table in sub_tables.iter_mut() { - // Sub-scopes (functions, nested classes) don't inherit class_entry - let child_free = inner_scope.analyze_symbol_table(sub_table, None)?; + // Pass class_entry to sub-scopes that can see the class scope + let child_class_entry = if sub_table.can_see_class_scope { + if is_class { + class_symbols_clone.as_ref() + } else { + class_entry + } + } else { + None + }; + let child_free = inner_scope.analyze_symbol_table(sub_table, child_class_entry)?; // Propagate child's free variables to this scope newfree.extend(child_free); } // PEP 649: Analyze annotation block if present if let Some(annotation_table) = annotation_block { // Pass class symbols to annotation scope if can_see_class_scope - let ann_class_entry = class_symbols_for_ann.as_ref().or(class_entry); + let ann_class_entry = if annotation_table.can_see_class_scope { + if is_class { + class_symbols_clone.as_ref() + } else { + class_entry + } + } else { + None + }; let child_free = inner_scope.analyze_symbol_table(annotation_table, ann_class_entry)?; // Propagate annotation's free variables to this scope @@ -535,24 +557,21 @@ impl SymbolTableAnalyzer { // all is well } SymbolScope::Unknown => { - // PEP 649: Check class_entry first (like analyze_name) - // If name is bound in enclosing class, mark as GlobalImplicit - if let Some(class_symbols) = class_entry - && let Some(class_sym) = class_symbols.get(&symbol.name) - { - // DEF_BOUND && !DEF_NONLOCAL -> GLOBAL_IMPLICIT - if class_sym.is_bound() && class_sym.scope != SymbolScope::Free { - symbol.scope = SymbolScope::GlobalImplicit; - return Ok(()); - } - } - // Try hard to figure out what the scope of this symbol is. let scope = if symbol.is_bound() { self.found_in_inner_scope(sub_tables, &symbol.name, st_typ) .unwrap_or(SymbolScope::Local) } else if let Some(scope) = self.found_in_outer_scope(&symbol.name, st_typ) { + // If found in enclosing scope (function/TypeParams), use that scope + } else if let Some(class_symbols) = class_entry + && let Some(class_sym) = class_symbols.get(&symbol.name) + && class_sym.is_bound() + && class_sym.scope != SymbolScope::Free + { + // If name is bound in enclosing class, use GlobalImplicit + // so it can be accessed via __classdict__ + SymbolScope::GlobalImplicit } else if self.tables.is_empty() { // Don't make assumptions when we don't know. SymbolScope::Unknown @@ -1054,7 +1073,13 @@ impl SymbolTableBuilder { if is_conditional && !self.tables.last().unwrap().has_conditional_annotations { self.tables.last_mut().unwrap().has_conditional_annotations = true; - // Register __conditional_annotations__ symbol in the scope (USE flag, not DEF) + // Register __conditional_annotations__ as both Assigned and Used so that + // it becomes a Cell variable in class scope (children reference it as Free) + self.register_name( + "__conditional_annotations__", + SymbolUsage::Assigned, + annotation.range(), + )?; self.register_name( "__conditional_annotations__", SymbolUsage::Used, @@ -1441,16 +1466,35 @@ impl SymbolTableBuilder { }) => { let was_in_type_alias = self.in_type_alias; self.in_type_alias = true; + // Check before entering any sub-scopes + let in_class = self + .tables + .last() + .is_some_and(|t| t.typ == CompilerScope::Class); + let is_generic = type_params.is_some(); if let Some(type_params) = type_params { self.enter_type_param_block( "TypeAlias", self.line_index_start(type_params.range), )?; self.scan_type_params(type_params)?; - self.scan_expression(value, ExpressionContext::Load)?; + } + // Value scope for lazy evaluation + self.enter_scope( + "TypeAlias", + CompilerScope::TypeParams, + self.line_index_start(value.range()), + ); + if in_class { + if let Some(table) = self.tables.last_mut() { + table.can_see_class_scope = true; + } + self.register_name("__classdict__", SymbolUsage::Used, TextRange::default())?; + } + self.scan_expression(value, ExpressionContext::Load)?; + self.leave_scope(); + if is_generic { self.leave_scope(); - } else { - self.scan_expression(value, ExpressionContext::Load)?; } self.in_type_alias = was_in_type_alias; self.scan_expression(name, ExpressionContext::Store)?; @@ -1943,11 +1987,16 @@ impl SymbolTableBuilder { ) -> SymbolTableResult { // Enter a new TypeParams scope for the bound/default expression // This allows the expression to access outer scope symbols + let in_class = self.tables.last().is_some_and(|t| t.can_see_class_scope); let line_number = self.line_index_start(expr.range()); self.enter_scope(scope_name, CompilerScope::TypeParams, line_number); - // Note: In CPython, can_see_class_scope is preserved in the new scope - // In RustPython, this is handled through the scope hierarchy + if in_class { + if let Some(table) = self.tables.last_mut() { + table.can_see_class_scope = true; + } + self.register_name("__classdict__", SymbolUsage::Used, TextRange::default())?; + } // Set scope_info for better error messages let old_scope_info = self.scope_info; diff --git a/crates/derive-impl/src/pymodule.rs b/crates/derive-impl/src/pymodule.rs index 705f155b282..ed86d142cef 100644 --- a/crates/derive-impl/src/pymodule.rs +++ b/crates/derive-impl/src/pymodule.rs @@ -700,7 +700,18 @@ impl ModuleItem for ClassItem { }; let class_new = quote_spanned!(ident.span() => let new_class = <#ident as ::rustpython_vm::class::PyClassImpl>::make_class(ctx); - new_class.set_attr(rustpython_vm::identifier!(ctx, __module__), vm.new_pyobj(#module_name)); + // Only set __module__ string if the class doesn't already have a + // getset descriptor for __module__ (which provides instance-level + // module resolution, e.g. TypeAliasType) + { + let module_key = rustpython_vm::identifier!(ctx, __module__); + let has_module_getset = new_class.attributes.read() + .get(module_key) + .is_some_and(|v| v.downcastable::()); + if !has_module_getset { + new_class.set_attr(module_key, vm.new_pyobj(#module_name)); + } + } ); (class_name, class_new) }; @@ -778,7 +789,15 @@ impl ModuleItem for StructSequenceItem { // Generate the class creation code let class_new = quote_spanned!(pytype_ident.span() => let new_class = <#pytype_ident as ::rustpython_vm::class::PyClassImpl>::make_class(ctx); - new_class.set_attr(rustpython_vm::identifier!(ctx, __module__), vm.new_pyobj(#module_name)); + { + let module_key = rustpython_vm::identifier!(ctx, __module__); + let has_module_getset = new_class.attributes.read() + .get(module_key) + .is_some_and(|v| v.downcastable::()); + if !has_module_getset { + new_class.set_attr(module_key, vm.new_pyobj(#module_name)); + } + } ); // Handle py_attrs for custom names, or use class_name as default diff --git a/crates/vm/src/builtins/function.rs b/crates/vm/src/builtins/function.rs index 1a6d4520083..9541e968ab6 100644 --- a/crates/vm/src/builtins/function.rs +++ b/crates/vm/src/builtins/function.rs @@ -718,19 +718,23 @@ impl PyFunction { value: PySetterValue>, vm: &VirtualMachine, ) -> PyResult<()> { - let annotations = match value { + match value { PySetterValue::Assign(Some(value)) => { let annotations = value.downcast::().map_err(|_| { vm.new_type_error("__annotations__ must be set to a dict object") })?; - Some(annotations) + *self.annotations.lock() = Some(annotations); + *self.annotate.lock() = None; } - PySetterValue::Assign(None) | PySetterValue::Delete => None, - }; - *self.annotations.lock() = annotations; - - // Clear __annotate__ when __annotations__ is set - *self.annotate.lock() = None; + PySetterValue::Assign(None) => { + *self.annotations.lock() = None; + *self.annotate.lock() = None; + } + PySetterValue::Delete => { + // del only clears cached annotations; __annotate__ is preserved + *self.annotations.lock() = None; + } + } Ok(()) } diff --git a/crates/vm/src/builtins/genericalias.rs b/crates/vm/src/builtins/genericalias.rs index 21034e08f0e..b30d8586331 100644 --- a/crates/vm/src/builtins/genericalias.rs +++ b/crates/vm/src/builtins/genericalias.rs @@ -1,27 +1,26 @@ -// spell-checker:ignore iparam +// spell-checker:ignore iparam gaiterobject use std::sync::LazyLock; use super::type_; use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, atomic_func, - builtins::{PyList, PyStr, PyTuple, PyTupleRef, PyType, PyTypeRef}, + builtins::{PyList, PyStr, PyTuple, PyTupleRef, PyType}, class::PyClassImpl, common::hash, convert::ToPyObject, function::{FuncArgs, PyComparisonValue}, protocol::{PyMappingMethods, PyNumberMethods}, types::{ - AsMapping, AsNumber, Callable, Comparable, Constructor, GetAttr, Hashable, Iterable, - PyComparisonOp, Representable, + AsMapping, AsNumber, Callable, Comparable, Constructor, GetAttr, Hashable, IterNext, + Iterable, PyComparisonOp, Representable, }, }; use alloc::fmt; -// attr_exceptions -static ATTR_EXCEPTIONS: [&str; 12] = [ +// Attributes that are looked up on the GenericAlias itself, not on __origin__ +static ATTR_EXCEPTIONS: [&str; 9] = [ "__class__", - "__bases__", "__origin__", "__args__", "__unpacked__", @@ -30,13 +29,14 @@ static ATTR_EXCEPTIONS: [&str; 12] = [ "__mro_entries__", "__reduce_ex__", // needed so we don't look up object.__reduce_ex__ "__reduce__", - "__copy__", - "__deepcopy__", ]; +// Attributes that are blocked from being looked up on __origin__ +static ATTR_BLOCKED: [&str; 3] = ["__bases__", "__copy__", "__deepcopy__"]; + #[pyclass(module = "types", name = "GenericAlias")] pub struct PyGenericAlias { - origin: PyTypeRef, + origin: PyObjectRef, args: PyTupleRef, parameters: PyTupleRef, starred: bool, // for __unpacked__ attribute @@ -62,7 +62,7 @@ impl Constructor for PyGenericAlias { if !args.kwargs.is_empty() { return Err(vm.new_type_error("GenericAlias() takes no keyword arguments")); } - let (origin, arguments): (_, PyObjectRef) = args.bind(vm)?; + let (origin, arguments): (PyObjectRef, PyObjectRef) = args.bind(vm)?; let args = if let Ok(tuple) = arguments.try_to_ref::(vm) { tuple.to_owned() } else { @@ -87,10 +87,15 @@ impl Constructor for PyGenericAlias { flags(BASETYPE) )] impl PyGenericAlias { - pub fn new(origin: PyTypeRef, args: PyTupleRef, starred: bool, vm: &VirtualMachine) -> Self { + pub fn new( + origin: impl Into, + args: PyTupleRef, + starred: bool, + vm: &VirtualMachine, + ) -> Self { let parameters = make_parameters(&args, vm); Self { - origin, + origin: origin.into(), args, parameters, starred, @@ -98,7 +103,11 @@ impl PyGenericAlias { } /// Create a GenericAlias from an origin and PyObjectRef arguments (helper for compatibility) - pub fn from_args(origin: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> Self { + pub fn from_args( + origin: impl Into, + args: PyObjectRef, + vm: &VirtualMachine, + ) -> Self { let args = if let Ok(tuple) = args.try_to_ref::(vm) { tuple.to_owned() } else { @@ -138,15 +147,35 @@ impl PyGenericAlias { } } + fn repr_arg(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // ParamSpec args can be lists - format their items with repr_item + if obj.class().is(vm.ctx.types.list_type) { + let list = obj.downcast_ref::().unwrap(); + let len = list.borrow_vec().len(); + let mut parts = Vec::with_capacity(len); + // Use indexed access so list mutation during repr causes IndexError + for i in 0..len { + let item = + list.borrow_vec().get(i).cloned().ok_or_else(|| { + vm.new_index_error("list index out of range".to_owned()) + })?; + parts.push(repr_item(item, vm)?); + } + Ok(format!("[{}]", parts.join(", "))) + } else { + repr_item(obj, vm) + } + } + let repr_str = format!( "{}[{}]", - repr_item(self.origin.clone().into(), vm)?, + repr_item(self.origin.clone(), vm)?, if self.args.is_empty() { "()".to_owned() } else { self.args .iter() - .map(|o| repr_item(o.clone(), vm)) + .map(|o| repr_arg(o.clone(), vm)) .collect::>>()? .join(", ") } @@ -172,7 +201,7 @@ impl PyGenericAlias { #[pygetset] fn __origin__(&self) -> PyObjectRef { - self.origin.clone().into() + self.origin.clone() } #[pygetset] @@ -182,7 +211,7 @@ impl PyGenericAlias { #[pygetset] fn __typing_unpacked_tuple_args__(&self, vm: &VirtualMachine) -> PyObjectRef { - if self.starred && self.origin.is(vm.ctx.types.tuple_type) { + if self.starred && self.origin.is(vm.ctx.types.tuple_type.as_object()) { self.args.clone().into() } else { vm.ctx.none() @@ -213,11 +242,29 @@ impl PyGenericAlias { } #[pymethod] - fn __reduce__(zelf: &Py, vm: &VirtualMachine) -> (PyTypeRef, (PyTypeRef, PyTupleRef)) { - ( - vm.ctx.types.generic_alias_type.to_owned(), - (zelf.origin.clone(), zelf.args.clone()), - ) + fn __reduce__(zelf: &Py, vm: &VirtualMachine) -> PyResult { + if zelf.starred { + // (next, (iter(GenericAlias(origin, args)),)) + let next_fn = vm.builtins.get_attr("next", vm)?; + let non_starred = Self::new(zelf.origin.clone(), zelf.args.clone(), false, vm); + let iter_obj = PyGenericAliasIterator { + obj: crate::common::lock::PyMutex::new(Some(non_starred.into_pyobject(vm))), + } + .into_pyobject(vm); + Ok(PyTuple::new_ref( + vec![next_fn, PyTuple::new_ref(vec![iter_obj], &vm.ctx).into()], + &vm.ctx, + )) + } else { + Ok(PyTuple::new_ref( + vec![ + vm.ctx.types.generic_alias_type.to_owned().into(), + PyTuple::new_ref(vec![zelf.origin.clone(), zelf.args.clone().into()], &vm.ctx) + .into(), + ], + &vm.ctx, + )) + } } #[pymethod] @@ -245,8 +292,11 @@ impl PyGenericAlias { } pub(crate) fn make_parameters(args: &Py, vm: &VirtualMachine) -> PyTupleRef { + make_parameters_from_slice(args.as_slice(), vm) +} + +fn make_parameters_from_slice(args: &[PyObjectRef], vm: &VirtualMachine) -> PyTupleRef { let mut parameters: Vec = Vec::with_capacity(args.len()); - let mut iparam = 0; for arg in args { // We don't want __parameters__ descriptor of a bare Python class. @@ -256,37 +306,34 @@ pub(crate) fn make_parameters(args: &Py, vm: &VirtualMachine) -> PyTupl // Check for __typing_subst__ attribute if arg.get_attr(identifier!(vm, __typing_subst__), vm).is_ok() { - // Use tuple_add equivalent logic if tuple_index(¶meters, arg).is_none() { - if iparam >= parameters.len() { - parameters.resize(iparam + 1, vm.ctx.none()); - } - parameters[iparam] = arg.clone(); - iparam += 1; + parameters.push(arg.clone()); } } else if let Ok(subparams) = arg.get_attr(identifier!(vm, __parameters__), vm) && let Ok(sub_params) = subparams.try_to_ref::(vm) { - let len2 = sub_params.len(); - // Resize if needed - if iparam + len2 > parameters.len() { - parameters.resize(iparam + len2, vm.ctx.none()); - } for sub_param in sub_params { - // Use tuple_add equivalent logic - if tuple_index(¶meters[..iparam], sub_param).is_none() { - if iparam >= parameters.len() { - parameters.resize(iparam + 1, vm.ctx.none()); - } - parameters[iparam] = sub_param.clone(); - iparam += 1; + if tuple_index(¶meters, sub_param).is_none() { + parameters.push(sub_param.clone()); + } + } + } else if arg.try_to_ref::(vm).is_ok() || arg.try_to_ref::(vm).is_ok() { + // Recursively extract parameters from lists/tuples (ParamSpec args) + let items: Vec = if let Ok(t) = arg.try_to_ref::(vm) { + t.as_slice().to_vec() + } else { + let list = arg.downcast_ref::().unwrap(); + list.borrow_vec().to_vec() + }; + let sub = make_parameters_from_slice(&items, vm); + for sub_param in sub.iter() { + if tuple_index(¶meters, sub_param).is_none() { + parameters.push(sub_param.clone()); } } } } - // Resize to actual size - parameters.truncate(iparam); PyTuple::new_ref(parameters, &vm.ctx) } @@ -433,7 +480,7 @@ pub fn subs_parameters( let arg_items = if let Ok(tuple) = item.try_to_ref::(vm) { tuple.as_slice().to_vec() } else { - vec![item] + vec![item.clone()] }; let n_items = arg_items.len(); @@ -457,32 +504,55 @@ pub fn subs_parameters( continue; } - // Check if this is an unpacked TypeVarTuple's _is_unpacked_typevartuple + // Recursively substitute params in lists/tuples + let is_list = arg.try_to_ref::(vm).is_ok(); + if arg.try_to_ref::(vm).is_ok() || is_list { + let sub_items: Vec = if let Ok(t) = arg.try_to_ref::(vm) { + t.as_slice().to_vec() + } else { + arg.downcast_ref::().unwrap().borrow_vec().to_vec() + }; + let sub_tuple = PyTuple::new_ref(sub_items, &vm.ctx); + let sub_result = subs_parameters( + alias.clone(), + sub_tuple, + parameters.clone(), + item.clone(), + vm, + )?; + let substituted: PyObjectRef = if is_list { + // Convert tuple back to list + PyList::from(sub_result.as_slice().to_vec()) + .into_ref(&vm.ctx) + .into() + } else { + sub_result.into() + }; + new_args.push(substituted); + continue; + } + + // Check if this is an unpacked TypeVarTuple let unpack = is_unpacked_typevartuple(arg, vm)?; - // Try __typing_subst__ method first, + // Try __typing_subst__ method first let substituted_arg = if let Ok(subst) = arg.get_attr(identifier!(vm, __typing_subst__), vm) { - // Find parameter index's tuple_index if let Some(iparam) = tuple_index(parameters.as_slice(), arg) { subst.call((arg_items[iparam].clone(),), vm)? } else { - // This shouldn't happen in well-formed generics but handle gracefully subs_tvars(arg.clone(), ¶meters, &arg_items, vm)? } } else { - // Use subs_tvars for objects with __parameters__ subs_tvars(arg.clone(), ¶meters, &arg_items, vm)? }; if unpack { - // Handle unpacked TypeVarTuple's tuple_extend if let Ok(tuple) = substituted_arg.try_to_ref::(vm) { for elem in tuple { new_args.push(elem.clone()); } } else { - // This shouldn't happen but handle gracefully new_args.push(substituted_arg); } } else { @@ -519,7 +589,7 @@ impl AsNumber for PyGenericAlias { impl Callable for PyGenericAlias { type Args = FuncArgs; fn call(zelf: &Py, args: FuncArgs, vm: &VirtualMachine) -> PyResult { - PyType::call(&zelf.origin, args, vm).map(|obj| { + zelf.origin.call(args, vm).map(|obj| { if let Err(exc) = obj.set_attr(identifier!(vm, __orig_class__), zelf.to_owned(), vm) && !exc.fast_isinstance(vm.ctx.exceptions.attribute_error) && !exc.fast_isinstance(vm.ctx.exceptions.type_error) @@ -540,17 +610,17 @@ impl Comparable for PyGenericAlias { ) -> PyResult { op.eq_only(|| { let other = class_or_notimplemented!(Self, other); + if zelf.starred != other.starred { + return Ok(PyComparisonValue::Implemented(false)); + } Ok(PyComparisonValue::Implemented( - if !zelf.__origin__().rich_compare_bool( - &other.__origin__(), - PyComparisonOp::Eq, - vm, - )? { - false - } else { - zelf.__args__() - .rich_compare_bool(&other.__args__(), PyComparisonOp::Eq, vm)? - }, + zelf.__origin__() + .rich_compare_bool(&other.__origin__(), PyComparisonOp::Eq, vm)? + && zelf.__args__().rich_compare_bool( + &other.__args__(), + PyComparisonOp::Eq, + vm, + )?, )) }) } @@ -559,14 +629,20 @@ impl Comparable for PyGenericAlias { impl Hashable for PyGenericAlias { #[inline] fn hash(zelf: &Py, vm: &VirtualMachine) -> PyResult { - Ok(zelf.origin.as_object().hash(vm)? ^ zelf.args.as_object().hash(vm)?) + Ok(zelf.origin.hash(vm)? ^ zelf.args.as_object().hash(vm)?) } } impl GetAttr for PyGenericAlias { fn getattro(zelf: &Py, attr: &Py, vm: &VirtualMachine) -> PyResult { + let attr_str = attr.as_str(); for exc in &ATTR_EXCEPTIONS { - if *(*exc) == attr.to_string() { + if *exc == attr_str { + return zelf.as_object().generic_getattr(attr, vm); + } + } + for blocked in &ATTR_BLOCKED { + if *blocked == attr_str { return zelf.as_object().generic_getattr(attr, vm); } } @@ -582,27 +658,65 @@ impl Representable for PyGenericAlias { } impl Iterable for PyGenericAlias { - // ga_iter - // spell-checker:ignore gaiterobject - // TODO: gaiterobject fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { - // CPython's ga_iter creates an iterator that yields one starred GenericAlias - // we don't have gaiterobject yet + Ok(PyGenericAliasIterator { + obj: crate::common::lock::PyMutex::new(Some(zelf.into())), + } + .into_pyobject(vm)) + } +} - let starred_alias = Self::new( - zelf.origin.clone(), - zelf.args.clone(), - true, // starred - vm, - ); - let starred_ref = PyRef::new_ref( - starred_alias, - vm.ctx.types.generic_alias_type.to_owned(), - None, - ); - let items = vec![starred_ref.into()]; - let iter_tuple = PyTuple::new_ref(items, &vm.ctx); - Ok(iter_tuple.to_pyobject(vm).get_iter(vm)?.into()) +// gaiterobject - yields one starred GenericAlias then exhausts +#[pyclass(module = "types", name = "generic_alias_iterator")] +#[derive(Debug, PyPayload)] +pub struct PyGenericAliasIterator { + obj: crate::common::lock::PyMutex>, +} + +#[pyclass(with(Representable, Iterable, IterNext))] +impl PyGenericAliasIterator { + #[pymethod] + fn __reduce__(&self, vm: &VirtualMachine) -> PyResult { + let iter_fn = vm.builtins.get_attr("iter", vm)?; + let guard = self.obj.lock(); + let arg: PyObjectRef = if let Some(ref obj) = *guard { + // Not yet exhausted: (iter, (obj,)) + PyTuple::new_ref(vec![obj.clone()], &vm.ctx).into() + } else { + // Exhausted: (iter, ((),)) + let empty = PyTuple::new_ref(vec![], &vm.ctx); + PyTuple::new_ref(vec![empty.into()], &vm.ctx).into() + }; + Ok(PyTuple::new_ref(vec![iter_fn, arg], &vm.ctx)) + } +} + +impl Representable for PyGenericAliasIterator { + fn repr_str(_zelf: &Py, _vm: &VirtualMachine) -> PyResult { + Ok("".to_owned()) + } +} + +impl Iterable for PyGenericAliasIterator { + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyResult { + Ok(zelf.into()) + } +} + +impl crate::types::IterNext for PyGenericAliasIterator { + fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { + use crate::protocol::PyIterReturn; + let mut guard = zelf.obj.lock(); + let obj = match guard.take() { + Some(obj) => obj, + None => return Ok(PyIterReturn::StopIteration(None)), + }; + // Create a starred GenericAlias from the original + let alias = obj.downcast_ref::().ok_or_else(|| { + vm.new_type_error("generic_alias_iterator expected GenericAlias".to_owned()) + })?; + let starred = PyGenericAlias::new(alias.origin.clone(), alias.args.clone(), true, vm); + Ok(PyIterReturn::Return(starred.into_pyobject(vm))) } } @@ -628,6 +742,6 @@ pub fn subscript_generic(type_params: PyObjectRef, vm: &VirtualMachine) -> PyRes } pub fn init(context: &Context) { - let generic_alias_type = &context.types.generic_alias_type; - PyGenericAlias::extend_class(context, generic_alias_type); + PyGenericAlias::extend_class(context, context.types.generic_alias_type); + PyGenericAliasIterator::extend_class(context, context.types.generic_alias_iterator_type); } diff --git a/crates/vm/src/builtins/list.rs b/crates/vm/src/builtins/list.rs index c6b9869a0ab..7203110c7da 100644 --- a/crates/vm/src/builtins/list.rs +++ b/crates/vm/src/builtins/list.rs @@ -511,7 +511,11 @@ impl Representable for PyList { let s = if zelf.__len__() == 0 { "[]".to_owned() } else if let Some(_guard) = ReprGuard::enter(vm, zelf.as_object()) { - collection_repr(None, "[", "]", zelf.borrow_vec().iter(), vm)? + // Clone elements before calling repr to release the read lock. + // Element repr may mutate the list (e.g., list.clear()), which + // needs a write lock and would deadlock if read lock is held. + let elements: Vec = zelf.borrow_vec().to_vec(); + collection_repr(None, "[", "]", elements.iter(), vm)? } else { "[...]".to_owned() }; diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index 74de2c12eb4..e2852ee7dc5 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -927,10 +927,11 @@ impl PyType { } let mut attrs = self.attributes.write(); - // Store to __annotate_func__ + // Clear cached annotations only when setting to a new callable + if !vm.is_none(&value) { + attrs.swap_remove(identifier!(vm, __annotations_cache__)); + } attrs.insert(identifier!(vm, __annotate_func__), value.clone()); - // Always clear cached annotations when __annotate__ is updated - attrs.swap_remove(identifier!(vm, __annotations_cache__)); Ok(()) } @@ -999,7 +1000,11 @@ impl PyType { } #[pygetset(setter)] - fn set___annotations__(&self, value: Option, vm: &VirtualMachine) -> PyResult<()> { + fn set___annotations__( + &self, + value: crate::function::PySetterValue, + vm: &VirtualMachine, + ) -> PyResult<()> { if self.slots.flags.has_feature(PyTypeFlags::IMMUTABLETYPE) { return Err(vm.new_type_error(format!( "cannot set '__annotations__' attribute of immutable type '{}'", @@ -1008,33 +1013,40 @@ impl PyType { } let mut attrs = self.attributes.write(); - // conditional update based on __annotations__ presence let has_annotations = attrs.contains_key(identifier!(vm, __annotations__)); - if has_annotations { - // If __annotations__ is in dict, update it - if let Some(value) = value { - attrs.insert(identifier!(vm, __annotations__), value); - } else if attrs - .swap_remove(identifier!(vm, __annotations__)) - .is_none() - { - return Err(vm.new_attribute_error("__annotations__".to_owned())); + match value { + crate::function::PySetterValue::Assign(value) => { + // SET path: store the value (including None) + let key = if has_annotations { + identifier!(vm, __annotations__) + } else { + identifier!(vm, __annotations_cache__) + }; + attrs.insert(key, value); + if has_annotations { + attrs.swap_remove(identifier!(vm, __annotations_cache__)); + } } - // Also clear __annotations_cache__ - attrs.swap_remove(identifier!(vm, __annotations_cache__)); - } else { - // Otherwise update only __annotations_cache__ - if let Some(value) = value { - attrs.insert(identifier!(vm, __annotations_cache__), value); - } else if attrs - .swap_remove(identifier!(vm, __annotations_cache__)) - .is_none() - { - return Err(vm.new_attribute_error("__annotations__".to_owned())); + crate::function::PySetterValue::Delete => { + // DELETE path: remove the key + let removed = if has_annotations { + attrs + .swap_remove(identifier!(vm, __annotations__)) + .is_some() + } else { + attrs + .swap_remove(identifier!(vm, __annotations_cache__)) + .is_some() + }; + if !removed { + return Err(vm.new_attribute_error("__annotations__".to_owned())); + } + if has_annotations { + attrs.swap_remove(identifier!(vm, __annotations_cache__)); + } } } - // Always clear __annotate_func__ and __annotate__ attrs.swap_remove(identifier!(vm, __annotate_func__)); attrs.swap_remove(identifier!(vm, __annotate__)); @@ -1055,7 +1067,15 @@ impl PyType { Some(found) } }) - .unwrap_or_else(|| vm.ctx.new_str(ascii!("builtins")).into()) + .unwrap_or_else(|| { + // For non-heap types, extract module from tp_name (e.g. "typing.TypeAliasType" -> "typing") + let slot_name = self.slot_name(); + if let Some((module, _)) = slot_name.rsplit_once('.') { + vm.ctx.intern_str(module).to_object() + } else { + vm.ctx.new_str(ascii!("builtins")).into() + } + }) } #[pygetset(setter)] @@ -1838,6 +1858,13 @@ impl SetAttr for PyType { ) -> PyResult<()> { // TODO: pass PyRefExact instead of &str let attr_name = vm.ctx.intern_str(attr_name.as_str()); + if zelf.slots.flags.has_feature(PyTypeFlags::IMMUTABLETYPE) { + return Err(vm.new_type_error(format!( + "cannot set '{}' attribute of immutable type '{}'", + attr_name, + zelf.slot_name() + ))); + } if let Some(attr) = zelf.get_class_attr(attr_name) { let descr_set = attr.class().slots.descr_set.load(); if let Some(descriptor) = descr_set { diff --git a/crates/vm/src/class.rs b/crates/vm/src/class.rs index a71baf070cd..3075b59f0bb 100644 --- a/crates/vm/src/class.rs +++ b/crates/vm/src/class.rs @@ -165,10 +165,17 @@ pub trait PyClassImpl: PyClassDef { } } if let Some(module_name) = Self::MODULE_NAME { - class.set_attr( - identifier!(ctx, __module__), - ctx.new_str(module_name).into(), - ); + let module_key = identifier!(ctx, __module__); + // Don't overwrite a getset descriptor for __module__ (e.g. TypeAliasType + // has an instance-level __module__ getset that should not be replaced) + let has_getset = class + .attributes + .read() + .get(module_key) + .is_some_and(|v| v.downcastable::()); + if !has_getset { + class.set_attr(module_key, ctx.new_str(module_name).into()); + } } // Don't add __new__ attribute if slot_new is inherited from object diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index df726b5f684..f197effb7e0 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -671,7 +671,7 @@ impl ExecutingFrame<'_> { } else { let name = self.code.freevars[i - self.code.cellvars.len()]; vm.new_name_error( - format!("free variable '{name}' referenced before assignment in enclosing scope"), + format!("cannot access free variable '{name}' where it is not associated with a value in enclosing scope"), name.to_owned(), ) } @@ -3118,7 +3118,7 @@ impl ExecutingFrame<'_> { let name = tuple.as_slice()[0].clone(); let type_params_obj = tuple.as_slice()[1].clone(); - let value = tuple.as_slice()[2].clone(); + let compute_value = tuple.as_slice()[2].clone(); let type_params: PyTupleRef = if vm.is_none(&type_params_obj) { vm.ctx.empty_tuple.clone() @@ -3131,7 +3131,7 @@ impl ExecutingFrame<'_> { let name = name.downcast::().map_err(|_| { vm.new_type_error("TypeAliasType name must be a string".to_owned()) })?; - let type_alias = typing::TypeAliasType::new(name, type_params, value); + let type_alias = typing::TypeAliasType::new(name, type_params, compute_value); Ok(type_alias.into_ref(&vm.ctx).into()) } bytecode::IntrinsicFunction1::ListToTuple => { diff --git a/crates/vm/src/stdlib/ast/pyast.rs b/crates/vm/src/stdlib/ast/pyast.rs index 58f049aee40..14245a56c09 100644 --- a/crates/vm/src/stdlib/ast/pyast.rs +++ b/crates/vm/src/stdlib/ast/pyast.rs @@ -77,6 +77,13 @@ macro_rules! impl_base_node { #[extend_class] fn extend_class(ctx: &Context, class: &'static Py) { + // AST types are mutable (heap types in CPython, not IMMUTABLETYPE) + // Safety: called during type initialization before any concurrent access + unsafe { + let flags = &class.slots.flags as *const crate::types::PyTypeFlags + as *mut crate::types::PyTypeFlags; + (*flags).remove(crate::types::PyTypeFlags::IMMUTABLETYPE); + } class.set_attr( identifier!(ctx, _attributes), ctx.empty_tuple.clone().into(), @@ -100,6 +107,13 @@ macro_rules! impl_base_node { #[extend_class] fn extend_class_with_fields(ctx: &Context, class: &'static Py) { + // AST types are mutable (heap types in CPython, not IMMUTABLETYPE) + // Safety: called during type initialization before any concurrent access + unsafe { + let flags = &class.slots.flags as *const crate::types::PyTypeFlags + as *mut crate::types::PyTypeFlags; + (*flags).remove(crate::types::PyTypeFlags::IMMUTABLETYPE); + } class.set_attr( identifier!(ctx, _fields), ctx.new_tuple(vec![ @@ -530,6 +544,13 @@ pub(crate) struct NodeExprConstant(NodeExpr); impl NodeExprConstant { #[extend_class] fn extend_class_with_fields(ctx: &Context, class: &'static Py) { + // AST types are mutable (heap types, not IMMUTABLETYPE) + // Safety: called during type initialization before any concurrent access + unsafe { + let flags = &class.slots.flags as *const crate::types::PyTypeFlags + as *mut crate::types::PyTypeFlags; + (*flags).remove(crate::types::PyTypeFlags::IMMUTABLETYPE); + } class.set_attr( identifier!(ctx, _fields), ctx.new_tuple(vec![ diff --git a/crates/vm/src/stdlib/ast/python.rs b/crates/vm/src/stdlib/ast/python.rs index 539420f27c8..0de6f45b912 100644 --- a/crates/vm/src/stdlib/ast/python.rs +++ b/crates/vm/src/stdlib/ast/python.rs @@ -25,6 +25,13 @@ pub(crate) mod _ast { impl NodeAst { #[extend_class] fn extend_class(ctx: &Context, class: &'static Py) { + // AST types are mutable (heap types, not IMMUTABLETYPE) + // Safety: called during type initialization before any concurrent access + unsafe { + let flags = &class.slots.flags as *const crate::types::PyTypeFlags + as *mut crate::types::PyTypeFlags; + (*flags).remove(crate::types::PyTypeFlags::IMMUTABLETYPE); + } let empty_tuple = ctx.empty_tuple.clone(); class.set_str_attr("_fields", empty_tuple.clone(), ctx); class.set_str_attr("_attributes", empty_tuple.clone(), ctx); diff --git a/crates/vm/src/stdlib/typevar.rs b/crates/vm/src/stdlib/typevar.rs index d1be1118a2e..ac8aeac3636 100644 --- a/crates/vm/src/stdlib/typevar.rs +++ b/crates/vm/src/stdlib/typevar.rs @@ -391,7 +391,7 @@ pub(crate) mod typevar { evaluate_default: PyMutex::new(vm.ctx.none()), covariant: false, contravariant: false, - infer_variance: false, + infer_variance: true, } } } @@ -631,7 +631,7 @@ pub(crate) mod typevar { evaluate_default: PyMutex::new(vm.ctx.none()), covariant: false, contravariant: false, - infer_variance: false, + infer_variance: true, } } } diff --git a/crates/vm/src/stdlib/typing.rs b/crates/vm/src/stdlib/typing.rs index 94b014c62fa..59884d1ec9c 100644 --- a/crates/vm/src/stdlib/typing.rs +++ b/crates/vm/src/stdlib/typing.rs @@ -1,4 +1,4 @@ -// spell-checker:ignore typevarobject funcobj +// spell-checker:ignore typevarobject funcobj typevartuples use crate::{ Context, PyResult, VirtualMachine, builtins::pystr::AsPyStr, class::PyClassImpl, function::IntoFuncArgs, @@ -29,12 +29,13 @@ pub fn call_typing_func_object<'a>( #[pymodule(name = "_typing", with(super::typevar::typevar))] pub(crate) mod decl { use crate::{ - Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, - builtins::{PyStrRef, PyTupleRef, PyType, PyTypeRef, type_}, + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func, + builtins::{PyGenericAlias, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, type_}, function::FuncArgs, - protocol::PyNumberMethods, - types::{AsNumber, Constructor, Representable}, + protocol::{PyMappingMethods, PyNumberMethods}, + types::{AsMapping, AsNumber, Constructor, Iterable, Representable}, }; + use std::sync::LazyLock; #[pyfunction] pub(crate) fn _idfunc(args: FuncArgs, _vm: &VirtualMachine) -> PyObjectRef { @@ -84,23 +85,47 @@ pub(crate) mod decl { } #[pyattr] - #[pyclass(name)] + #[pyclass(name, module = "typing")] #[derive(Debug, PyPayload)] - #[allow(dead_code)] pub(crate) struct TypeAliasType { name: PyStrRef, type_params: PyTupleRef, - value: PyObjectRef, - // compute_value: PyObjectRef, - // module: PyObjectRef, + compute_value: PyObjectRef, + cached_value: crate::common::lock::PyMutex>, + module: Option, + is_lazy: bool, } - #[pyclass(with(Constructor, Representable, AsNumber), flags(BASETYPE))] + #[pyclass( + with(Constructor, Representable, AsMapping, AsNumber, Iterable), + flags(IMMUTABLETYPE) + )] impl TypeAliasType { - pub const fn new(name: PyStrRef, type_params: PyTupleRef, value: PyObjectRef) -> Self { + /// Create from intrinsic: compute_value is a callable that returns the value + pub fn new(name: PyStrRef, type_params: PyTupleRef, compute_value: PyObjectRef) -> Self { Self { name, type_params, - value, + compute_value, + cached_value: crate::common::lock::PyMutex::new(None), + module: None, + is_lazy: true, + } + } + + /// Create with an eagerly evaluated value (used by constructor) + fn new_eager( + name: PyStrRef, + type_params: PyTupleRef, + value: PyObjectRef, + module: Option, + ) -> Self { + Self { + name, + type_params, + compute_value: value.clone(), + cached_value: crate::common::lock::PyMutex::new(Some(value)), + module, + is_lazy: false, } } @@ -110,55 +135,200 @@ pub(crate) mod decl { } #[pygetset] - fn __value__(&self) -> PyObjectRef { - self.value.clone() + fn __value__(&self, vm: &VirtualMachine) -> PyResult { + let cached = self.cached_value.lock().clone(); + if let Some(value) = cached { + return Ok(value); + } + let value = self.compute_value.call((), vm)?; + *self.cached_value.lock() = Some(value.clone()); + Ok(value) } #[pygetset] fn __type_params__(&self) -> PyTupleRef { self.type_params.clone() } + + #[pygetset] + fn __parameters__(&self, vm: &VirtualMachine) -> PyResult { + // TypeVarTuples must be unpacked in __parameters__ + unpack_typevartuples(&self.type_params, vm).map(|t| t.into()) + } + + #[pygetset] + fn __module__(&self, vm: &VirtualMachine) -> PyObjectRef { + if let Some(ref module) = self.module { + return module.clone(); + } + // Fall back to compute_value's __module__ (like PyFunction_GetModule) + if let Ok(module) = self.compute_value.get_attr("__module__", vm) { + return module; + } + vm.ctx.none() + } + + fn __getitem__(zelf: PyRef, args: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if zelf.type_params.is_empty() { + return Err( + vm.new_type_error("Only generic type aliases are subscriptable".to_owned()) + ); + } + let args_tuple = if let Ok(tuple) = args.try_to_ref::(vm) { + tuple.to_owned() + } else { + PyTuple::new_ref(vec![args], &vm.ctx) + }; + let origin: PyObjectRef = zelf.as_object().to_owned(); + Ok(PyGenericAlias::new(origin, args_tuple, false, vm).into_pyobject(vm)) + } + + #[pymethod] + fn __reduce__(zelf: &Py, _vm: &VirtualMachine) -> PyObjectRef { + zelf.name.clone().into() + } + + #[pymethod] + fn __typing_unpacked_tuple_args__(&self, vm: &VirtualMachine) -> PyObjectRef { + vm.ctx.none() + } + + /// Returns the evaluator for the alias value. + #[pygetset] + fn evaluate_value(&self, vm: &VirtualMachine) -> PyResult { + if self.is_lazy { + // Lazy path: return the compute function directly + return Ok(self.compute_value.clone()); + } + // Eager path: wrap value in a ConstEvaluator + let value = self.compute_value.clone(); + Ok(vm + .new_function("_ConstEvaluator", move |_args: FuncArgs| -> PyResult { + Ok(value.clone()) + }) + .into()) + } + + /// Check type_params ordering: non-default params must precede default params. + /// Uses __default__ attribute to check if a type param has a default value, + /// comparing against typing.NoDefault sentinel (like get_type_param_default). + fn check_type_params( + type_params: &PyTupleRef, + vm: &VirtualMachine, + ) -> PyResult> { + if type_params.is_empty() { + return Ok(None); + } + let no_default = &vm.ctx.typing_no_default; + let mut default_seen = false; + for param in type_params.iter() { + let dflt = param.get_attr("__default__", vm).map_err(|_| { + vm.new_type_error(format!( + "Expected a type param, got {}", + param + .repr(vm) + .map(|s| s.to_string()) + .unwrap_or_else(|_| "?".to_owned()) + )) + })?; + let is_no_default = dflt.is(no_default); + if is_no_default { + if default_seen { + return Err(vm.new_type_error(format!( + "non-default type parameter '{}' follows default type parameter", + param.repr(vm)? + ))); + } + } else { + default_seen = true; + } + } + Ok(Some(type_params.clone())) + } } impl Constructor for TypeAliasType { type Args = FuncArgs; fn py_new(_cls: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult { - // TypeAliasType(name, value, *, type_params=None) - if args.args.len() < 2 { - return Err(vm.new_type_error(format!( - "TypeAliasType() missing {} required positional argument{}: {}", - 2 - args.args.len(), - if 2 - args.args.len() == 1 { "" } else { "s" }, - if args.args.is_empty() { - "'name' and 'value'" - } else { - "'value'" - } - ))); + // typealias(name, value, *, type_params=()) + // name and value are positional-or-keyword; type_params is keyword-only. + + // Reject unexpected keyword arguments + for key in args.kwargs.keys() { + if key != "name" && key != "value" && key != "type_params" { + return Err(vm.new_type_error(format!( + "typealias() got an unexpected keyword argument '{key}'" + ))); + } } + + // Reject too many positional arguments if args.args.len() > 2 { return Err(vm.new_type_error(format!( - "TypeAliasType() takes 2 positional arguments but {} were given", + "typealias() takes exactly 2 positional arguments ({} given)", args.args.len() ))); } - let name = args.args[0] - .clone() - .downcast::() - .map_err(|_| vm.new_type_error("TypeAliasType name must be a string".to_owned()))?; - let value = args.args[1].clone(); + // Resolve name: positional[0] or kwarg + let name = if !args.args.is_empty() { + if args.kwargs.contains_key("name") { + return Err(vm.new_type_error( + "argument for typealias() given by name ('name') and position (1)" + .to_owned(), + )); + } + args.args[0].clone() + } else { + args.kwargs.get("name").cloned().ok_or_else(|| { + vm.new_type_error( + "typealias() missing required argument 'name' (pos 1)".to_owned(), + ) + })? + }; + + // Resolve value: positional[1] or kwarg + let value = if args.args.len() >= 2 { + if args.kwargs.contains_key("value") { + return Err(vm.new_type_error( + "argument for typealias() given by name ('value') and position (2)" + .to_owned(), + )); + } + args.args[1].clone() + } else { + args.kwargs.get("value").cloned().ok_or_else(|| { + vm.new_type_error( + "typealias() missing required argument 'value' (pos 2)".to_owned(), + ) + })? + }; + + let name = name.downcast::().map_err(|obj| { + vm.new_type_error(format!( + "typealias() argument 'name' must be str, not {}", + obj.class().name() + )) + })?; let type_params = if let Some(tp) = args.kwargs.get("type_params") { - tp.clone() + let tp = tp + .clone() .downcast::() - .map_err(|_| vm.new_type_error("type_params must be a tuple".to_owned()))? + .map_err(|_| vm.new_type_error("type_params must be a tuple".to_owned()))?; + Self::check_type_params(&tp, vm)?; + tp } else { vm.ctx.empty_tuple.clone() }; - Ok(Self::new(name, type_params, value)) + // Get caller's module name from frame globals, like typevar.rs caller() + let module = vm + .current_frame() + .and_then(|f| f.globals.get_item("__name__", vm).ok()); + + Ok(Self::new_eager(name, type_params, value, module)) } } @@ -168,6 +338,19 @@ pub(crate) mod decl { } } + impl AsMapping for TypeAliasType { + fn as_mapping() -> &'static PyMappingMethods { + static AS_MAPPING: LazyLock = LazyLock::new(|| PyMappingMethods { + subscript: atomic_func!(|mapping, needle, vm| { + let zelf = TypeAliasType::mapping_downcast(mapping); + TypeAliasType::__getitem__(zelf.to_owned(), needle.to_owned(), vm) + }), + ..PyMappingMethods::NOT_IMPLEMENTED + }); + &AS_MAPPING + } + } + impl AsNumber for TypeAliasType { fn as_number() -> &'static PyNumberMethods { static AS_NUMBER: PyNumberMethods = PyNumberMethods { @@ -178,6 +361,41 @@ pub(crate) mod decl { } } + impl Iterable for TypeAliasType { + fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { + // Import typing.Unpack and return iter((Unpack[self],)) + let typing = vm.import("typing", 0)?; + let unpack = typing.get_attr("Unpack", vm)?; + let zelf_obj: PyObjectRef = zelf.into(); + let unpacked = vm.call_method(&unpack, "__getitem__", (zelf_obj,))?; + let tuple = PyTuple::new_ref(vec![unpacked], &vm.ctx); + Ok(tuple.as_object().get_iter(vm)?.into()) + } + } + + /// Wrap TypeVarTuples in Unpack[], matching unpack_typevartuples() + fn unpack_typevartuples(type_params: &PyTupleRef, vm: &VirtualMachine) -> PyResult { + let has_tvt = type_params + .iter() + .any(|p| p.downcastable::()); + if !has_tvt { + return Ok(type_params.clone()); + } + let typing = vm.import("typing", 0)?; + let unpack_cls = typing.get_attr("Unpack", vm)?; + let new_params: Vec = type_params + .iter() + .map(|p| { + if p.downcastable::() { + vm.call_method(&unpack_cls, "__getitem__", (p.clone(),)) + } else { + Ok(p.clone()) + } + }) + .collect::>()?; + Ok(PyTuple::new_ref(new_params, &vm.ctx)) + } + pub(crate) fn module_exec( vm: &VirtualMachine, module: &Py, diff --git a/crates/vm/src/types/zoo.rs b/crates/vm/src/types/zoo.rs index 2b60e37f316..07754d08340 100644 --- a/crates/vm/src/types/zoo.rs +++ b/crates/vm/src/types/zoo.rs @@ -93,6 +93,7 @@ pub struct TypeZoo { pub typing_no_default_type: &'static Py, pub not_implemented_type: &'static Py, pub generic_alias_type: &'static Py, + pub generic_alias_iterator_type: &'static Py, pub union_type: &'static Py, pub interpolation_type: &'static Py, pub template_type: &'static Py, @@ -200,6 +201,7 @@ impl TypeZoo { typing_no_default_type: crate::stdlib::typing::NoDefault::init_builtin_type(), not_implemented_type: singletons::PyNotImplemented::init_builtin_type(), generic_alias_type: genericalias::PyGenericAlias::init_builtin_type(), + generic_alias_iterator_type: genericalias::PyGenericAliasIterator::init_builtin_type(), union_type: union_::PyUnion::init_builtin_type(), interpolation_type: interpolation::PyInterpolation::init_builtin_type(), template_type: template::PyTemplate::init_builtin_type(),