diff --git a/pyproject.toml b/pyproject.toml index a5d2c3d80..3ff21a251 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ dev = [ "pillow>=12.0", "strict-no-cover", "logfire>=3.0.0", + "opentelemetry-sdk>=1.39.1", ] docs = [ "mkdocs>=1.6.1", diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 4d35f8a90..1c855ae48 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -1,8 +1,10 @@ from __future__ import annotations +from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Any, Generic +from typing import Any, Generic, Protocol +from pydantic import BaseModel from typing_extensions import TypeVar from mcp.server._typed_request import TypedServerRequestMixin @@ -81,3 +83,35 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, * if meta: params["_meta"] = meta await self.notify("notifications/message", params) + + +HandlerResult = BaseModel | dict[str, Any] | None +"""What a request handler (or middleware) may return. `ServerRunner` serializes +all three to a result dict.""" + +CallNext = Callable[[], Awaitable[HandlerResult]] + +_MwLifespanT = TypeVar("_MwLifespanT", contravariant=True) + + +class ContextMiddleware(Protocol[_MwLifespanT]): + """Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``. + + Runs *inside* `ServerRunner._on_request` after params validation and + `Context` construction. Wraps registered handlers (including ``ping``) but + not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed + outermost-first on `Server.middleware`. + + `Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific + middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific + types) can be typed `ContextMiddleware[object]` — `Context` is covariant in + `LifespanT`, so it registers on any `Server[L]`. + """ + + async def __call__( + self, + ctx: Context[_MwLifespanT, TransportContext], + method: str, + params: BaseModel, + call_next: CallNext, + ) -> HandlerResult: ... diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 59de0ace4..9dc44708f 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -58,7 +58,7 @@ async def main(): from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings -from mcp.server.context import ServerRequestContext +from mcp.server.context import ContextMiddleware, ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.models import InitializationOptions @@ -199,6 +199,9 @@ def __init__( ] = {} self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None + # Context-tier middleware consumed by `ServerRunner`. Additive; the + # existing `run()` path ignores it. + self.middleware: list[ContextMiddleware[LifespanResultT]] = [] logger.debug("Initializing server %r", name) # Populate internal handler dicts from on_* kwargs @@ -246,6 +249,16 @@ def _has_handler(self, method: str) -> bool: """Check if a handler is registered for the given method.""" return method in self._request_handlers or method in self._notification_handlers + # --- ServerRegistry protocol (consumed by ServerRunner) ------------------ + + def get_request_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: + """Return the handler for a request method, or ``None``.""" + return self._request_handlers.get(method) + + def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: + """Return the handler for a notification method, or ``None``.""" + return self._notification_handlers.get(method) + # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py new file mode 100644 index 000000000..bb3af0443 --- /dev/null +++ b/src/mcp/server/runner.py @@ -0,0 +1,295 @@ +"""`ServerRunner` — per-connection orchestrator over a `Dispatcher`. + +`ServerRunner` is the bridge between the dispatcher layer (`on_request` / +`on_notify`, untyped dicts) and the user's handler layer (typed `Context`, +typed params). One instance per client connection. It: + +* handles the ``initialize`` handshake and populates `Connection` +* gates requests until initialized (``ping`` exempt) +* looks up the handler in the server's registry, validates params, builds + `Context`, runs the middleware chain, returns the result dict +* drives ``dispatcher.run()`` and the per-connection lifespan + +`ServerRunner` consumes any `ServerRegistry` — the lowlevel `Server` satisfies +it via additive methods so the existing ``Server.run()`` path is unaffected. +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable, Mapping, Sequence +from dataclasses import dataclass, field +from functools import partial, reduce +from typing import Any, Generic, Protocol, cast + +import anyio.abc +from opentelemetry.trace import SpanKind, StatusCode +from pydantic import BaseModel +from typing_extensions import TypeVar + +from mcp.server.connection import Connection +from mcp.server.context import CallNext, Context, ContextMiddleware +from mcp.server.lowlevel.server import NotificationOptions +from mcp.shared._otel import extract_trace_context, otel_span +from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest +from mcp.shared.exceptions import MCPError +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + CallToolRequestParams, + CompleteRequestParams, + GetPromptRequestParams, + Implementation, + InitializeRequestParams, + InitializeResult, + NotificationParams, + PaginatedRequestParams, + ProgressNotificationParams, + ReadResourceRequestParams, + RequestParams, + ServerCapabilities, + SetLevelRequestParams, + SubscribeRequestParams, + UnsubscribeRequestParams, +) + +__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner", "otel_middleware"] + +logger = logging.getLogger(__name__) + +LifespanT = TypeVar("LifespanT", default=Any) +ServerTransportT = TypeVar("ServerTransportT", bound=TransportContext, default=TransportContext) + +Handler = Callable[..., Awaitable[Any]] +"""A request/notification handler: ``(ctx, params) -> result``. Typed loosely +so the existing `ServerRequestContext`-based handlers and the new +`Context`-based handlers both fit during the transition. +""" + + +_INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) + +# TODO: remove this lookup once `Server` stores (params_type, handler) in its +# registry directly. This is scaffolding so ServerRunner can validate params +# without changing the existing `_request_handlers` dict shape. +_PARAMS_FOR_METHOD: dict[str, type[BaseModel]] = { + "ping": RequestParams, + "tools/list": PaginatedRequestParams, + "tools/call": CallToolRequestParams, + "prompts/list": PaginatedRequestParams, + "prompts/get": GetPromptRequestParams, + "resources/list": PaginatedRequestParams, + "resources/templates/list": PaginatedRequestParams, + "resources/read": ReadResourceRequestParams, + "resources/subscribe": SubscribeRequestParams, + "resources/unsubscribe": UnsubscribeRequestParams, + "logging/setLevel": SetLevelRequestParams, + "completion/complete": CompleteRequestParams, +} +"""Spec method → params model. Scaffolding while the lowlevel `Server`'s +`_request_handlers` stores handler-only; the registry refactor should make this +the registry's responsibility (or store params types alongside handlers).""" + +_PARAMS_FOR_NOTIFICATION: dict[str, type[BaseModel]] = { + "notifications/initialized": NotificationParams, + "notifications/roots/list_changed": NotificationParams, + "notifications/progress": ProgressNotificationParams, +} + + +class ServerRegistry(Protocol): + """The handler registry `ServerRunner` consumes. + + The lowlevel `Server` satisfies this via additive methods. + """ + + @property + def name(self) -> str: ... + @property + def version(self) -> str | None: ... + + @property + def middleware(self) -> Sequence[ContextMiddleware[Any]]: ... + + def get_request_handler(self, method: str) -> Handler | None: ... + def get_notification_handler(self, method: str) -> Handler | None: ... + def get_capabilities( + self, notification_options: Any, experimental_capabilities: dict[str, dict[str, Any]] + ) -> ServerCapabilities: ... + + +def otel_middleware(next_on_request: OnRequest) -> OnRequest: + """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. + + Mirrors the span shape of the existing `Server._handle_request`: span name + ``"MCP handle []"``, ``mcp.method.name`` attribute, W3C + trace context extracted from ``params._meta`` (SEP-414), and an ERROR + status if the handler raises. + """ + + async def wrapped( + dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + target: str | None + match params: + case {"name": str() as target}: + pass + case _: + target = None + parent: Any | None + match params: + case {"_meta": {**meta}}: + parent = extract_trace_context(meta) + case _: + parent = None + span_name = f"MCP handle {method}{f' {target}' if target else ''}" + with otel_span( + span_name, + kind=SpanKind.SERVER, + attributes={"mcp.method.name": method}, + context=parent, + record_exception=False, + set_status_on_exception=False, + ) as span: + try: + return await next_on_request(dctx, method, params) + except MCPError as e: + span.set_status(StatusCode.ERROR, e.error.message) + raise + except Exception as e: + span.record_exception(e) + span.set_status(StatusCode.ERROR, str(e)) + raise + + return wrapped + + +def _dump_result(result: Any) -> dict[str, Any]: + if result is None: + return {} + if isinstance(result, BaseModel): + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + if isinstance(result, dict): + return cast(dict[str, Any], result) + raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None") + + +@dataclass +class ServerRunner(Generic[LifespanT, ServerTransportT]): + """Per-connection orchestrator. One instance per client connection.""" + + server: ServerRegistry + dispatcher: Dispatcher[ServerTransportT] + lifespan_state: LifespanT + has_standalone_channel: bool + stateless: bool = False + dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) + + connection: Connection = field(init=False) + _initialized: bool = field(init=False) + + def __post_init__(self) -> None: + self._initialized = self.stateless + self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) + + async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: + """Drive the dispatcher until the underlying channel closes. + + Composes `dispatch_middleware` over `_on_request` and hands the result + to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers + can ``await tg.start(runner.run)`` and resume once the dispatcher is + ready to accept requests. + """ + await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) + + def _compose_on_request(self) -> OnRequest: + """Wrap `_on_request` in `dispatch_middleware`, outermost-first. + + Dispatch-tier middleware sees raw ``(dctx, method, params) -> dict`` + and wraps everything — initialize, METHOD_NOT_FOUND, validation + failures included. `run()` calls this once and hands the result to + `dispatcher.run()`. + """ + return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request) + + async def _on_request( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> dict[str, Any]: + if method == "initialize": + return self._handle_initialize(params) + if not self._initialized and method not in _INIT_EXEMPT: + raise MCPError( + code=INVALID_REQUEST, + message=f"Received {method!r} before initialization was complete", + ) + handler = self.server.get_request_handler(method) + if handler is None: + raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}") + # TODO: scaffolding — params_type comes from a static lookup until the + # registry stores it alongside the handler. + params_type = _PARAMS_FOR_METHOD.get(method, RequestParams) + # ValidationError propagates; the dispatcher's exception boundary maps + # it to INVALID_PARAMS. + typed_params = params_type.model_validate(params or {}) + ctx = self._make_context(dctx, typed_params) + call: CallNext = partial(handler, ctx, typed_params) + for mw in reversed(self.server.middleware): + call = partial(mw, ctx, method, typed_params, call) + return _dump_result(await call()) + + async def _on_notify( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> None: + if method == "notifications/initialized": + self._initialized = True + self.connection.initialized.set() + return + if not self._initialized: + logger.debug("dropped %s: received before initialization", method) + return + handler = self.server.get_notification_handler(method) + if handler is None: + logger.debug("no handler for notification %s", method) + return + params_type = _PARAMS_FOR_NOTIFICATION.get(method, NotificationParams) + typed_params = params_type.model_validate(params or {}) + ctx = self._make_context(dctx, typed_params) + await handler(ctx, typed_params) + + def _make_context( + self, dctx: DispatchContext[TransportContext], typed_params: BaseModel + ) -> Context[LifespanT, ServerTransportT]: + # `OnRequest` delivers `DispatchContext[TransportContext]`; this + # ServerRunner instance was constructed for a specific + # `ServerTransportT`, so the narrow is safe by construction. + narrowed = cast(DispatchContext[ServerTransportT], dctx) + meta = getattr(typed_params, "meta", None) + return Context(narrowed, lifespan=self.lifespan_state, connection=self.connection, meta=meta) + + def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: + init = InitializeRequestParams.model_validate(params or {}) + self.connection.client_info = init.client_info + self.connection.client_capabilities = init.capabilities + # TODO: real version negotiation. This always responds with LATEST, + # which is wrong — the server should pick the highest version both + # sides support and compute a per-connection feature set from it. + # See FOLLOWUPS: "Consolidate per-connection mode/negotiation". + self.connection.protocol_version = ( + init.protocol_version if init.protocol_version in {LATEST_PROTOCOL_VERSION} else LATEST_PROTOCOL_VERSION + ) + self._initialized = True + self.connection.initialized.set() + result = InitializeResult( + protocol_version=self.connection.protocol_version, + capabilities=self.server.get_capabilities(NotificationOptions(), {}), + server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"), + ) + return _dump_result(result) diff --git a/src/mcp/shared/_otel.py b/src/mcp/shared/_otel.py index 170e873a0..553b8a0bc 100644 --- a/src/mcp/shared/_otel.py +++ b/src/mcp/shared/_otel.py @@ -20,9 +20,18 @@ def otel_span( kind: SpanKind, attributes: dict[str, Any] | None = None, context: Context | None = None, + record_exception: bool = True, + set_status_on_exception: bool = True, ) -> Iterator[Any]: """Create an OTel span.""" - with _tracer.start_as_current_span(name, kind=kind, attributes=attributes, context=context) as span: + with _tracer.start_as_current_span( + name, + kind=kind, + attributes=attributes, + context=context, + record_exception=record_exception, + set_status_on_exception=set_status_on_exception, + ) as span: yield span diff --git a/tests/conftest.py b/tests/conftest.py index af7e47993..b83c47213 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,16 @@ +import os + import pytest +# OpenTelemetry's `set_tracer_provider` is set-once per process, so the suite +# uses a single span-capture mechanism: logfire's `capfire` fixture (its +# `configure()` swaps span processors on repeat calls rather than re-setting +# the provider). Logfire's default `distributed_tracing=None` emits a +# RuntimeWarning + diagnostic span when incoming W3C trace context is +# extracted; several tests exercise that propagation deliberately, so opt in +# suite-wide. Set before logfire is imported anywhere. +os.environ.setdefault("LOGFIRE_DISTRIBUTED_TRACING", "true") + @pytest.fixture def anyio_backend(): diff --git a/tests/server/conftest.py b/tests/server/conftest.py new file mode 100644 index 000000000..290ccc957 --- /dev/null +++ b/tests/server/conftest.py @@ -0,0 +1,45 @@ +"""Shared fixtures for server-side tests.""" + +from collections.abc import Iterator + +import pytest +from logfire.testing import CaptureLogfire, TestExporter +from opentelemetry.sdk.trace import ReadableSpan + + +class SpanCapture: + """Thin adapter over logfire's `TestExporter` for asserting on MCP spans. + + `finished()` returns the raw `ReadableSpan` objects emitted by the + ``mcp-python-sdk`` instrumentation scope, filtered to exclude logfire's + synthetic ``pending_span`` markers, so tests can assert directly on + `.name`, `.kind`, `.status`, `.attributes`, `.parent`, `.events`. + """ + + def __init__(self, exporter: TestExporter) -> None: + self._exporter = exporter + + def clear(self) -> None: + self._exporter.clear() + + def finished(self) -> list[ReadableSpan]: + return [ + s + for s in self._exporter.exported_spans + if s.instrumentation_scope is not None + and s.instrumentation_scope.name == "mcp-python-sdk" + and not (s.attributes and s.attributes.get("logfire.span_type") == "pending_span") + ] + + +@pytest.fixture +def spans(capfire: CaptureLogfire) -> Iterator[SpanCapture]: + """In-memory MCP span capture, cleared before and after each test. + + Backed by the project-level `capfire` override (see ``tests/conftest.py``) + so there is a single global tracer provider for the suite. + """ + capture = SpanCapture(capfire.exporter) + capture.clear() + yield capture + capture.clear() diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py new file mode 100644 index 000000000..843b0ae8b --- /dev/null +++ b/tests/server/test_runner.py @@ -0,0 +1,340 @@ +"""Tests for `ServerRunner`. + +End-to-end over `DirectDispatcher` with a real lowlevel `Server` as the +registry. The `connected_runner` helper starts both sides and (by default) +performs the initialize handshake, so each test exercises only the behaviour +under test. +""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import anyio +import anyio.lowlevel +import pytest +from opentelemetry.trace import SpanKind, StatusCode + +from mcp.server.connection import Connection +from mcp.server.context import Context +from mcp.server.lowlevel.server import Server +from mcp.server.runner import ServerRunner, otel_middleware +from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair +from mcp.shared.dispatcher import DispatchMiddleware +from mcp.shared.exceptions import MCPError +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + INTERNAL_ERROR, + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + ClientCapabilities, + Implementation, + InitializeRequestParams, + Tool, +) + +from ..shared.test_dispatcher import Recorder, echo_handlers +from .conftest import SpanCapture + + +def _initialize_params() -> dict[str, Any]: + return InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test-client", version="1.0"), + ).model_dump(by_alias=True, exclude_none=True) + + +_seen_ctx: list[Context[Any, TransportContext]] = [] +SrvT = Server[dict[str, Any]] + + +@pytest.fixture +def server() -> SrvT: + """A lowlevel Server with one tools/list handler registered.""" + _seen_ctx.clear() + + async def list_tools(ctx: Any, params: Any) -> Any: + # ctx is typed `Any` because Server's on_list_tools kwarg expects the + # legacy ServerRequestContext shape; ServerRunner passes the new + # `Context`. The transition is intentional — Handler is loosely typed. + _seen_ctx.append(ctx) + return {"tools": [Tool(name="t", input_schema={"type": "object"}).model_dump(by_alias=True)]} + + return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) + + +@asynccontextmanager +async def connected_runner( + server: SrvT, + *, + initialized: bool = True, + stateless: bool = False, + has_standalone_channel: bool = True, + dispatch_middleware: list[DispatchMiddleware] | None = None, +) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[None, TransportContext]]]: + """Yield ``(client, runner)`` running over an in-memory dispatcher pair. + + Starts the client (echo handlers) and `runner.run()` in a task group, wraps + the body in ``anyio.fail_after(5)``, and cancels on exit. When + ``initialized`` is true the helper performs the real ``initialize`` request + before yielding, so tests start past the init-gate via the public path. + """ + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=has_standalone_channel, + stateless=stateless, + dispatch_middleware=dispatch_middleware or [], + ) + c_req, c_notify = echo_handlers(Recorder()) + body_exc: BaseException | None = None + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(runner.run) + try: + with anyio.fail_after(5): + if initialized: + await client.send_raw_request("initialize", _initialize_params()) + yield client, runner + except BaseException as e: + # Capture and re-raise outside the task group so test failures + # surface as the original exception, not an ExceptionGroup wrapper. + body_exc = e + client.close() + server_d.close() + if body_exc is not None: + raise body_exc + + +@pytest.mark.anyio +async def test_connected_runner_propagates_body_exception_unwrapped(server: SrvT): + """The harness re-raises body exceptions as-is, not as ``ExceptionGroup``.""" + with pytest.raises(RuntimeError, match="boom"): + async with connected_runner(server): + raise RuntimeError("boom") + + +@pytest.mark.anyio +async def test_runner_handles_initialize_and_populates_connection(server: SrvT): + async with connected_runner(server, initialized=False) as (client, runner): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["serverInfo"]["name"] == "test-server" + assert "tools" in result["capabilities"] + assert runner.connection.client_info is not None + assert runner.connection.client_info.name == "test-client" + assert runner.connection.protocol_version == LATEST_PROTOCOL_VERSION + assert runner._initialized is True + + +@pytest.mark.anyio +async def test_runner_gates_requests_before_initialize(server: SrvT): + async with connected_runner(server, initialized=False) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INVALID_REQUEST + # ping is exempt from the gate + assert await client.send_raw_request("ping", None) == {} + + +@pytest.mark.anyio +async def test_runner_routes_to_handler_and_builds_context(server: SrvT): + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + ctx = _seen_ctx[0] + assert isinstance(ctx, Context) + assert ctx.lifespan is None + assert isinstance(ctx.connection, Connection) + assert ctx.transport.kind == "direct" + + +@pytest.mark.anyio +async def test_runner_unknown_method_raises_method_not_found(server: SrvT): + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND + + +@pytest.mark.anyio +async def test_runner_on_notify_initialized_sets_flag_and_connection_event(server: SrvT): + async with connected_runner(server, initialized=False) as (client, runner): + await client.notify("notifications/initialized", None) + await runner.connection.initialized.wait() + assert runner._initialized is True + + +@pytest.mark.anyio +async def test_runner_on_notify_routes_to_registered_handler(server: SrvT): + seen: list[tuple[Any, Any]] = [] + + async def on_roots_changed(ctx: Any, params: Any) -> None: + seen.append((ctx, params)) + + server._notification_handlers["notifications/roots/list_changed"] = on_roots_changed + async with connected_runner(server) as (client, _): + await client.notify("notifications/roots/list_changed", None) + # DirectDispatcher delivers synchronously; one yield is enough. + await anyio.lowlevel.checkpoint() + assert len(seen) == 1 + assert isinstance(seen[0][0], Context) + + +@pytest.mark.anyio +async def test_runner_on_notify_drops_before_init_and_unknown_methods(server: SrvT): + async with connected_runner(server, initialized=False) as (client, _): + await client.notify("notifications/roots/list_changed", None) # before init: dropped + await client.notify("notifications/initialized", None) + await client.notify("notifications/unknown", None) # no handler: dropped + # No exception raised; both drops are silent. + + +@pytest.mark.anyio +async def test_runner_dispatch_middleware_wraps_everything_including_initialize(server: SrvT): + seen_methods: list[str] = [] + + def trace_mw(next_on_request: Any) -> Any: + async def wrapped(dctx: Any, method: str, params: Any) -> Any: + seen_methods.append(method) + return await next_on_request(dctx, method, params) + + return wrapped + + async with connected_runner(server, dispatch_middleware=[trace_mw]) as (client, _): + await client.send_raw_request("tools/list", None) + assert seen_methods == ["initialize", "tools/list"] + + +@pytest.mark.anyio +async def test_runner_server_middleware_wraps_handlers_but_not_initialize(server: SrvT): + seen_methods: list[str] = [] + + async def ctx_mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + seen_methods.append(method) + return await call_next() + + server.middleware.append(ctx_mw) + async with connected_runner(server) as (client, _): + await client.send_raw_request("ping", None) + await client.send_raw_request("tools/list", None) + # initialize (sent by the helper) NOT wrapped; ping and tools/list ARE. + assert seen_methods == ["ping", "tools/list"] + + +@pytest.mark.anyio +async def test_runner_server_middleware_runs_outermost_first(server: SrvT): + order: list[str] = [] + + def make_mw(tag: str) -> Any: + async def mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + order.append(f"{tag}-in") + result = await call_next() + order.append(f"{tag}-out") + return result + + return mw + + server.middleware.extend([make_mw("a"), make_mw("b")]) + async with connected_runner(server) as (client, _): + await client.send_raw_request("tools/list", None) + assert order == ["a-in", "b-in", "b-out", "a-out"] + + +@pytest.mark.anyio +async def test_runner_handler_returning_none_yields_empty_result(server: SrvT): + async def set_level(ctx: Any, params: Any) -> None: + return None + + server._request_handlers["logging/setLevel"] = set_level + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("logging/setLevel", {"level": "info"}) + assert result == {} + + +@pytest.mark.anyio +async def test_runner_handler_returning_unsupported_type_surfaces_as_internal_error(server: SrvT): + async def bad_return(ctx: Any, params: Any) -> int: + return 42 + + server._request_handlers["tools/list"] = bad_return + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + assert "int" in exc.value.error.message + + +@pytest.mark.anyio +async def test_runner_stateless_skips_init_gate(server: SrvT): + async with connected_runner(server, initialized=False, stateless=True, has_standalone_channel=False) as (client, _): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + + +@pytest.mark.anyio +async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: SpanCapture): + async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: + return {"content": [], "isError": False} + + server._request_handlers["tools/call"] = call_tool + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + result = await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {}}) + assert result == {"content": [], "isError": False} + [span] = spans.finished() + assert span.name == "MCP handle tools/call mytool" + assert span.kind == SpanKind.SERVER + assert span.attributes is not None + assert span.attributes["mcp.method.name"] == "tools/call" + assert span.status.status_code == StatusCode.UNSET + + +@pytest.mark.anyio +async def test_otel_middleware_extracts_parent_context_from_meta(server: SrvT, spans: SpanCapture): + parent_span_id = "b7ad6b7169203331" + traceparent = f"00-0af7651916cd43dd8448eb211c80319c-{parent_span_id}-01" + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + await client.send_raw_request("tools/list", {"_meta": {"traceparent": traceparent}}) + [span] = spans.finished() + assert span.parent is not None + assert format(span.parent.span_id, "016x") == parent_span_id + assert span.context is not None + assert format(span.context.trace_id, "032x") == "0af7651916cd43dd8448eb211c80319c" + + +@pytest.mark.anyio +async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, spans: SpanCapture): + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND + [span] = spans.finished() + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "Method not found: nonexistent/method" + # MCPError is a protocol-level response, not a crash — no traceback event. + assert not [e for e in span.events if e.name == "exception"] + + +@pytest.mark.anyio +async def test_otel_middleware_records_error_status_on_handler_exception(server: SrvT, spans: SpanCapture): + async def failing(ctx: Any, params: Any) -> Any: + raise ValueError("handler blew up") + + server._request_handlers["tools/list"] = failing + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + [span] = spans.finished() + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "handler blew up" + [event] = [e for e in span.events if e.name == "exception"] + assert event.attributes is not None + assert event.attributes["exception.type"] == "ValueError" diff --git a/tests/shared/test_otel.py b/tests/shared/test_otel.py index ec7ff78cc..a7df4c429 100644 --- a/tests/shared/test_otel.py +++ b/tests/shared/test_otel.py @@ -10,9 +10,6 @@ pytestmark = pytest.mark.anyio -# Logfire warns about propagated trace context by default (distributed_tracing=None). -# This is expected here since we're testing cross-boundary context propagation. -@pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_client_and_server_spans(capfire: CaptureLogfire): """Verify that calling a tool produces client and server spans with correct attributes.""" server = MCPServer("test") diff --git a/uv.lock b/uv.lock index 705d014aa..71bac4cce 100644 --- a/uv.lock +++ b/uv.lock @@ -872,6 +872,7 @@ dev = [ { name = "inline-snapshot" }, { name = "logfire" }, { name = "mcp", extra = ["cli", "ws"] }, + { name = "opentelemetry-sdk" }, { name = "pillow" }, { name = "pyright" }, { name = "pytest" }, @@ -924,6 +925,7 @@ dev = [ { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "logfire", specifier = ">=3.0.0" }, { name = "mcp", extras = ["cli", "ws"], editable = "." }, + { name = "opentelemetry-sdk", specifier = ">=1.39.1" }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, { name = "pytest", specifier = ">=8.3.4" },