asyncio: Fix from Anthony Baire for CPython issue 19566 (replaces earlier fix).

This commit is contained in:
Guido van Rossum 2013-11-13 15:50:08 -08:00
parent be3c2fea35
commit 2bcae708d8
3 changed files with 60 additions and 42 deletions

View File

@ -440,10 +440,13 @@ def remove_child_handler(self, pid):
raise NotImplementedError() raise NotImplementedError()
def set_loop(self, loop): def attach_loop(self, loop):
"""Reattach the watcher to another event loop. """Attach the watcher to an event loop.
Note: loop may be None If the watcher was previously attached to an event loop, then it is
first detached before attaching to the new loop.
Note: loop may be None.
""" """
raise NotImplementedError() raise NotImplementedError()
@ -467,15 +470,11 @@ def __exit__(self, a, b, c):
class BaseChildWatcher(AbstractChildWatcher): class BaseChildWatcher(AbstractChildWatcher):
def __init__(self, loop): def __init__(self):
self._loop = None self._loop = None
self._callbacks = {}
self.set_loop(loop)
def close(self): def close(self):
self.set_loop(None) self.attach_loop(None)
self._callbacks.clear()
def _do_waitpid(self, expected_pid): def _do_waitpid(self, expected_pid):
raise NotImplementedError() raise NotImplementedError()
@ -483,7 +482,7 @@ def _do_waitpid(self, expected_pid):
def _do_waitpid_all(self): def _do_waitpid_all(self):
raise NotImplementedError() raise NotImplementedError()
def set_loop(self, loop): def attach_loop(self, loop):
assert loop is None or isinstance(loop, events.AbstractEventLoop) assert loop is None or isinstance(loop, events.AbstractEventLoop)
if self._loop is not None: if self._loop is not None:
@ -497,13 +496,6 @@ def set_loop(self, loop):
# during the switch. # during the switch.
self._do_waitpid_all() self._do_waitpid_all()
def remove_child_handler(self, pid):
try:
del self._callbacks[pid]
return True
except KeyError:
return False
def _sig_chld(self): def _sig_chld(self):
try: try:
self._do_waitpid_all() self._do_waitpid_all()
@ -535,6 +527,14 @@ class SafeChildWatcher(BaseChildWatcher):
big number of children (O(n) each time SIGCHLD is raised) big number of children (O(n) each time SIGCHLD is raised)
""" """
def __init__(self):
super().__init__()
self._callbacks = {}
def close(self):
self._callbacks.clear()
super().close()
def __enter__(self): def __enter__(self):
return self return self
@ -547,6 +547,13 @@ def add_child_handler(self, pid, callback, *args):
# Prevent a race condition in case the child is already terminated. # Prevent a race condition in case the child is already terminated.
self._do_waitpid(pid) self._do_waitpid(pid)
def remove_child_handler(self, pid):
try:
del self._callbacks[pid]
return True
except KeyError:
return False
def _do_waitpid_all(self): def _do_waitpid_all(self):
for pid in list(self._callbacks): for pid in list(self._callbacks):
@ -592,17 +599,17 @@ class FastChildWatcher(BaseChildWatcher):
There is no noticeable overhead when handling a big number of children There is no noticeable overhead when handling a big number of children
(O(1) each time a child terminates). (O(1) each time a child terminates).
""" """
def __init__(self, loop): def __init__(self):
super().__init__()
self._callbacks = {}
self._lock = threading.Lock() self._lock = threading.Lock()
self._zombies = {} self._zombies = {}
self._forks = 0 self._forks = 0
# Call base class constructor last because it calls back into
# the subclass (set_loop() calls _do_waitpid()).
super().__init__(loop)
def close(self): def close(self):
super().close() self._callbacks.clear()
self._zombies.clear() self._zombies.clear()
super().close()
def __enter__(self): def __enter__(self):
with self._lock: with self._lock:
@ -643,6 +650,13 @@ def add_child_handler(self, pid, callback, *args):
else: else:
callback(pid, returncode, *args) callback(pid, returncode, *args)
def remove_child_handler(self, pid):
try:
del self._callbacks[pid]
return True
except KeyError:
return False
def _do_waitpid_all(self): def _do_waitpid_all(self):
# Because of signal coalescing, we must keep calling waitpid() as # Because of signal coalescing, we must keep calling waitpid() as
# long as we're able to reap a child. # long as we're able to reap a child.
@ -687,25 +701,24 @@ def __init__(self):
def _init_watcher(self): def _init_watcher(self):
with events._lock: with events._lock:
if self._watcher is None: # pragma: no branch if self._watcher is None: # pragma: no branch
self._watcher = SafeChildWatcher()
if isinstance(threading.current_thread(), if isinstance(threading.current_thread(),
threading._MainThread): threading._MainThread):
self._watcher = SafeChildWatcher(self._local._loop) self._watcher.attach_loop(self._local._loop)
else:
self._watcher = SafeChildWatcher(None)
def set_event_loop(self, loop): def set_event_loop(self, loop):
"""Set the event loop. """Set the event loop.
As a side effect, if a child watcher was set before, then calling As a side effect, if a child watcher was set before, then calling
.set_event_loop() from the main thread will call .set_loop(loop) on the .set_event_loop() from the main thread will call .attach_loop(loop) on
child watcher. the child watcher.
""" """
super().set_event_loop(loop) super().set_event_loop(loop)
if self._watcher is not None and \ if self._watcher is not None and \
isinstance(threading.current_thread(), threading._MainThread): isinstance(threading.current_thread(), threading._MainThread):
self._watcher.set_loop(loop) self._watcher.attach_loop(loop)
def get_child_watcher(self): def get_child_watcher(self):
"""Get the child watcher """Get the child watcher

View File

@ -1311,7 +1311,9 @@ def test_create_datagram_endpoint(self):
class UnixEventLoopTestsMixin(EventLoopTestsMixin): class UnixEventLoopTestsMixin(EventLoopTestsMixin):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
events.set_child_watcher(unix_events.SafeChildWatcher(self.loop)) watcher = unix_events.SafeChildWatcher()
watcher.attach_loop(self.loop)
events.set_child_watcher(watcher)
def tearDown(self): def tearDown(self):
events.set_child_watcher(None) events.set_child_watcher(None)

View File

@ -687,7 +687,7 @@ def test_not_implemented(self):
self.assertRaises( self.assertRaises(
NotImplementedError, watcher.remove_child_handler, f) NotImplementedError, watcher.remove_child_handler, f)
self.assertRaises( self.assertRaises(
NotImplementedError, watcher.set_loop, f) NotImplementedError, watcher.attach_loop, f)
self.assertRaises( self.assertRaises(
NotImplementedError, watcher.close) NotImplementedError, watcher.close)
self.assertRaises( self.assertRaises(
@ -700,7 +700,7 @@ class BaseChildWatcherTests(unittest.TestCase):
def test_not_implemented(self): def test_not_implemented(self):
f = unittest.mock.Mock() f = unittest.mock.Mock()
watcher = unix_events.BaseChildWatcher(None) watcher = unix_events.BaseChildWatcher()
self.assertRaises( self.assertRaises(
NotImplementedError, watcher._do_waitpid, f) NotImplementedError, watcher._do_waitpid, f)
@ -720,11 +720,14 @@ def setUp(self):
with unittest.mock.patch.object( with unittest.mock.patch.object(
self.loop, "add_signal_handler") as self.m_add_signal_handler: self.loop, "add_signal_handler") as self.m_add_signal_handler:
self.watcher = self.create_watcher(self.loop) self.watcher = self.create_watcher()
self.watcher.attach_loop(self.loop)
def tearDown(self): def cleanup():
ChildWatcherTestsMixin.instance = None ChildWatcherTestsMixin.instance = None
self.addCleanup(cleanup)
def waitpid(pid, flags): def waitpid(pid, flags):
self = ChildWatcherTestsMixin.instance self = ChildWatcherTestsMixin.instance
if isinstance(self.watcher, unix_events.SafeChildWatcher) or pid != -1: if isinstance(self.watcher, unix_events.SafeChildWatcher) or pid != -1:
@ -1334,7 +1337,7 @@ def test_set_loop(
self.loop, self.loop,
"add_signal_handler") as m_new_add_signal_handler: "add_signal_handler") as m_new_add_signal_handler:
self.watcher.set_loop(self.loop) self.watcher.attach_loop(self.loop)
m_old_remove_signal_handler.assert_called_once_with( m_old_remove_signal_handler.assert_called_once_with(
signal.SIGCHLD) signal.SIGCHLD)
@ -1375,7 +1378,7 @@ def test_set_loop_race_condition(
with unittest.mock.patch.object( with unittest.mock.patch.object(
old_loop, "remove_signal_handler") as m_remove_signal_handler: old_loop, "remove_signal_handler") as m_remove_signal_handler:
self.watcher.set_loop(None) self.watcher.attach_loop(None)
m_remove_signal_handler.assert_called_once_with( m_remove_signal_handler.assert_called_once_with(
signal.SIGCHLD) signal.SIGCHLD)
@ -1395,7 +1398,7 @@ def test_set_loop_race_condition(
with unittest.mock.patch.object( with unittest.mock.patch.object(
self.loop, "add_signal_handler") as m_add_signal_handler: self.loop, "add_signal_handler") as m_add_signal_handler:
self.watcher.set_loop(self.loop) self.watcher.attach_loop(self.loop)
m_add_signal_handler.assert_called_once_with( m_add_signal_handler.assert_called_once_with(
signal.SIGCHLD, self.watcher._sig_chld) signal.SIGCHLD, self.watcher._sig_chld)
@ -1457,13 +1460,13 @@ def test_close(
class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
def create_watcher(self, loop): def create_watcher(self):
return unix_events.SafeChildWatcher(loop) return unix_events.SafeChildWatcher()
class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
def create_watcher(self, loop): def create_watcher(self):
return unix_events.FastChildWatcher(loop) return unix_events.FastChildWatcher()
class PolicyTests(unittest.TestCase): class PolicyTests(unittest.TestCase):
@ -1485,7 +1488,7 @@ def test_get_child_watcher(self):
def test_get_child_watcher_after_set(self): def test_get_child_watcher_after_set(self):
policy = self.create_policy() policy = self.create_policy()
watcher = unix_events.FastChildWatcher(None) watcher = unix_events.FastChildWatcher()
policy.set_child_watcher(watcher) policy.set_child_watcher(watcher)
self.assertIs(policy._watcher, watcher) self.assertIs(policy._watcher, watcher)