From 88d0a7f9f3ecbf05fff7caf0591530739dfe370e Mon Sep 17 00:00:00 2001 From: Soham Dahivalkar Date: Wed, 10 Jun 2026 17:30:26 +0530 Subject: [PATCH] Expose sampling capabilities on high-level clients --- README.v2.md | 31 +++++ src/mcp/client/client.py | 5 + src/mcp/client/session_group.py | 2 + tests/client/test_client.py | 33 ++++++ tests/client/test_session_group.py | 1 + tests/interaction/_connect.py | 8 ++ tests/interaction/_requirements.py | 9 -- tests/interaction/lowlevel/test_sampling.py | 125 ++++++++++++++++++-- 8 files changed, 198 insertions(+), 16 deletions(-) diff --git a/README.v2.md b/README.v2.md index bae230c3f9..fc97503205 100644 --- a/README.v2.md +++ b/README.v2.md @@ -963,6 +963,37 @@ async def generate_poem(topic: str, ctx: Context) -> str: _Full example: [examples/snippets/servers/sampling.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/sampling.py)_ +Clients that support sampling can also advertise fine-grained capabilities. For example, pass +`SamplingCapability(context=SamplingContextCapability())` when the client is prepared to handle +`include_context="thisServer"` or `include_context="allServers"` requests: + +```python +from mcp import ClientSession, types +from mcp.client.context import ClientRequestContext + + +async def handle_sampling( + context: ClientRequestContext, + params: types.CreateMessageRequestParams, +) -> types.CreateMessageResult: + include_context = params.include_context or "none" + return types.CreateMessageResult( + role="assistant", + content=types.TextContent(text=f"Handled sampling with context policy: {include_context}"), + model="example-client-model", + ) + + +async def run(read, write): + async with ClientSession( + read, + write, + sampling_callback=handle_sampling, + sampling_capabilities=types.SamplingCapability(context=types.SamplingContextCapability()), + ) as session: + await session.initialize() +``` + ### Logging and Notifications Tools can send logs and notifications through the context: diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index b33fea4052..3ab393b447 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -30,6 +30,7 @@ ReadResourceResult, RequestParamsMeta, ResourceTemplateReference, + SamplingCapability, ) @@ -79,6 +80,9 @@ async def main(): sampling_callback: SamplingFnT | None = None """Callback for handling sampling requests.""" + sampling_capabilities: SamplingCapability | None = None + """Fine-grained sampling capabilities advertised when sampling_callback is provided.""" + list_roots_callback: ListRootsFnT | None = None """Callback for handling list roots requests.""" @@ -121,6 +125,7 @@ async def __aenter__(self) -> Client: write_stream=write_stream, read_timeout_seconds=self.read_timeout_seconds, sampling_callback=self.sampling_callback, + sampling_capabilities=self.sampling_capabilities, list_roots_callback=self.list_roots_callback, logging_callback=self.logging_callback, message_handler=self.message_handler, diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9610212642..b70756c032 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -75,6 +75,7 @@ class ClientSessionParameters: read_timeout_seconds: float | None = None sampling_callback: SamplingFnT | None = None + sampling_capabilities: types.SamplingCapability | None = None elicitation_callback: ElicitationFnT | None = None list_roots_callback: ListRootsFnT | None = None logging_callback: LoggingFnT | None = None @@ -305,6 +306,7 @@ async def _establish_session( write, read_timeout_seconds=session_params.read_timeout_seconds, sampling_callback=session_params.sampling_callback, + sampling_capabilities=session_params.sampling_capabilities, elicitation_callback=session_params.elicitation_callback, list_roots_callback=session_params.list_roots_callback, logging_callback=session_params.logging_callback, diff --git a/tests/client/test_client.py b/tests/client/test_client.py index ac52a9024a..d6183ca515 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -12,6 +12,7 @@ from inline_snapshot import snapshot from mcp import MCPError, types +from mcp.client import ClientRequestContext from mcp.client._memory import InMemoryTransport from mcp.client.client import Client from mcp.server import Server, ServerRequestContext @@ -113,6 +114,38 @@ async def test_client_is_initialized(app: MCPServer): assert client.initialize_result.server_info.name == "test" +async def test_client_forwards_sampling_capabilities(): + """Test that Client forwards fine-grained sampling capabilities to ClientSession.""" + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="capabilities", input_schema={"type": "object"})]) + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "capabilities" + assert ctx.session.client_params is not None + sampling = ctx.session.client_params.capabilities.sampling + has_context = sampling is not None and sampling.context is not None + return CallToolResult(content=[TextContent(text=str(has_context).lower())]) + + async def sampling_callback( + context: ClientRequestContext, params: types.CreateMessageRequestParams + ) -> types.CreateMessageResult: + raise NotImplementedError + + server = Server("introspector", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + + async with Client( + server, + sampling_callback=sampling_callback, + sampling_capabilities=types.SamplingCapability(context=types.SamplingContextCapability()), + ) as client: + result = await client.call_tool("capabilities", {}) + + assert result == CallToolResult(content=[TextContent(text="true")]) + + async def test_client_with_simple_server(simple_server: Server): """Test that from_server works with a basic Server instance.""" async with Client(simple_server) as client: diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f39..9d8df7ea49 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -373,6 +373,7 @@ async def test_client_session_group_establish_session_parameterized( mock_write_stream, read_timeout_seconds=None, sampling_callback=None, + sampling_capabilities=None, elicitation_callback=None, list_roots_callback=None, logging_callback=None, diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index 1faf4aa8d6..2337d0da5d 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -38,6 +38,7 @@ JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, + SamplingCapability, jsonrpc_message_adapter, ) from tests.interaction.transports._bridge import StreamingASGITransport @@ -64,6 +65,7 @@ def __call__( *, read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, + sampling_capabilities: SamplingCapability | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, @@ -78,6 +80,7 @@ async def connect_in_memory( *, read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, + sampling_capabilities: SamplingCapability | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, @@ -89,6 +92,7 @@ async def connect_in_memory( server, read_timeout_seconds=read_timeout_seconds, sampling_callback=sampling_callback, + sampling_capabilities=sampling_capabilities, list_roots_callback=list_roots_callback, logging_callback=logging_callback, message_handler=message_handler, @@ -108,6 +112,7 @@ async def connect_over_streamable_http( retry_interval: int | None = None, read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, + sampling_capabilities: SamplingCapability | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, @@ -135,6 +140,7 @@ async def connect_over_streamable_http( streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client), read_timeout_seconds=read_timeout_seconds, sampling_callback=sampling_callback, + sampling_capabilities=sampling_capabilities, list_roots_callback=list_roots_callback, logging_callback=logging_callback, message_handler=message_handler, @@ -321,6 +327,7 @@ async def connect_over_sse( *, read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, + sampling_capabilities: SamplingCapability | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, @@ -351,6 +358,7 @@ def httpx_client_factory( transport, read_timeout_seconds=read_timeout_seconds, sampling_callback=sampling_callback, + sampling_capabilities=sampling_capabilities, list_roots_callback=list_roots_callback, logging_callback=logging_callback, message_handler=message_handler, diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index caed8905d0..fe6f675bc8 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -1145,11 +1145,6 @@ def __post_init__(self) -> None: "A sampling request carrying tools and toolChoice reaches the client, and a tool_use response " "with a toolUse stop reason returns to the requesting handler." ), - deferred=( - "Not implemented in the SDK: Client does not expose ClientSession's sampling_capabilities " - "parameter, so a client can never declare sampling.tools and the server-side validator " - "rejects every tool-enabled request before it is sent." - ), ), "sampling:create-message:audio-content": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#audio-content", @@ -1197,10 +1192,6 @@ def __post_init__(self) -> None: "When the request includes tools, the client accepts a callback result whose content is an " "array including tool_use blocks." ), - deferred=( - "Not implemented in the SDK: requires declaring sampling.tools, which the high-level client " - "cannot do (see sampling:create:tools)." - ), ), "sampling:tool-result:no-mixed-content": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#tool-result-messages", diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py index 260e564192..71025a7102 100644 --- a/tests/interaction/lowlevel/test_sampling.py +++ b/tests/interaction/lowlevel/test_sampling.py @@ -23,7 +23,9 @@ ModelHint, ModelPreferences, SamplingCapability, + SamplingContextCapability, SamplingMessage, + SamplingToolsCapability, TextContent, ToolResultContent, ToolUseContent, @@ -91,9 +93,9 @@ async def sampling_callback( async def test_create_message_params_reach_callback(connect: Connect) -> None: """Every sampling parameter the handler supplies arrives at the client callback unchanged. - The client has not declared the sampling.context capability (Client cannot declare it), yet - include_context="thisServer" reaches the callback regardless: the spec's SHOULD NOT is not - enforced. See the divergence note on `sampling:context:server-gated-by-capability`. + This client intentionally has not declared the sampling.context capability, yet + include_context="thisServer" reaches the callback regardless. See the divergence note on + `sampling:context:server-gated-by-capability`. """ received: list[CreateMessageRequestParams] = [] @@ -312,8 +314,8 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def test_create_message_with_tools_is_rejected_for_unsupporting_client(connect: Connect) -> None: """A tool-enabled sampling request to a client that has not declared sampling.tools never leaves the server. - The client supports plain sampling but cannot declare the tools sub-capability (Client does not - expose it), so the server-side validator rejects the request before anything reaches the wire. + The client supports plain sampling but does not declare the tools sub-capability, so the + server-side validator rejects the request before anything reaches the wire. """ async def list_tools( @@ -349,6 +351,82 @@ async def sampling_callback( ) +@requirement("sampling:create:tools") +@requirement("sampling:result:with-tools-array-content") +async def test_create_message_with_tools_reaches_supporting_client(connect: Connect) -> None: + """A tool-enabled sampling request reaches a client that declared sampling.tools.""" + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="What is the weather?"))], + max_tokens=100, + tools=[ + types.Tool( + name="get_weather", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + ) + ], + tool_choice=types.ToolChoice(mode="required"), + ) + assert isinstance(result, CreateMessageResultWithTools) + content = result.content_as_list + assert isinstance(content[0], ToolUseContent) + return CallToolResult(content=[TextContent(text=f"{result.stop_reason}: {content[0].name}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResultWithTools: + received.append(params) + return CreateMessageResultWithTools( + role="assistant", + content=[ToolUseContent(id="call-1", name="get_weather", input={"city": "London"})], + model="mock-llm-1", + stop_reason="toolUse", + ) + + async with connect( + server, + sampling_callback=sampling_callback, + sampling_capabilities=SamplingCapability(tools=SamplingToolsCapability()), + ) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="toolUse: get_weather")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=TextContent(text="What is the weather?"))], + max_tokens=100, + tools=[ + types.Tool( + name="get_weather", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + ) + ], + tool_choice=types.ToolChoice(mode="required"), + ) + ] + ) + + @requirement("sampling:tool-result:no-mixed-content") async def test_create_message_with_mixed_tool_result_content_is_rejected(connect: Connect) -> None: """A sampling request whose user message mixes tool_result with other content never leaves the server. @@ -406,8 +484,7 @@ async def sampling_callback( async def test_a_client_with_a_sampling_callback_declares_the_sampling_capability(connect: Connect) -> None: """A client connecting with a sampling callback advertises the sampling capability to the server. - Client cannot declare any sub-capabilities (it does not expose ClientSession's - sampling_capabilities parameter), so the snapshot pins an empty SamplingCapability. + Without explicit sampling_capabilities, the snapshot pins an empty SamplingCapability. """ captured: list[SamplingCapability | None] = [] @@ -436,6 +513,40 @@ async def sampling_callback( assert captured == snapshot([SamplingCapability()]) +@requirement("sampling:capability:declare") +async def test_a_client_can_declare_sampling_context_capability(connect: Connect) -> None: + """A client can advertise the sampling.context sub-capability when configured.""" + captured: list[SamplingCapability | None] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="capabilities", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "capabilities" + assert ctx.session.client_params is not None + captured.append(ctx.session.client_params.capabilities.sampling) + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Registered only so the sampling capability is advertised; never called.""" + raise NotImplementedError + + async with connect( + server, + sampling_callback=sampling_callback, + sampling_capabilities=SamplingCapability(context=SamplingContextCapability()), + ) as client: + await client.call_tool("capabilities", {}) + + assert captured == snapshot([SamplingCapability(context=SamplingContextCapability())]) + + @requirement("sampling:create-message:audio-content") async def test_create_message_request_with_audio_content_reaches_callback(connect: Connect) -> None: """A sampling request message carrying audio content arrives at the client callback intact.