diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 35a83fcf1..48334e1be 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -38,6 +38,8 @@ SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) + +logger = logging.getLogger(__name__) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) @@ -332,13 +334,16 @@ async def send_notification( message=JSONRPCMessage(jsonrpc_notification), metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, ) - await self._write_stream.send(session_message) + try: + await self._write_stream.send(session_message) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("Notification dropped - transport closed") + return async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) - await self._write_stream.send(session_message) else: jsonrpc_response = JSONRPCResponse( jsonrpc="2.0", @@ -346,7 +351,12 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er result=response.model_dump(by_alias=True, mode="json", exclude_none=True), ) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) + + try: await self._write_stream.send(session_message) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("Response for %s dropped - transport closed", request_id) + return async def _receive_loop(self) -> None: async with ( diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 5d4c3347f..c3081960a 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -1,5 +1,6 @@ from unittest.mock import AsyncMock, Mock +import anyio import pytest import mcp.types as types @@ -72,3 +73,24 @@ async def test_normal_message_handling_not_affected(): # Verify _handle_request was called server._handle_request.assert_called_once() + + +@pytest.mark.anyio +async def test_handle_request_drops_response_when_transport_is_closed(): + """Closed write streams during respond should be treated as expected drops.""" + server = Server("test-server") + session = Mock(spec=ServerSession) + + responder = Mock(spec=RequestResponder) + responder.request_id = 1 + responder.respond = AsyncMock(side_effect=anyio.ClosedResourceError()) + + await server._handle_request( + responder, + types.PingRequest(method="ping"), + session, + {}, + raise_exceptions=False, + ) + + responder.respond.assert_called_once() diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 34f9c6e28..5caadd88c 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, cast import anyio import pytest @@ -11,7 +11,7 @@ from mcp.server.session import ServerSession from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder +from mcp.shared.session import BaseSession, RequestResponder from mcp.types import ( ClientNotification, Completion, @@ -220,6 +220,68 @@ async def mock_client(): assert received_protocol_version == "2024-11-05" +class _ClosedWriteStream: + async def send(self, item: SessionMessage) -> None: + raise anyio.ClosedResourceError + + +class _OpenWriteStream: + def __init__(self): + self.items: list[SessionMessage] = [] + + async def send(self, item: SessionMessage) -> None: + self.items.append(item) + + +class _FakeResult: + def __init__(self, payload: dict[str, Any]): + self._payload = payload + + def model_dump(self, **kwargs: Any) -> dict[str, Any]: + return dict(self._payload) + + +class _FakeNotification: + def __init__(self, payload: dict[str, Any]): + self._payload = payload + + def model_dump(self, **kwargs: Any) -> dict[str, Any]: + return dict(self._payload) + + +@pytest.mark.anyio +async def test_base_session_send_response_ignores_closed_write_stream(): + session = cast(Any, object.__new__(BaseSession)) + session._write_stream = _ClosedWriteStream() + + await cast(Any, BaseSession)._send_response(session, 1, _FakeResult({"ok": True})) + + +@pytest.mark.anyio +async def test_base_session_send_notification_ignores_closed_write_stream(): + session = cast(Any, object.__new__(BaseSession)) + session._write_stream = _ClosedWriteStream() + + await cast(Any, BaseSession).send_notification( + session, + _FakeNotification({"method": "notifications/progress", "params": {"progress": 1}}), + ) + + +@pytest.mark.anyio +async def test_base_session_send_notification_still_writes_when_open(): + open_stream = _OpenWriteStream() + session = cast(Any, object.__new__(BaseSession)) + session._write_stream = open_stream + + await cast(Any, BaseSession).send_notification( + session, + _FakeNotification({"method": "notifications/progress", "params": {"progress": 1}}), + ) + + assert len(open_stream.items) == 1 + + @pytest.mark.anyio async def test_ping_request_before_initialization(): """Test that ping requests are allowed before initialization is complete."""