Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)

logger = logging.getLogger(__name__)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)

Expand Down Expand Up @@ -332,21 +334,29 @@ async def send_notification(
message=JSONRPCMessage(jsonrpc_notification),
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
)
await self._write_stream.send(session_message)
try:
await self._write_stream.send(session_message)
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("Notification dropped - transport closed")
return

async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
await self._write_stream.send(session_message)
else:
jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))

try:
await self._write_stream.send(session_message)
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("Response for %s dropped - transport closed", request_id)
return

async def _receive_loop(self) -> None:
async with (
Expand Down
22 changes: 22 additions & 0 deletions tests/server/test_lowlevel_exception_handling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import AsyncMock, Mock

import anyio
import pytest

import mcp.types as types
Expand Down Expand Up @@ -72,3 +73,24 @@ async def test_normal_message_handling_not_affected():

# Verify _handle_request was called
server._handle_request.assert_called_once()


@pytest.mark.anyio
async def test_handle_request_drops_response_when_transport_is_closed():
"""Closed write streams during respond should be treated as expected drops."""
server = Server("test-server")
session = Mock(spec=ServerSession)

responder = Mock(spec=RequestResponder)
responder.request_id = 1
responder.respond = AsyncMock(side_effect=anyio.ClosedResourceError())

await server._handle_request(
responder,
types.PingRequest(method="ping"),
session,
{},
raise_exceptions=False,
)

responder.respond.assert_called_once()
66 changes: 64 additions & 2 deletions tests/server/test_session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, cast

import anyio
import pytest
Expand All @@ -11,7 +11,7 @@
from mcp.server.session import ServerSession
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
from mcp.shared.session import BaseSession, RequestResponder
from mcp.types import (
ClientNotification,
Completion,
Expand Down Expand Up @@ -220,6 +220,68 @@ async def mock_client():
assert received_protocol_version == "2024-11-05"


class _ClosedWriteStream:
async def send(self, item: SessionMessage) -> None:
raise anyio.ClosedResourceError


class _OpenWriteStream:
def __init__(self):
self.items: list[SessionMessage] = []

async def send(self, item: SessionMessage) -> None:
self.items.append(item)


class _FakeResult:
def __init__(self, payload: dict[str, Any]):
self._payload = payload

def model_dump(self, **kwargs: Any) -> dict[str, Any]:
return dict(self._payload)


class _FakeNotification:
def __init__(self, payload: dict[str, Any]):
self._payload = payload

def model_dump(self, **kwargs: Any) -> dict[str, Any]:
return dict(self._payload)


@pytest.mark.anyio
async def test_base_session_send_response_ignores_closed_write_stream():
session = cast(Any, object.__new__(BaseSession))
session._write_stream = _ClosedWriteStream()

await cast(Any, BaseSession)._send_response(session, 1, _FakeResult({"ok": True}))


@pytest.mark.anyio
async def test_base_session_send_notification_ignores_closed_write_stream():
session = cast(Any, object.__new__(BaseSession))
session._write_stream = _ClosedWriteStream()

await cast(Any, BaseSession).send_notification(
session,
_FakeNotification({"method": "notifications/progress", "params": {"progress": 1}}),
)


@pytest.mark.anyio
async def test_base_session_send_notification_still_writes_when_open():
open_stream = _OpenWriteStream()
session = cast(Any, object.__new__(BaseSession))
session._write_stream = open_stream

await cast(Any, BaseSession).send_notification(
session,
_FakeNotification({"method": "notifications/progress", "params": {"progress": 1}}),
)

assert len(open_stream.items) == 1


@pytest.mark.anyio
async def test_ping_request_before_initialization():
"""Test that ping requests are allowed before initialization is complete."""
Expand Down
Loading