Taskgroup tweaks (GH-31559)

Now uses .cancel()/.uncancel(), for even fewer broken edge cases.
This commit is contained in:
Tin Tvrtković 2022-02-26 17:18:48 +01:00 committed by GitHub
parent 41ddcd3f40
commit edbee56d69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 33 deletions

View File

@ -66,31 +66,28 @@ async def __aexit__(self, et, exc, tb):
self._base_error is None): self._base_error is None):
self._base_error = exc self._base_error = exc
if et is exceptions.CancelledError: if et is not None:
if self._parent_cancel_requested:
# Only if we did request task to cancel ourselves
# we mark it as no longer cancelled.
self._parent_task.uncancel()
else:
propagate_cancellation_error = et
if et is not None and not self._aborting:
# Our parent task is being cancelled:
#
# async with TaskGroup() as g:
# g.create_task(...)
# await ... # <- CancelledError
#
if et is exceptions.CancelledError: if et is exceptions.CancelledError:
propagate_cancellation_error = et if self._parent_cancel_requested and not self._parent_task.uncancel():
# Do nothing, i.e. swallow the error.
pass
else:
propagate_cancellation_error = exc
# or there's an exception in "async with": if not self._aborting:
# # Our parent task is being cancelled:
# async with TaskGroup() as g: #
# g.create_task(...) # async with TaskGroup() as g:
# 1 / 0 # g.create_task(...)
# # await ... # <- CancelledError
self._abort() #
# or there's an exception in "async with":
#
# async with TaskGroup() as g:
# g.create_task(...)
# 1 / 0
#
self._abort()
# We use while-loop here because "self._on_completed_fut" # We use while-loop here because "self._on_completed_fut"
# can be cancelled multiple times if our parent task # can be cancelled multiple times if our parent task
@ -118,7 +115,6 @@ async def __aexit__(self, et, exc, tb):
self._on_completed_fut = None self._on_completed_fut = None
assert self._unfinished_tasks == 0 assert self._unfinished_tasks == 0
self._on_completed_fut = None # no longer needed
if self._base_error is not None: if self._base_error is not None:
raise self._base_error raise self._base_error
@ -199,8 +195,7 @@ def _on_task_done(self, task):
}) })
return return
self._abort() if not self._aborting and not self._parent_cancel_requested:
if not self._parent_task.cancelling():
# If parent task *is not* being cancelled, it means that we want # If parent task *is not* being cancelled, it means that we want
# to manually cancel it to abort whatever is being run right now # to manually cancel it to abort whatever is being run right now
# in the TaskGroup. But we want to mark parent task as # in the TaskGroup. But we want to mark parent task as
@ -219,5 +214,6 @@ def _on_task_done(self, task):
# pass # pass
# await something_else # this line has to be called # await something_else # this line has to be called
# # after TaskGroup is finished. # # after TaskGroup is finished.
self._abort()
self._parent_cancel_requested = True self._parent_cancel_requested = True
self._parent_task.cancel() self._parent_task.cancel()

View File

@ -120,7 +120,11 @@ async def runner():
self.assertTrue(t2_cancel) self.assertTrue(t2_cancel)
self.assertTrue(t2.cancelled()) self.assertTrue(t2.cancelled())
async def test_taskgroup_05(self): async def test_cancel_children_on_child_error(self):
"""
When a child task raises an error, the rest of the children
are cancelled and the errors are gathered into an EG.
"""
NUM = 0 NUM = 0
t2_cancel = False t2_cancel = False
@ -165,7 +169,7 @@ async def runner():
self.assertTrue(t2_cancel) self.assertTrue(t2_cancel)
self.assertTrue(runner_cancel) self.assertTrue(runner_cancel)
async def test_taskgroup_06(self): async def test_cancellation(self):
NUM = 0 NUM = 0
@ -186,10 +190,12 @@ async def runner():
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
self.assertFalse(r.done()) self.assertFalse(r.done())
r.cancel() r.cancel("test")
with self.assertRaises(asyncio.CancelledError): with self.assertRaises(asyncio.CancelledError) as cm:
await r await r
self.assertEqual(cm.exception.args, ('test',))
self.assertEqual(NUM, 5) self.assertEqual(NUM, 5)
async def test_taskgroup_07(self): async def test_taskgroup_07(self):
@ -226,7 +232,7 @@ async def runner():
self.assertEqual(NUM, 15) self.assertEqual(NUM, 15)
async def test_taskgroup_08(self): async def test_cancellation_in_body(self):
async def foo(): async def foo():
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@ -246,10 +252,12 @@ async def runner():
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
self.assertFalse(r.done()) self.assertFalse(r.done())
r.cancel() r.cancel("test")
with self.assertRaises(asyncio.CancelledError): with self.assertRaises(asyncio.CancelledError) as cm:
await r await r
self.assertEqual(cm.exception.args, ('test',))
async def test_taskgroup_09(self): async def test_taskgroup_09(self):
t1 = t2 = None t1 = t2 = None
@ -699,3 +707,7 @@ async def coro():
async with taskgroups.TaskGroup() as g: async with taskgroups.TaskGroup() as g:
t = g.create_task(coro(), name="yolo") t = g.create_task(coro(), name="yolo")
self.assertEqual(t.get_name(), "yolo") self.assertEqual(t.get_name(), "yolo")
if __name__ == "__main__":
unittest.main()