diff --git a/tests/test_tcp.py b/tests/test_tcp.py index 382b3814..47f18ba9 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -737,6 +737,130 @@ async def test(): with s1, s2: loop.run_until_complete(test()) + def test_create_connection_sock_cancel_detaches(self): + async def client(addr): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + try: + sock.connect(addr) + except BlockingIOError: + pass + await asyncio.sleep(0.01) + + task = asyncio.ensure_future( + self.loop.create_connection(asyncio.Protocol, sock=sock)) + await asyncio.sleep(0) + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + # After cancellation the socket must be detached (fd == -1) + # so that its __del__ won't close a recycled fd. + self.assertEqual(sock.fileno(), -1) + + def _recv_or_abort(sock): + try: + sock.recv_all(1) + except ConnectionAbortedError: + pass + + with self.tcp_server(_recv_or_abort, + max_clients=1, + backlog=1) as srv: + self.loop.run_until_complete(client(srv.addr)) + + def test_create_connection_sock_cancel_fd_leak(self): + # Regression test for https://github.com/MagicStack/uvloop/issues/645 + # and https://github.com/aio-libs/aiohttp/issues/10506 + # + # When create_connection(sock=sock) is cancelled, the socket must + # be detached so its close()/`__del__` won't double-close the fd. + # Without the fix, libuv closes the fd but the socket object still + # references it, enabling a chain of fd corruption and data leak: + # + # 1. cancel → libuv closes fd N + # 2. New connection (victim) reuses fd N + # 3. Stale sock.close() closes fd N → breaks the victim + # 4. Another fd N is opened (new connection) + # 5. Victim writev(N) → data goes to the wrong connection + + async def test(): + srv = await asyncio.start_server( + lambda r, w: w.close(), + '127.0.0.1', 0, + family=socket.AF_INET) + addr = srv.sockets[0].getsockname() + + # --- Step 1: create_connection with sock= and cancel it --- + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + await self.loop.sock_connect(sock, addr) + stale_fd = sock.fileno() + + task = self.loop.create_task( + self.loop.create_connection(asyncio.Protocol, sock=sock) + ) + await asyncio.sleep(0) + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + # --- Step 2: a victim connection reuses the fd --- + victim_tr, _ = await self.loop.create_connection( + asyncio.Protocol, *addr) + victim_fd = victim_tr.get_extra_info('socket').fileno() + if victim_fd != stale_fd: + victim_tr.close() + sock.close() + srv.close() + await srv.wait_closed() + raise unittest.SkipTest( + f'fd not reused (got {victim_fd}, need {stale_fd})') + + # --- Step 3: stale sock.close() must NOT kill the victim --- + # Allocate the socketpair BEFORE sock.close() so the pair + # fds don't collide with stale_fd. + spy_a, spy_b = socket.socketpair() + spy_b.setblocking(False) + + sock.close() + + # Check whether sock.close() broke the victim's fd. + victim_broken = False + try: + os.fstat(victim_fd) + except OSError: + victim_broken = True + + if victim_broken: + # The victim's fd was killed — place a spy socket on + # the freed fd (in production this would be a new + # incoming connection). + os.dup2(spy_a.fileno(), stale_fd) + spy_a.close() + + # Victim writes. If victim_broken, writev(stale_fd) goes + # to the spy; otherwise it goes to the real connection. + victim_tr.write(b'LEAKED') + + try: + leaked = spy_b.recv(4096) + except BlockingIOError: + leaked = b'' + + if victim_broken: + os.close(stale_fd) + spy_b.close() + victim_tr.close() + srv.close() + await srv.wait_closed() + + self.assertEqual(leaked, b'', + f"Data leaked to an unrelated socket: " + f"got {leaked!r}") + + self.loop.run_until_complete(test()) + class Test_UV_TCP(_TestTCP, tb.UVTestCase): diff --git a/tests/test_unix.py b/tests/test_unix.py index 0d670e39..d66dc708 100644 --- a/tests/test_unix.py +++ b/tests/test_unix.py @@ -404,6 +404,117 @@ def test_create_unix_connection_6(self): lambda: None, path='/tmp/a', ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT)) + def test_create_unix_connection_sock_cancel_detaches(self): + async def test(): + srv_path = os.path.join(tempfile.mkdtemp(), 'test.sock') + srv = await asyncio.start_unix_server( + lambda r, w: w.close(), path=srv_path) + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.setblocking(False) + try: + sock.connect(srv_path) + except BlockingIOError: + pass + await asyncio.sleep(0.01) + + task = asyncio.ensure_future( + self.loop.create_unix_connection( + asyncio.Protocol, sock=sock)) + await asyncio.sleep(0) + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertEqual(sock.fileno(), -1) + + srv.close() + await srv.wait_closed() + if os.path.exists(srv_path): + os.unlink(srv_path) + + self.loop.run_until_complete(test()) + + def test_create_unix_connection_sock_cancel_fd_leak(self): + # Same as test_create_connection_sock_cancel_fd_leak but for + # the create_unix_connection(sock=) path. + + async def test(): + srv_path = os.path.join(tempfile.mkdtemp(), 'test.sock') + srv = await asyncio.start_unix_server( + lambda r, w: w.close(), path=srv_path) + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.setblocking(False) + await self.loop.sock_connect(sock, srv_path) + stale_fd = sock.fileno() + + task = self.loop.create_task( + self.loop.create_unix_connection( + asyncio.Protocol, sock=sock)) + await asyncio.sleep(0) + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + # Create victim that reuses the fd. + victim_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + victim_sock.setblocking(False) + await self.loop.sock_connect(victim_sock, srv_path) + victim_tr, _ = await self.loop.create_unix_connection( + asyncio.Protocol, sock=victim_sock) + victim_fd = victim_tr.get_extra_info('socket').fileno() + if victim_fd != stale_fd: + victim_tr.close() + sock.close() + srv.close() + await srv.wait_closed() + if os.path.exists(srv_path): + os.unlink(srv_path) + raise unittest.SkipTest( + f'fd not reused (got {victim_fd}, need {stale_fd})') + + spy_a, spy_b = socket.socketpair() + spy_b.setblocking(False) + + sock.close() + + victim_broken = False + try: + os.fstat(victim_fd) + except OSError: + victim_broken = True + + if victim_broken: + os.dup2(spy_a.fileno(), stale_fd) + spy_a.close() + + victim_tr.write(b'LEAKED') + + try: + leaked = spy_b.recv(4096) + except BlockingIOError: + leaked = b'' + + if victim_broken: + os.close(stale_fd) + spy_b.close() + victim_tr.close() + # Let pending callbacks (e.g. server-side connection_lost + # from the cancelled connection) run before closing the + # server, to avoid triggering call_exception_handler(). + await asyncio.sleep(0) + srv.close() + await srv.wait_closed() + if os.path.exists(srv_path): + os.unlink(srv_path) + + self.assertEqual(leaked, b'', + f"Data leaked to an unrelated socket: " + f"got {leaked!r}") + + self.loop.run_until_complete(test()) + class Test_UV_Unix(_TestUnix, tb.UVTestCase): diff --git a/uvloop/loop.pyx b/uvloop/loop.pyx index 577d45a4..b316bd7f 100644 --- a/uvloop/loop.pyx +++ b/uvloop/loop.pyx @@ -2053,6 +2053,9 @@ cdef class Loop: tr = TCPTransport.new(self, protocol, None, waiter, context) try: # libuv will make socket non-blocking + # We are not detaching the PSO from the now-libuv-managed + # FD here because of: + # https://github.com/python/asyncio/pull/449 tr._open(sock.fileno()) tr._init_protocol() await waiter @@ -2065,6 +2068,15 @@ cdef class Loop: # up in `Transport._call_connection_made()`, and calling # `_close()` before it is fine. tr._close() + # Fix for: + # * https://github.com/MagicStack/uvloop/issues/645 + # * https://github.com/MagicStack/uvloop/issues/738 + # The underlying FD is closed in tr._close(), the owner of + # `sock` must not get a chance to double-close the same FD + # sometime later, because that FD may be reused by a new + # connection under load. So we detach the PSO from the + # already-closed FD here. + sock.detach() raise tr._attach_fileobj(sock) @@ -2306,7 +2318,9 @@ cdef class Loop: except (KeyboardInterrupt, SystemExit): raise except BaseException: + # See comments in create_connection() for more information tr._close() + sock.detach() raise tr._attach_fileobj(sock)