From b243f812981164bc1cccc53e6019b1530c58090b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 23 Apr 2026 23:42:43 -0700 Subject: [PATCH 1/2] fix(v1.x): drop responses on closed write streams --- src/mcp/shared/session.py | 14 ++++++-- tests/server/test_session.py | 63 +++++++++++++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 3 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 35a83fcf1..48334e1be 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -38,6 +38,8 @@ SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) + +logger = logging.getLogger(__name__) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) @@ -332,13 +334,16 @@ async def send_notification( message=JSONRPCMessage(jsonrpc_notification), metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, ) - await self._write_stream.send(session_message) + try: + await self._write_stream.send(session_message) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("Notification dropped - transport closed") + return async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) - await self._write_stream.send(session_message) else: jsonrpc_response = JSONRPCResponse( jsonrpc="2.0", @@ -346,7 +351,12 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er result=response.model_dump(by_alias=True, mode="json", exclude_none=True), ) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) + + try: await self._write_stream.send(session_message) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("Response for %s dropped - transport closed", request_id) + return async def _receive_loop(self) -> None: async with ( diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 34f9c6e28..90fba0ecb 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -11,7 +11,7 @@ from mcp.server.session import ServerSession from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder +from mcp.shared.session import BaseSession, RequestResponder from mcp.types import ( ClientNotification, Completion, @@ -220,6 +220,67 @@ async def mock_client(): assert received_protocol_version == "2024-11-05" +class _ClosedWriteStream: + async def send(self, item): + raise anyio.ClosedResourceError + + +class _OpenWriteStream: + def __init__(self): + self.items: list[SessionMessage] = [] + + async def send(self, item): + self.items.append(item) + + +class _FakeResult: + def __init__(self, payload: dict[str, Any]): + self._payload = payload + + def model_dump(self, **kwargs): + return dict(self._payload) + + +class _FakeNotification: + def __init__(self, payload: dict[str, Any]): + self._payload = payload + + def model_dump(self, **kwargs): + return dict(self._payload) + + +@pytest.mark.anyio +async def test_base_session_send_response_ignores_closed_write_stream(): + session = object.__new__(BaseSession) + session._write_stream = _ClosedWriteStream() + + await BaseSession._send_response(session, 1, _FakeResult({"ok": True})) + + +@pytest.mark.anyio +async def test_base_session_send_notification_ignores_closed_write_stream(): + session = object.__new__(BaseSession) + session._write_stream = _ClosedWriteStream() + + await BaseSession.send_notification( + session, + _FakeNotification({"method": "notifications/progress", "params": {"progress": 1}}), + ) + + +@pytest.mark.anyio +async def test_base_session_send_notification_still_writes_when_open(): + session = object.__new__(BaseSession) + session._write_stream = _OpenWriteStream() + + await BaseSession.send_notification( + session, + _FakeNotification({"method": "notifications/progress", "params": {"progress": 1}}), + ) + + assert len(session._write_stream.items) == 1 + + @pytest.mark.anyio async def test_ping_request_before_initialization(): """Test that ping requests are allowed before initialization is complete.""" From 8176935c800a7c8ab31ee5a1a7080af2339f56d7 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 24 Apr 2026 00:10:09 -0700 Subject: [PATCH 2/2] test(v1.x): fix closed-stream regression CI --- .../test_lowlevel_exception_handling.py | 22 +++++++++++++++ tests/server/test_session.py | 27 ++++++++++--------- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 5d4c3347f..c3081960a 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -1,5 +1,6 @@ from unittest.mock import AsyncMock, Mock +import anyio import pytest import mcp.types as types @@ -72,3 +73,24 @@ async def test_normal_message_handling_not_affected(): # Verify _handle_request was called server._handle_request.assert_called_once() + + +@pytest.mark.anyio +async def test_handle_request_drops_response_when_transport_is_closed(): + """Closed write streams during respond should be treated as expected drops.""" + server = Server("test-server") + session = Mock(spec=ServerSession) + + responder = Mock(spec=RequestResponder) + responder.request_id = 1 + responder.respond = AsyncMock(side_effect=anyio.ClosedResourceError()) + + await server._handle_request( + responder, + types.PingRequest(method="ping"), + session, + {}, + raise_exceptions=False, + ) + + responder.respond.assert_called_once() diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 90fba0ecb..5caadd88c 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, cast import anyio import pytest @@ -221,7 +221,7 @@ async def mock_client(): class _ClosedWriteStream: - async def send(self, item): + async def send(self, item: SessionMessage) -> None: raise anyio.ClosedResourceError @@ -229,7 +229,7 @@ class _OpenWriteStream: def __init__(self): self.items: list[SessionMessage] = [] - async def send(self, item): + async def send(self, item: SessionMessage) -> None: self.items.append(item) @@ -237,7 +237,7 @@ class _FakeResult: def __init__(self, payload: dict[str, Any]): self._payload = payload - def model_dump(self, **kwargs): + def model_dump(self, **kwargs: Any) -> dict[str, Any]: return dict(self._payload) @@ -245,24 +245,24 @@ class _FakeNotification: def __init__(self, payload: dict[str, Any]): self._payload = payload - def model_dump(self, **kwargs): + def model_dump(self, **kwargs: Any) -> dict[str, Any]: return dict(self._payload) @pytest.mark.anyio async def test_base_session_send_response_ignores_closed_write_stream(): - session = object.__new__(BaseSession) + session = cast(Any, object.__new__(BaseSession)) session._write_stream = _ClosedWriteStream() - await BaseSession._send_response(session, 1, _FakeResult({"ok": True})) + await cast(Any, BaseSession)._send_response(session, 1, _FakeResult({"ok": True})) @pytest.mark.anyio async def test_base_session_send_notification_ignores_closed_write_stream(): - session = object.__new__(BaseSession) + session = cast(Any, object.__new__(BaseSession)) session._write_stream = _ClosedWriteStream() - await BaseSession.send_notification( + await cast(Any, BaseSession).send_notification( session, _FakeNotification({"method": "notifications/progress", "params": {"progress": 1}}), ) @@ -270,15 +270,16 @@ async def test_base_session_send_notification_ignores_closed_write_stream(): @pytest.mark.anyio async def test_base_session_send_notification_still_writes_when_open(): - session = object.__new__(BaseSession) - session._write_stream = _OpenWriteStream() + open_stream = _OpenWriteStream() + session = cast(Any, object.__new__(BaseSession)) + session._write_stream = open_stream - await BaseSession.send_notification( + await cast(Any, BaseSession).send_notification( session, _FakeNotification({"method": "notifications/progress", "params": {"progress": 1}}), ) - assert len(session._write_stream.items) == 1 + assert len(open_stream.items) == 1 @pytest.mark.anyio