mirror of https://github.com/python/cpython.git
[3.11] gh-111085: Fix invalid state handling in TaskGroup and Timeout (GH-111111) (GH-111172)
asyncio.TaskGroup and asyncio.Timeout classes now raise proper RuntimeError
if they are improperly used.
* When they are used without entering the context manager.
* When they are used after finishing.
* When the context manager is entered more than once (simultaneously or
sequentially).
* If there is no current task when entering the context manager.
They now remain in a consistent state after an exception is thrown,
so subsequent operations can be performed correctly (if they are allowed).
(cherry picked from commit 6c23635f2b
)
Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
This commit is contained in:
parent
cf28c61c73
commit
cf777399a9
|
@ -54,16 +54,14 @@ def __repr__(self):
|
|||
async def __aenter__(self):
|
||||
if self._entered:
|
||||
raise RuntimeError(
|
||||
f"TaskGroup {self!r} has been already entered")
|
||||
self._entered = True
|
||||
|
||||
f"TaskGroup {self!r} has already been entered")
|
||||
if self._loop is None:
|
||||
self._loop = events.get_running_loop()
|
||||
|
||||
self._parent_task = tasks.current_task(self._loop)
|
||||
if self._parent_task is None:
|
||||
raise RuntimeError(
|
||||
f'TaskGroup {self!r} cannot determine the parent task')
|
||||
self._entered = True
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
@ -49,8 +49,9 @@ def when(self) -> Optional[float]:
|
|||
|
||||
def reschedule(self, when: Optional[float]) -> None:
|
||||
"""Reschedule the timeout."""
|
||||
assert self._state is not _State.CREATED
|
||||
if self._state is not _State.ENTERED:
|
||||
if self._state is _State.CREATED:
|
||||
raise RuntimeError("Timeout has not been entered")
|
||||
raise RuntimeError(
|
||||
f"Cannot change state of {self._state.value} Timeout",
|
||||
)
|
||||
|
@ -82,11 +83,14 @@ def __repr__(self) -> str:
|
|||
return f"<Timeout [{self._state.value}]{info_str}>"
|
||||
|
||||
async def __aenter__(self) -> "Timeout":
|
||||
self._state = _State.ENTERED
|
||||
self._task = tasks.current_task()
|
||||
self._cancelling = self._task.cancelling()
|
||||
if self._task is None:
|
||||
if self._state is not _State.CREATED:
|
||||
raise RuntimeError("Timeout has already been entered")
|
||||
task = tasks.current_task()
|
||||
if task is None:
|
||||
raise RuntimeError("Timeout should be used inside a task")
|
||||
self._state = _State.ENTERED
|
||||
self._task = task
|
||||
self._cancelling = self._task.cancelling()
|
||||
self.reschedule(self._when)
|
||||
return self
|
||||
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
from asyncio import taskgroups
|
||||
import unittest
|
||||
|
||||
from test.test_asyncio.utils import await_without_task
|
||||
|
||||
|
||||
# To prevent a warning "test altered the execution environment"
|
||||
def tearDownModule():
|
||||
|
@ -779,6 +781,49 @@ async def main():
|
|||
|
||||
await asyncio.create_task(main())
|
||||
|
||||
async def test_taskgroup_already_entered(self):
|
||||
tg = taskgroups.TaskGroup()
|
||||
async with tg:
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
|
||||
async with tg:
|
||||
pass
|
||||
|
||||
async def test_taskgroup_double_enter(self):
|
||||
tg = taskgroups.TaskGroup()
|
||||
async with tg:
|
||||
pass
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
|
||||
async with tg:
|
||||
pass
|
||||
|
||||
async def test_taskgroup_finished(self):
|
||||
tg = taskgroups.TaskGroup()
|
||||
async with tg:
|
||||
pass
|
||||
coro = asyncio.sleep(0)
|
||||
with self.assertRaisesRegex(RuntimeError, "is finished"):
|
||||
tg.create_task(coro)
|
||||
# We still have to await coro to avoid a warning
|
||||
await coro
|
||||
|
||||
async def test_taskgroup_not_entered(self):
|
||||
tg = taskgroups.TaskGroup()
|
||||
coro = asyncio.sleep(0)
|
||||
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
||||
tg.create_task(coro)
|
||||
# We still have to await coro to avoid a warning
|
||||
await coro
|
||||
|
||||
async def test_taskgroup_without_parent_task(self):
|
||||
tg = taskgroups.TaskGroup()
|
||||
with self.assertRaisesRegex(RuntimeError, "parent task"):
|
||||
await await_without_task(tg.__aenter__())
|
||||
coro = asyncio.sleep(0)
|
||||
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
||||
tg.create_task(coro)
|
||||
# We still have to await coro to avoid a warning
|
||||
await coro
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -6,11 +6,12 @@
|
|||
import asyncio
|
||||
from asyncio import tasks
|
||||
|
||||
from test.test_asyncio.utils import await_without_task
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class TimeoutTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_timeout_basic(self):
|
||||
|
@ -258,6 +259,51 @@ async def test_timeout_exception_cause (self):
|
|||
cause = exc.exception.__cause__
|
||||
assert isinstance(cause, asyncio.CancelledError)
|
||||
|
||||
async def test_timeout_already_entered(self):
|
||||
async with asyncio.timeout(0.01) as cm:
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
|
||||
async with cm:
|
||||
pass
|
||||
|
||||
async def test_timeout_double_enter(self):
|
||||
async with asyncio.timeout(0.01) as cm:
|
||||
pass
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
|
||||
async with cm:
|
||||
pass
|
||||
|
||||
async def test_timeout_finished(self):
|
||||
async with asyncio.timeout(0.01) as cm:
|
||||
pass
|
||||
with self.assertRaisesRegex(RuntimeError, "finished"):
|
||||
cm.reschedule(0.02)
|
||||
|
||||
async def test_timeout_expired(self):
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(0.01) as cm:
|
||||
await asyncio.sleep(1)
|
||||
with self.assertRaisesRegex(RuntimeError, "expired"):
|
||||
cm.reschedule(0.02)
|
||||
|
||||
async def test_timeout_expiring(self):
|
||||
async with asyncio.timeout(0.01) as cm:
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await asyncio.sleep(1)
|
||||
with self.assertRaisesRegex(RuntimeError, "expiring"):
|
||||
cm.reschedule(0.02)
|
||||
|
||||
async def test_timeout_not_entered(self):
|
||||
cm = asyncio.timeout(0.01)
|
||||
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
||||
cm.reschedule(0.02)
|
||||
|
||||
async def test_timeout_without_task(self):
|
||||
cm = asyncio.timeout(0.01)
|
||||
with self.assertRaisesRegex(RuntimeError, "task"):
|
||||
await await_without_task(cm.__aenter__())
|
||||
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
||||
cm.reschedule(0.02)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -612,3 +612,18 @@ def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
|
|||
sock.family = family
|
||||
sock.gettimeout.return_value = 0.0
|
||||
return sock
|
||||
|
||||
|
||||
async def await_without_task(coro):
|
||||
exc = None
|
||||
def func():
|
||||
try:
|
||||
for _ in coro.__await__():
|
||||
pass
|
||||
except BaseException as err:
|
||||
nonlocal exc
|
||||
exc = err
|
||||
asyncio.get_running_loop().call_soon(func)
|
||||
await asyncio.sleep(0)
|
||||
if exc is not None:
|
||||
raise exc
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
Fix invalid state handling in :class:`asyncio.TaskGroup` and
|
||||
:class:`asyncio.Timeout`. They now raise proper RuntimeError if they are
|
||||
improperly used and are left in consistent state after this.
|
Loading…
Reference in New Issue