From 8f1f6587c9d955d71fc175a159547f2edb5084e1 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Thu, 11 Jun 2026 01:37:17 +0800 Subject: [PATCH 1/2] fix: preserve OAuth endpoint query params --- src/mcp/client/auth/oauth2.py | 27 ++++++++++++++------ tests/client/test_auth.py | 46 +++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 3c546fda2b..48d074d240 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -9,10 +9,10 @@ import secrets import string import time -from collections.abc import AsyncGenerator, Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping from dataclasses import dataclass, field from typing import Any, Protocol -from urllib.parse import quote, urlencode, urljoin, urlparse +from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlunparse import anyio import httpx @@ -53,6 +53,13 @@ logger = logging.getLogger(__name__) +def _append_url_query_params(url: str, params: Mapping[str, str]) -> str: + parsed = urlparse(url) + query_params = parse_qsl(parsed.query, keep_blank_values=True) + query_params.extend(params.items()) + return urlunparse(parsed._replace(query=urlencode(query_params))) + + class PKCEParameters(BaseModel): """PKCE (Proof Key for Code Exchange) parameters.""" @@ -327,14 +334,17 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: if not self.context.client_info: raise OAuthFlowError("No client info available for authorization") # pragma: no cover + client_id = self.context.client_info.client_id + if not client_id: + raise OAuthFlowError("No client ID available for authorization") # pragma: no cover # Generate PKCE parameters pkce_params = PKCEParameters.generate() state = secrets.token_urlsafe(32) - auth_params = { + auth_params: dict[str, str] = { "response_type": "code", - "client_id": self.context.client_info.client_id, + "client_id": client_id, "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), "state": state, "code_challenge": pkce_params.code_challenge, @@ -345,15 +355,16 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: if self.context.should_include_resource_param(self.context.protocol_version): auth_params["resource"] = self.context.get_resource_url() # RFC 8707 - if self.context.client_metadata.scope: # pragma: no branch - auth_params["scope"] = self.context.client_metadata.scope + scope = self.context.client_metadata.scope + if scope: # pragma: no branch + auth_params["scope"] = scope # OIDC requires prompt=consent when offline_access is requested # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess - if "offline_access" in self.context.client_metadata.scope.split(): + if "offline_access" in scope.split(): auth_params["prompt"] = "consent" - authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" + authorization_url = _append_url_query_params(auth_endpoint, auth_params) await self.context.redirect_handler(authorization_url) # Wait for callback diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index bb0bce4c92..5a3ecf371c 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -606,6 +606,52 @@ async def test_token_exchange_request_authorization_code(self, oauth_provider: O assert "client_id=test_client" in content assert "client_secret=test_secret" in content + @pytest.mark.anyio + async def test_authorization_endpoint_preserves_existing_query_params( + self, oauth_provider: OAuthClientProvider + ): + """Authorization endpoint query params should survive OAuth parameter injection.""" + captured_auth_url: str | None = None + captured_state: str | None = None + + async def redirect_handler(url: str) -> None: + nonlocal captured_auth_url, captured_state + captured_auth_url = url + captured_state = parse_qs(urlparse(url).query)["state"][0] + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", captured_state + + oauth_provider.context.redirect_handler = redirect_handler + oauth_provider.context.callback_handler = callback_handler + oauth_provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://test.salesforce.com"), + authorization_endpoint=AnyHttpUrl( + "https://test.salesforce.com/services/oauth2/authorize?prompt=select_account" + ), + token_endpoint=AnyHttpUrl("https://test.salesforce.com/services/oauth2/token"), + ) + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + auth_code, code_verifier = await oauth_provider._perform_authorization_code_grant() + + assert auth_code == "test_auth_code" + assert code_verifier + assert captured_auth_url is not None + parsed = urlparse(captured_auth_url) + params = parse_qs(parsed.query) + assert parsed.scheme == "https" + assert parsed.netloc == "test.salesforce.com" + assert parsed.path == "/services/oauth2/authorize" + assert params["prompt"] == ["select_account"] + assert params["response_type"] == ["code"] + assert params["client_id"] == ["test_client"] + assert params["redirect_uri"] == ["http://localhost:3030/callback"] + @pytest.mark.anyio async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): """Test refresh token request building.""" From a13bb4147ac05b8cf334e5d77dc815e0fdf3a023 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Thu, 11 Jun 2026 03:29:31 +0800 Subject: [PATCH 2/2] style: format OAuth query endpoint test --- tests/client/test_auth.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 5a3ecf371c..07b74b6a9d 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -607,9 +607,7 @@ async def test_token_exchange_request_authorization_code(self, oauth_provider: O assert "client_secret=test_secret" in content @pytest.mark.anyio - async def test_authorization_endpoint_preserves_existing_query_params( - self, oauth_provider: OAuthClientProvider - ): + async def test_authorization_endpoint_preserves_existing_query_params(self, oauth_provider: OAuthClientProvider): """Authorization endpoint query params should survive OAuth parameter injection.""" captured_auth_url: str | None = None captured_state: str | None = None