mirror of https://github.com/python/cpython.git
bpo-28369: Enhance transport socket check in add_reader/writer (#4365)
This commit is contained in:
parent
f76231f89a
commit
ce12629c84
|
@ -246,8 +246,16 @@ def _accept_connection2(self, protocol_factory, conn, extra,
|
|||
self.call_exception_handler(context)
|
||||
|
||||
def _ensure_fd_no_transport(self, fd):
|
||||
fileno = fd
|
||||
if not isinstance(fileno, int):
|
||||
try:
|
||||
fileno = int(fileno.fileno())
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
# This code matches selectors._fileobj_to_fd function.
|
||||
raise ValueError("Invalid file object: "
|
||||
"{!r}".format(fd)) from None
|
||||
try:
|
||||
transport = self._transports[fd]
|
||||
transport = self._transports[fileno]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
|
|
|
@ -361,6 +361,13 @@ def assert_writer(self, fd, callback, *args):
|
|||
handle._args, args)
|
||||
|
||||
def _ensure_fd_no_transport(self, fd):
|
||||
if not isinstance(fd, int):
|
||||
try:
|
||||
fd = int(fd.fileno())
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
# This code matches selectors._fileobj_to_fd function.
|
||||
raise ValueError("Invalid file object: "
|
||||
"{!r}".format(fd)) from None
|
||||
try:
|
||||
transport = self._transports[fd]
|
||||
except KeyError:
|
||||
|
|
|
@ -1616,5 +1616,75 @@ def test_child_watcher_replace_mainloop_existing(self):
|
|||
new_loop.close()
|
||||
|
||||
|
||||
class TestFunctional(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def test_add_reader_invalid_argument(self):
|
||||
def assert_raises():
|
||||
return self.assertRaisesRegex(ValueError, r'Invalid file object')
|
||||
|
||||
cb = lambda: None
|
||||
|
||||
with assert_raises():
|
||||
self.loop.add_reader(object(), cb)
|
||||
with assert_raises():
|
||||
self.loop.add_writer(object(), cb)
|
||||
|
||||
with assert_raises():
|
||||
self.loop.remove_reader(object())
|
||||
with assert_raises():
|
||||
self.loop.remove_writer(object())
|
||||
|
||||
def test_add_reader_or_writer_transport_fd(self):
|
||||
def assert_raises():
|
||||
return self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r'File descriptor .* is used by transport')
|
||||
|
||||
async def runner():
|
||||
tr, pr = await self.loop.create_connection(
|
||||
lambda: asyncio.Protocol(), sock=rsock)
|
||||
|
||||
try:
|
||||
cb = lambda: None
|
||||
|
||||
with assert_raises():
|
||||
self.loop.add_reader(rsock, cb)
|
||||
with assert_raises():
|
||||
self.loop.add_reader(rsock.fileno(), cb)
|
||||
|
||||
with assert_raises():
|
||||
self.loop.remove_reader(rsock)
|
||||
with assert_raises():
|
||||
self.loop.remove_reader(rsock.fileno())
|
||||
|
||||
with assert_raises():
|
||||
self.loop.add_writer(rsock, cb)
|
||||
with assert_raises():
|
||||
self.loop.add_writer(rsock.fileno(), cb)
|
||||
|
||||
with assert_raises():
|
||||
self.loop.remove_writer(rsock)
|
||||
with assert_raises():
|
||||
self.loop.remove_writer(rsock.fileno())
|
||||
|
||||
finally:
|
||||
tr.close()
|
||||
|
||||
rsock, wsock = socket.socketpair()
|
||||
try:
|
||||
self.loop.run_until_complete(runner())
|
||||
finally:
|
||||
rsock.close()
|
||||
wsock.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
Enhance add_reader/writer check that socket is not used by some transport.
|
||||
Before, only cases when add_reader/writer were called with an int FD were
|
||||
supported. Now the check is implemented correctly for all file-like
|
||||
objects.
|
Loading…
Reference in New Issue