bpo-38415: Allow using @asynccontextmanager-made ctx managers as decorators (GH-16667)

This commit is contained in:
Jason Fried 2021-09-23 14:36:03 -07:00 committed by GitHub
parent af90b5498b
commit 86b833badd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 0 deletions

View File

@ -191,6 +191,14 @@ class _AsyncGeneratorContextManager(
):
"""Helper for @asynccontextmanager decorator."""
def __call__(self, func):
@wraps(func)
async def inner(*args, **kwds):
async with self.__class__(self.func, self.args, self.kwds):
return await func(*args, **kwds)
return inner
async def __aenter__(self):
# do not keep args and kwds alive unnecessarily
# they are only needed for recreation, which is not possible anymore

View File

@ -318,6 +318,82 @@ async def recursive():
self.assertEqual(ncols, 10)
self.assertEqual(depth, 0)
@_async_test
async def test_decorator(self):
entered = False
@asynccontextmanager
async def context():
nonlocal entered
entered = True
yield
entered = False
@context()
async def test():
self.assertTrue(entered)
self.assertFalse(entered)
await test()
self.assertFalse(entered)
@_async_test
async def test_decorator_with_exception(self):
entered = False
@asynccontextmanager
async def context():
nonlocal entered
try:
entered = True
yield
finally:
entered = False
@context()
async def test():
self.assertTrue(entered)
raise NameError('foo')
self.assertFalse(entered)
with self.assertRaisesRegex(NameError, 'foo'):
await test()
self.assertFalse(entered)
@_async_test
async def test_decorating_method(self):
@asynccontextmanager
async def context():
yield
class Test(object):
@context()
async def method(self, a, b, c=None):
self.a = a
self.b = b
self.c = c
# these tests are for argument passing when used as a decorator
test = Test()
await test.method(1, 2)
self.assertEqual(test.a, 1)
self.assertEqual(test.b, 2)
self.assertEqual(test.c, None)
test = Test()
await test.method('a', 'b', 'c')
self.assertEqual(test.a, 'a')
self.assertEqual(test.b, 'b')
self.assertEqual(test.c, 'c')
test = Test()
await test.method(a=1, b=2)
self.assertEqual(test.a, 1)
self.assertEqual(test.b, 2)
class AclosingTestCase(unittest.TestCase):

View File

@ -0,0 +1,3 @@
Added missing behavior to :func:`contextlib.asynccontextmanager` to match
:func:`contextlib.contextmanager` so decorated functions can themselves be
decorators.