Skip to content

Commit f274c78

Browse files
authored
Make ExitStack, AbstractContextManager and AsyncAbstractContextManager generic in return type of __exit__ (#11048)
1 parent b3bfdad commit f274c78

5 files changed

Lines changed: 42 additions & 28 deletions

File tree

stdlib/contextlib.pyi

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,32 +31,33 @@ if sys.version_info >= (3, 11):
3131
_T = TypeVar("_T")
3232
_T_co = TypeVar("_T_co", covariant=True)
3333
_T_io = TypeVar("_T_io", bound=IO[str] | None)
34+
_ExitT_co = TypeVar("_ExitT_co", covariant=True, bound=bool | None, default=bool | None)
3435
_F = TypeVar("_F", bound=Callable[..., Any])
3536
_P = ParamSpec("_P")
3637

3738
_ExitFunc: TypeAlias = Callable[[type[BaseException] | None, BaseException | None, TracebackType | None], bool | None]
38-
_CM_EF = TypeVar("_CM_EF", bound=AbstractContextManager[Any] | _ExitFunc)
39+
_CM_EF = TypeVar("_CM_EF", bound=AbstractContextManager[Any, Any] | _ExitFunc)
3940

4041
@runtime_checkable
41-
class AbstractContextManager(Protocol[_T_co]):
42+
class AbstractContextManager(Protocol[_T_co, _ExitT_co]):
4243
def __enter__(self) -> _T_co: ...
4344
@abstractmethod
4445
def __exit__(
4546
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, /
46-
) -> bool | None: ...
47+
) -> _ExitT_co: ...
4748

4849
@runtime_checkable
49-
class AbstractAsyncContextManager(Protocol[_T_co]):
50+
class AbstractAsyncContextManager(Protocol[_T_co, _ExitT_co]):
5051
async def __aenter__(self) -> _T_co: ...
5152
@abstractmethod
5253
async def __aexit__(
5354
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, /
54-
) -> bool | None: ...
55+
) -> _ExitT_co: ...
5556

5657
class ContextDecorator:
5758
def __call__(self, func: _F) -> _F: ...
5859

59-
class _GeneratorContextManager(AbstractContextManager[_T_co], ContextDecorator):
60+
class _GeneratorContextManager(AbstractContextManager[_T_co, bool | None], ContextDecorator):
6061
# __init__ and all instance attributes are actually inherited from _GeneratorContextManagerBase
6162
# _GeneratorContextManagerBase is more trouble than it's worth to include in the stub; see #6676
6263
def __init__(self, func: Callable[..., Iterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ...
@@ -81,7 +82,7 @@ if sys.version_info >= (3, 10):
8182
class AsyncContextDecorator:
8283
def __call__(self, func: _AF) -> _AF: ...
8384

84-
class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co], AsyncContextDecorator):
85+
class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co, bool | None], AsyncContextDecorator):
8586
# __init__ and these attributes are actually defined in the base class _GeneratorContextManagerBase,
8687
# which is more trouble than it's worth to include in the stub (see #6676)
8788
def __init__(self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ...
@@ -94,7 +95,7 @@ if sys.version_info >= (3, 10):
9495
) -> bool | None: ...
9596

9697
else:
97-
class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co]):
98+
class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co, bool | None]):
9899
def __init__(self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ...
99100
gen: AsyncGenerator[_T_co, Any]
100101
func: Callable[..., AsyncGenerator[_T_co, Any]]
@@ -111,7 +112,7 @@ class _SupportsClose(Protocol):
111112

112113
_SupportsCloseT = TypeVar("_SupportsCloseT", bound=_SupportsClose)
113114

114-
class closing(AbstractContextManager[_SupportsCloseT]):
115+
class closing(AbstractContextManager[_SupportsCloseT, None]):
115116
def __init__(self, thing: _SupportsCloseT) -> None: ...
116117
def __exit__(self, *exc_info: Unused) -> None: ...
117118

@@ -121,17 +122,17 @@ if sys.version_info >= (3, 10):
121122

122123
_SupportsAcloseT = TypeVar("_SupportsAcloseT", bound=_SupportsAclose)
123124

124-
class aclosing(AbstractAsyncContextManager[_SupportsAcloseT]):
125+
class aclosing(AbstractAsyncContextManager[_SupportsAcloseT, None]):
125126
def __init__(self, thing: _SupportsAcloseT) -> None: ...
126127
async def __aexit__(self, *exc_info: Unused) -> None: ...
127128

128-
class suppress(AbstractContextManager[None]):
129+
class suppress(AbstractContextManager[None, bool]):
129130
def __init__(self, *exceptions: type[BaseException]) -> None: ...
130131
def __exit__(
131132
self, exctype: type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None
132133
) -> bool: ...
133134

134-
class _RedirectStream(AbstractContextManager[_T_io]):
135+
class _RedirectStream(AbstractContextManager[_T_io, None]):
135136
def __init__(self, new_target: _T_io) -> None: ...
136137
def __exit__(
137138
self, exctype: type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None
@@ -142,27 +143,27 @@ class redirect_stderr(_RedirectStream[_T_io]): ...
142143

143144
# In reality this is a subclass of `AbstractContextManager`;
144145
# see #7961 for why we don't do that in the stub
145-
class ExitStack(metaclass=abc.ABCMeta):
146-
def enter_context(self, cm: AbstractContextManager[_T]) -> _T: ...
146+
class ExitStack(Generic[_ExitT_co], metaclass=abc.ABCMeta):
147+
def enter_context(self, cm: AbstractContextManager[_T, _ExitT_co]) -> _T: ...
147148
def push(self, exit: _CM_EF) -> _CM_EF: ...
148149
def callback(self, callback: Callable[_P, _T], /, *args: _P.args, **kwds: _P.kwargs) -> Callable[_P, _T]: ...
149150
def pop_all(self) -> Self: ...
150151
def close(self) -> None: ...
151152
def __enter__(self) -> Self: ...
152153
def __exit__(
153154
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, /
154-
) -> bool: ...
155+
) -> _ExitT_co: ...
155156

156157
_ExitCoroFunc: TypeAlias = Callable[
157158
[type[BaseException] | None, BaseException | None, TracebackType | None], Awaitable[bool | None]
158159
]
159-
_ACM_EF = TypeVar("_ACM_EF", bound=AbstractAsyncContextManager[Any] | _ExitCoroFunc)
160+
_ACM_EF = TypeVar("_ACM_EF", bound=AbstractAsyncContextManager[Any, Any] | _ExitCoroFunc)
160161

161162
# In reality this is a subclass of `AbstractAsyncContextManager`;
162163
# see #7961 for why we don't do that in the stub
163-
class AsyncExitStack(metaclass=abc.ABCMeta):
164-
def enter_context(self, cm: AbstractContextManager[_T]) -> _T: ...
165-
async def enter_async_context(self, cm: AbstractAsyncContextManager[_T]) -> _T: ...
164+
class AsyncExitStack(Generic[_ExitT_co], metaclass=abc.ABCMeta):
165+
def enter_context(self, cm: AbstractContextManager[_T, _ExitT_co]) -> _T: ...
166+
async def enter_async_context(self, cm: AbstractAsyncContextManager[_T, _ExitT_co]) -> _T: ...
166167
def push(self, exit: _CM_EF) -> _CM_EF: ...
167168
def push_async_exit(self, exit: _ACM_EF) -> _ACM_EF: ...
168169
def callback(self, callback: Callable[_P, _T], /, *args: _P.args, **kwds: _P.kwargs) -> Callable[_P, _T]: ...
@@ -177,7 +178,7 @@ class AsyncExitStack(metaclass=abc.ABCMeta):
177178
) -> bool: ...
178179

179180
if sys.version_info >= (3, 10):
180-
class nullcontext(AbstractContextManager[_T], AbstractAsyncContextManager[_T]):
181+
class nullcontext(AbstractContextManager[_T, None], AbstractAsyncContextManager[_T, None]):
181182
enter_result: _T
182183
@overload
183184
def __init__(self: nullcontext[None], enter_result: None = None) -> None: ...
@@ -189,7 +190,7 @@ if sys.version_info >= (3, 10):
189190
async def __aexit__(self, *exctype: Unused) -> None: ...
190191

191192
else:
192-
class nullcontext(AbstractContextManager[_T]):
193+
class nullcontext(AbstractContextManager[_T, None]):
193194
enter_result: _T
194195
@overload
195196
def __init__(self: nullcontext[None], enter_result: None = None) -> None: ...
@@ -201,7 +202,7 @@ else:
201202
if sys.version_info >= (3, 11):
202203
_T_fd_or_any_path = TypeVar("_T_fd_or_any_path", bound=FileDescriptorOrPath)
203204

204-
class chdir(AbstractContextManager[None], Generic[_T_fd_or_any_path]):
205+
class chdir(AbstractContextManager[None, None], Generic[_T_fd_or_any_path]):
205206
path: _T_fd_or_any_path
206207
def __init__(self, path: _T_fd_or_any_path) -> None: ...
207208
def __enter__(self) -> None: ...

stdlib/multiprocessing/synchronize.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class Barrier(threading.Barrier):
1414
self, parties: int, action: Callable[[], object] | None = None, timeout: float | None = None, *ctx: BaseContext
1515
) -> None: ...
1616

17-
class Condition(AbstractContextManager[bool]):
17+
class Condition(AbstractContextManager[bool, None]):
1818
def __init__(self, lock: _LockLike | None = None, *, ctx: BaseContext) -> None: ...
1919
def notify(self, n: int = 1) -> None: ...
2020
def notify_all(self) -> None: ...
@@ -34,7 +34,7 @@ class Event:
3434
def wait(self, timeout: float | None = None) -> bool: ...
3535

3636
# Not part of public API
37-
class SemLock(AbstractContextManager[bool]):
37+
class SemLock(AbstractContextManager[bool, None]):
3838
def acquire(self, block: bool = ..., timeout: float | None = ...) -> bool: ...
3939
def release(self) -> None: ...
4040
def __exit__(

stdlib/os/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ def replace(
794794
) -> None: ...
795795
def rmdir(path: StrOrBytesPath, *, dir_fd: int | None = None) -> None: ...
796796

797-
class _ScandirIterator(Iterator[DirEntry[AnyStr]], AbstractContextManager[_ScandirIterator[AnyStr]]):
797+
class _ScandirIterator(Iterator[DirEntry[AnyStr]], AbstractContextManager[_ScandirIterator[AnyStr], None]):
798798
def __next__(self) -> DirEntry[AnyStr]: ...
799799
def __exit__(self, *args: Unused) -> None: ...
800800
def close(self) -> None: ...

stdlib/typing.pyi

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,6 @@ if sys.version_info >= (3, 11):
129129
if sys.version_info >= (3, 12):
130130
__all__ += ["TypeAliasType", "override"]
131131

132-
ContextManager = AbstractContextManager
133-
AsyncContextManager = AbstractAsyncContextManager
134-
135132
Any = object()
136133

137134
def final(f: _T) -> _T: ...
@@ -431,6 +428,18 @@ class Generator(Iterator[_YieldT_co], Generic[_YieldT_co, _SendT_contra, _Return
431428
@property
432429
def gi_yieldfrom(self) -> Generator[Any, Any, Any] | None: ...
433430

431+
# NOTE: Technically we would like this to be able to accept a second parameter as well, just
432+
# like it's counterpart in contextlib, however `typing._SpecialGenericAlias` enforces the
433+
# correct number of arguments at runtime, so we would be hiding runtime errors.
434+
@runtime_checkable
435+
class ContextManager(AbstractContextManager[_T_co, bool | None], Protocol[_T_co]): ...
436+
437+
# NOTE: Technically we would like this to be able to accept a second parameter as well, just
438+
# like it's counterpart in contextlib, however `typing._SpecialGenericAlias` enforces the
439+
# correct number of arguments at runtime, so we would be hiding runtime errors.
440+
@runtime_checkable
441+
class AsyncContextManager(AbstractAsyncContextManager[_T_co, bool | None], Protocol[_T_co]): ...
442+
434443
@runtime_checkable
435444
class Awaitable(Protocol[_T_co]):
436445
@abstractmethod

tests/stubtest_allowlists/py3_common.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,10 @@ typing(_extensions)?\.TextIO\.errors
533533
typing(_extensions)?\.TextIO\.line_buffering
534534
typing(_extensions)?\.TextIO\.newlines
535535

536+
# These are typing._SpecialGenericAlias at runtime, which is not a real type, but it
537+
# behaves like one in most cases
538+
typing(_extensions)?\.(Async)?ContextManager
539+
536540
types.MethodType.__closure__ # read-only but not actually a property; stubtest thinks it doesn't exist.
537541
types.MethodType.__defaults__ # read-only but not actually a property; stubtest thinks it doesn't exist.
538542
types.ModuleType.__dict__ # read-only but not actually a property; stubtest thinks it's a mutable attribute.

0 commit comments

Comments
 (0)