mirror of https://github.com/python/cpython.git
Taskgroup tweaks (GH-31559)
Now uses .cancel()/.uncancel(), for even fewer broken edge cases.
This commit is contained in:
parent
41ddcd3f40
commit
edbee56d69
|
@ -66,31 +66,28 @@ async def __aexit__(self, et, exc, tb):
|
||||||
self._base_error is None):
|
self._base_error is None):
|
||||||
self._base_error = exc
|
self._base_error = exc
|
||||||
|
|
||||||
if et is exceptions.CancelledError:
|
if et is not None:
|
||||||
if self._parent_cancel_requested:
|
|
||||||
# Only if we did request task to cancel ourselves
|
|
||||||
# we mark it as no longer cancelled.
|
|
||||||
self._parent_task.uncancel()
|
|
||||||
else:
|
|
||||||
propagate_cancellation_error = et
|
|
||||||
|
|
||||||
if et is not None and not self._aborting:
|
|
||||||
# Our parent task is being cancelled:
|
|
||||||
#
|
|
||||||
# async with TaskGroup() as g:
|
|
||||||
# g.create_task(...)
|
|
||||||
# await ... # <- CancelledError
|
|
||||||
#
|
|
||||||
if et is exceptions.CancelledError:
|
if et is exceptions.CancelledError:
|
||||||
propagate_cancellation_error = et
|
if self._parent_cancel_requested and not self._parent_task.uncancel():
|
||||||
|
# Do nothing, i.e. swallow the error.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
propagate_cancellation_error = exc
|
||||||
|
|
||||||
# or there's an exception in "async with":
|
if not self._aborting:
|
||||||
#
|
# Our parent task is being cancelled:
|
||||||
# async with TaskGroup() as g:
|
#
|
||||||
# g.create_task(...)
|
# async with TaskGroup() as g:
|
||||||
# 1 / 0
|
# g.create_task(...)
|
||||||
#
|
# await ... # <- CancelledError
|
||||||
self._abort()
|
#
|
||||||
|
# or there's an exception in "async with":
|
||||||
|
#
|
||||||
|
# async with TaskGroup() as g:
|
||||||
|
# g.create_task(...)
|
||||||
|
# 1 / 0
|
||||||
|
#
|
||||||
|
self._abort()
|
||||||
|
|
||||||
# We use while-loop here because "self._on_completed_fut"
|
# We use while-loop here because "self._on_completed_fut"
|
||||||
# can be cancelled multiple times if our parent task
|
# can be cancelled multiple times if our parent task
|
||||||
|
@ -118,7 +115,6 @@ async def __aexit__(self, et, exc, tb):
|
||||||
self._on_completed_fut = None
|
self._on_completed_fut = None
|
||||||
|
|
||||||
assert self._unfinished_tasks == 0
|
assert self._unfinished_tasks == 0
|
||||||
self._on_completed_fut = None # no longer needed
|
|
||||||
|
|
||||||
if self._base_error is not None:
|
if self._base_error is not None:
|
||||||
raise self._base_error
|
raise self._base_error
|
||||||
|
@ -199,8 +195,7 @@ def _on_task_done(self, task):
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|
||||||
self._abort()
|
if not self._aborting and not self._parent_cancel_requested:
|
||||||
if not self._parent_task.cancelling():
|
|
||||||
# If parent task *is not* being cancelled, it means that we want
|
# If parent task *is not* being cancelled, it means that we want
|
||||||
# to manually cancel it to abort whatever is being run right now
|
# to manually cancel it to abort whatever is being run right now
|
||||||
# in the TaskGroup. But we want to mark parent task as
|
# in the TaskGroup. But we want to mark parent task as
|
||||||
|
@ -219,5 +214,6 @@ def _on_task_done(self, task):
|
||||||
# pass
|
# pass
|
||||||
# await something_else # this line has to be called
|
# await something_else # this line has to be called
|
||||||
# # after TaskGroup is finished.
|
# # after TaskGroup is finished.
|
||||||
|
self._abort()
|
||||||
self._parent_cancel_requested = True
|
self._parent_cancel_requested = True
|
||||||
self._parent_task.cancel()
|
self._parent_task.cancel()
|
||||||
|
|
|
@ -120,7 +120,11 @@ async def runner():
|
||||||
self.assertTrue(t2_cancel)
|
self.assertTrue(t2_cancel)
|
||||||
self.assertTrue(t2.cancelled())
|
self.assertTrue(t2.cancelled())
|
||||||
|
|
||||||
async def test_taskgroup_05(self):
|
async def test_cancel_children_on_child_error(self):
|
||||||
|
"""
|
||||||
|
When a child task raises an error, the rest of the children
|
||||||
|
are cancelled and the errors are gathered into an EG.
|
||||||
|
"""
|
||||||
|
|
||||||
NUM = 0
|
NUM = 0
|
||||||
t2_cancel = False
|
t2_cancel = False
|
||||||
|
@ -165,7 +169,7 @@ async def runner():
|
||||||
self.assertTrue(t2_cancel)
|
self.assertTrue(t2_cancel)
|
||||||
self.assertTrue(runner_cancel)
|
self.assertTrue(runner_cancel)
|
||||||
|
|
||||||
async def test_taskgroup_06(self):
|
async def test_cancellation(self):
|
||||||
|
|
||||||
NUM = 0
|
NUM = 0
|
||||||
|
|
||||||
|
@ -186,10 +190,12 @@ async def runner():
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
self.assertFalse(r.done())
|
self.assertFalse(r.done())
|
||||||
r.cancel()
|
r.cancel("test")
|
||||||
with self.assertRaises(asyncio.CancelledError):
|
with self.assertRaises(asyncio.CancelledError) as cm:
|
||||||
await r
|
await r
|
||||||
|
|
||||||
|
self.assertEqual(cm.exception.args, ('test',))
|
||||||
|
|
||||||
self.assertEqual(NUM, 5)
|
self.assertEqual(NUM, 5)
|
||||||
|
|
||||||
async def test_taskgroup_07(self):
|
async def test_taskgroup_07(self):
|
||||||
|
@ -226,7 +232,7 @@ async def runner():
|
||||||
|
|
||||||
self.assertEqual(NUM, 15)
|
self.assertEqual(NUM, 15)
|
||||||
|
|
||||||
async def test_taskgroup_08(self):
|
async def test_cancellation_in_body(self):
|
||||||
|
|
||||||
async def foo():
|
async def foo():
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
@ -246,10 +252,12 @@ async def runner():
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
self.assertFalse(r.done())
|
self.assertFalse(r.done())
|
||||||
r.cancel()
|
r.cancel("test")
|
||||||
with self.assertRaises(asyncio.CancelledError):
|
with self.assertRaises(asyncio.CancelledError) as cm:
|
||||||
await r
|
await r
|
||||||
|
|
||||||
|
self.assertEqual(cm.exception.args, ('test',))
|
||||||
|
|
||||||
async def test_taskgroup_09(self):
|
async def test_taskgroup_09(self):
|
||||||
|
|
||||||
t1 = t2 = None
|
t1 = t2 = None
|
||||||
|
@ -699,3 +707,7 @@ async def coro():
|
||||||
async with taskgroups.TaskGroup() as g:
|
async with taskgroups.TaskGroup() as g:
|
||||||
t = g.create_task(coro(), name="yolo")
|
t = g.create_task(coro(), name="yolo")
|
||||||
self.assertEqual(t.get_name(), "yolo")
|
self.assertEqual(t.get_name(), "yolo")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue