[3.11] gh-111085: Fix invalid state handling in TaskGroup and Timeout (GH-111111) (GH-111172)

asyncio.TaskGroup and asyncio.Timeout classes now raise proper RuntimeError
if they are improperly used.

* When they are used without entering the context manager.
* When they are used after finishing.
* When the context manager is entered more than once (simultaneously or
  sequentially).
* If there is no current task when entering the context manager.

They now remain in a consistent state after an exception is thrown,
so subsequent operations can be performed correctly (if they are allowed).

(cherry picked from commit 6c23635f2b)

Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
This commit is contained in:
Miss Islington (bot) 2023-10-21 21:40:07 +02:00 committed by GitHub
parent cf28c61c73
commit cf777399a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 121 additions and 10 deletions

View File

@ -54,16 +54,14 @@ def __repr__(self):
async def __aenter__(self):
if self._entered:
raise RuntimeError(
f"TaskGroup {self!r} has been already entered")
self._entered = True
f"TaskGroup {self!r} has already been entered")
if self._loop is None:
self._loop = events.get_running_loop()
self._parent_task = tasks.current_task(self._loop)
if self._parent_task is None:
raise RuntimeError(
f'TaskGroup {self!r} cannot determine the parent task')
self._entered = True
return self

View File

@ -49,8 +49,9 @@ def when(self) -> Optional[float]:
def reschedule(self, when: Optional[float]) -> None:
"""Reschedule the timeout."""
assert self._state is not _State.CREATED
if self._state is not _State.ENTERED:
if self._state is _State.CREATED:
raise RuntimeError("Timeout has not been entered")
raise RuntimeError(
f"Cannot change state of {self._state.value} Timeout",
)
@ -82,11 +83,14 @@ def __repr__(self) -> str:
return f"<Timeout [{self._state.value}]{info_str}>"
async def __aenter__(self) -> "Timeout":
self._state = _State.ENTERED
self._task = tasks.current_task()
self._cancelling = self._task.cancelling()
if self._task is None:
if self._state is not _State.CREATED:
raise RuntimeError("Timeout has already been entered")
task = tasks.current_task()
if task is None:
raise RuntimeError("Timeout should be used inside a task")
self._state = _State.ENTERED
self._task = task
self._cancelling = self._task.cancelling()
self.reschedule(self._when)
return self

View File

@ -8,6 +8,8 @@
from asyncio import taskgroups
import unittest
from test.test_asyncio.utils import await_without_task
# To prevent a warning "test altered the execution environment"
def tearDownModule():
@ -779,6 +781,49 @@ async def main():
await asyncio.create_task(main())
async def test_taskgroup_already_entered(self):
tg = taskgroups.TaskGroup()
async with tg:
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with tg:
pass
async def test_taskgroup_double_enter(self):
tg = taskgroups.TaskGroup()
async with tg:
pass
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with tg:
pass
async def test_taskgroup_finished(self):
tg = taskgroups.TaskGroup()
async with tg:
pass
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "is finished"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro
async def test_taskgroup_not_entered(self):
tg = taskgroups.TaskGroup()
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro
async def test_taskgroup_without_parent_task(self):
tg = taskgroups.TaskGroup()
with self.assertRaisesRegex(RuntimeError, "parent task"):
await await_without_task(tg.__aenter__())
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro
if __name__ == "__main__":
unittest.main()

View File

@ -6,11 +6,12 @@
import asyncio
from asyncio import tasks
from test.test_asyncio.utils import await_without_task
def tearDownModule():
asyncio.set_event_loop_policy(None)
class TimeoutTests(unittest.IsolatedAsyncioTestCase):
async def test_timeout_basic(self):
@ -258,6 +259,51 @@ async def test_timeout_exception_cause (self):
cause = exc.exception.__cause__
assert isinstance(cause, asyncio.CancelledError)
async def test_timeout_already_entered(self):
async with asyncio.timeout(0.01) as cm:
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with cm:
pass
async def test_timeout_double_enter(self):
async with asyncio.timeout(0.01) as cm:
pass
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with cm:
pass
async def test_timeout_finished(self):
async with asyncio.timeout(0.01) as cm:
pass
with self.assertRaisesRegex(RuntimeError, "finished"):
cm.reschedule(0.02)
async def test_timeout_expired(self):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01) as cm:
await asyncio.sleep(1)
with self.assertRaisesRegex(RuntimeError, "expired"):
cm.reschedule(0.02)
async def test_timeout_expiring(self):
async with asyncio.timeout(0.01) as cm:
with self.assertRaises(asyncio.CancelledError):
await asyncio.sleep(1)
with self.assertRaisesRegex(RuntimeError, "expiring"):
cm.reschedule(0.02)
async def test_timeout_not_entered(self):
cm = asyncio.timeout(0.01)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
cm.reschedule(0.02)
async def test_timeout_without_task(self):
cm = asyncio.timeout(0.01)
with self.assertRaisesRegex(RuntimeError, "task"):
await await_without_task(cm.__aenter__())
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
cm.reschedule(0.02)
if __name__ == '__main__':
unittest.main()

View File

@ -612,3 +612,18 @@ def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
sock.family = family
sock.gettimeout.return_value = 0.0
return sock
async def await_without_task(coro):
exc = None
def func():
try:
for _ in coro.__await__():
pass
except BaseException as err:
nonlocal exc
exc = err
asyncio.get_running_loop().call_soon(func)
await asyncio.sleep(0)
if exc is not None:
raise exc

View File

@ -0,0 +1,3 @@
Fix invalid state handling in :class:`asyncio.TaskGroup` and
:class:`asyncio.Timeout`. They now raise proper RuntimeError if they are
improperly used and are left in consistent state after this.