From 697c715741f8632b674edff38fb42bb125bd3d4e Mon Sep 17 00:00:00 2001 From: Matthieu Angibaud <276304331+mangibaud33@users.noreply.github.com> Date: Wed, 22 Apr 2026 16:50:00 +0200 Subject: [PATCH] feat(auth): persist oauth_metadata via TokenStorage to fix refresh after restart Closes part of #1318. Problem: OAuthClientProvider._initialize() restores tokens and client_info from storage but not oauth_metadata. After a restart with cached tokens, _refresh_token() falls back to /token (via get_authorization_base_url + urljoin), which is incorrect for servers whose token endpoint is at a non-standard path (e.g. HubSpot's MCP Auth App uses /oauth/v3/token). Refresh requests return 404, cascading into a full interactive OAuth flow that cannot complete in non-interactive environments (daemons, containers, long-running services). Fix: add optional get_oauth_metadata / set_oauth_metadata methods to the TokenStorage protocol. _initialize now restores metadata alongside tokens and client_info via a getattr fallback that preserves backward compatibility with storage implementations predating this API; they return None and the refresh path falls back to the legacy behaviour as before. Discovered metadata is persisted after OASM discovery in the 401 flow so subsequent restarts can resolve the correct token endpoint without rediscovery. Coverage: added 5 tests (3 feature + 2 to cover the getattr fallback branches and the Protocol default bodies). Also removed three "# pragma: no cover" markers on lines now exercised by the new tests per AGENTS.md. --- src/mcp/client/auth/oauth2.py | 40 +++++- tests/client/test_auth.py | 260 +++++++++++++++++++++++++++++++++- 2 files changed, 294 insertions(+), 6 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 72309f577..7ce36c6ee 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -87,6 +87,24 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None """Store client information.""" ... + async def get_oauth_metadata(self) -> OAuthMetadata | None: + """Get stored authorization server metadata. + + Optional: implementations may return ``None`` if metadata persistence + is not desired. Implementations that persist tokens across restarts + should also persist metadata so :meth:`OAuthClientProvider._refresh_token` + can resolve the correct token endpoint without rediscovering metadata + on every restart. + """ + return None + + async def set_oauth_metadata(self, metadata: OAuthMetadata) -> None: + """Store authorization server metadata. + + Optional: no-op by default. See :meth:`get_oauth_metadata`. + """ + return + @dataclass class OAuthContext: @@ -473,10 +491,19 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: # p self.context.clear_tokens() return False - async def _initialize(self) -> None: # pragma: no cover - """Load stored tokens and client info.""" + async def _initialize(self) -> None: + """Load stored tokens, client info, and authorization server metadata.""" self.context.current_tokens = await self.context.storage.get_tokens() self.context.client_info = await self.context.storage.get_client_info() + # Restore authorization server metadata so ``_refresh_token`` can + # resolve the correct token endpoint without rediscovering it on + # every restart. ``getattr`` preserves backward compatibility with + # storage implementations predating ``get_oauth_metadata``: they + # return ``None`` and the refresh path falls back to the legacy + # ``/token`` behaviour as before. + meta_getter = getattr(self.context.storage, "get_oauth_metadata", None) + if meta_getter is not None: + self.context.oauth_metadata = await meta_getter() self._initialized = True def _add_auth_header(self, request: httpx.Request) -> None: @@ -507,7 +534,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. """HTTPX auth flow integration.""" async with self.context.lock: if not self._initialized: - await self._initialize() # pragma: no cover + await self._initialize() # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) @@ -572,6 +599,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. break if ok and asm: self.context.oauth_metadata = asm + # Persist so subsequent restarts can resolve the + # correct token endpoint without rediscovery. + meta_setter = getattr(self.context.storage, "set_oauth_metadata", None) + if meta_setter is not None: + await meta_setter(asm) break else: logger.debug(f"OAuth metadata discovery failed: {url}") @@ -612,7 +644,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 5: Perform authorization and complete token exchange token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) - except Exception: # pragma: no cover + except Exception: logger.exception("OAuth flow error") raise diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index bb0bce4c9..f19ff0650 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -43,19 +43,26 @@ class MockTokenStorage: def __init__(self): self._tokens: OAuthToken | None = None self._client_info: OAuthClientInformationFull | None = None + self._oauth_metadata: OAuthMetadata | None = None async def get_tokens(self) -> OAuthToken | None: - return self._tokens # pragma: no cover + return self._tokens async def set_tokens(self, tokens: OAuthToken) -> None: self._tokens = tokens async def get_client_info(self) -> OAuthClientInformationFull | None: - return self._client_info # pragma: no cover + return self._client_info async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: self._client_info = client_info + async def get_oauth_metadata(self) -> OAuthMetadata | None: + return self._oauth_metadata + + async def set_oauth_metadata(self, metadata: OAuthMetadata) -> None: + self._oauth_metadata = metadata + @pytest.fixture def mock_storage(): @@ -2618,3 +2625,252 @@ async def callback_handler() -> tuple[str, str | None]: await auth_flow.asend(final_response) except StopAsyncIteration: pass + + +# --- Regression coverage for #1318: restore oauth_metadata on _initialize --- + + +@pytest.mark.anyio +async def test_initialize_restores_oauth_metadata( + oauth_provider: OAuthClientProvider, + mock_storage: MockTokenStorage, +): + """``_initialize`` should restore ``oauth_metadata`` from storage. + + Without this, ``_refresh_token`` loses the authoritative token endpoint + discovered during the prior session and falls back to ``/token`` + after every restart — a 404 for servers whose token endpoint sits on a + non-standard path. + """ + stored_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/oauth/v3/token"), + ) + await mock_storage.set_oauth_metadata(stored_metadata) + + await oauth_provider._initialize() + + assert oauth_provider.context.oauth_metadata is not None + assert str(oauth_provider.context.oauth_metadata.token_endpoint) == ("https://auth.example.com/oauth/v3/token") + + +@pytest.mark.anyio +async def test_refresh_token_uses_persisted_metadata_endpoint( + oauth_provider: OAuthClientProvider, + mock_storage: MockTokenStorage, + valid_tokens: OAuthToken, +): + """After a restart with persisted metadata, ``_refresh_token`` uses the + correct ``token_endpoint`` rather than the ``/token`` fallback. + """ + custom_token_endpoint = "https://auth.example.com/oauth/v3/token" + await mock_storage.set_tokens(valid_tokens) + await mock_storage.set_client_info( + OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + token_endpoint_auth_method="client_secret_post", + ) + ) + await mock_storage.set_oauth_metadata( + OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl(custom_token_endpoint), + ) + ) + + await oauth_provider._initialize() + request = await oauth_provider._refresh_token() + + assert str(request.url) == custom_token_endpoint + + +@pytest.mark.anyio +async def test_initialize_backward_compat_without_metadata_methods( + client_metadata: OAuthClientMetadata, + valid_tokens: OAuthToken, +): + """Storage implementations predating ``get_oauth_metadata`` keep working. + + Duck-typed ``TokenStorage`` instances written before this method was + introduced must not raise ``AttributeError`` on ``_initialize``. + """ + + class LegacyStorage: + """Duck-typed storage matching the pre-change ``TokenStorage``.""" + + def __init__(self, tokens: OAuthToken | None): + self._tokens = tokens + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens # pragma: no cover + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info # pragma: no cover + + legacy_storage = LegacyStorage(valid_tokens) + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=legacy_storage, # type: ignore[arg-type] + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + await provider._initialize() + + assert provider.context.current_tokens is valid_tokens + assert provider.context.oauth_metadata is None + + +@pytest.mark.anyio +async def test_token_storage_protocol_default_metadata_methods(): + """``TokenStorage`` provides no-op defaults for the optional metadata methods. + + Storage subclasses that don't care about metadata persistence can inherit + ``TokenStorage`` without overriding ``get_oauth_metadata`` / + ``set_oauth_metadata``; the default ``get`` returns ``None`` and the + default ``set`` is a no-op (equivalent to opting out of persistence). + """ + from mcp.client.auth.oauth2 import TokenStorage + + class DefaultStorage(TokenStorage): + async def get_tokens(self) -> OAuthToken | None: + return None # pragma: no cover + + async def set_tokens(self, tokens: OAuthToken) -> None: ... # pragma: no cover + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return None # pragma: no cover + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: ... # pragma: no cover + + storage = DefaultStorage() + assert await storage.get_oauth_metadata() is None + + metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + ) + # No-op: set completes without storing + await storage.set_oauth_metadata(metadata) + assert await storage.get_oauth_metadata() is None + + +@pytest.mark.anyio +async def test_auth_flow_discovery_with_legacy_storage_skips_metadata_persistence( + client_metadata: OAuthClientMetadata, +): + """OAuth discovery succeeds when storage lacks ``set_oauth_metadata``. + + Covers the ``getattr`` fallback branch in ``async_auth_flow`` that bypasses + persistence for storage implementations predating the metadata API. + """ + + class LegacyStorage: + """Duck-typed storage matching the pre-change ``TokenStorage``.""" + + def __init__(self) -> None: + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens # pragma: no cover + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info # pragma: no cover + + legacy_storage = LegacyStorage() + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=legacy_storage, # type: ignore[arg-type] + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + test_request = httpx.Request("GET", "https://api.example.com/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First yield: ``_initialize`` loads state from LegacyStorage (no tokens, + # no client info, no metadata fallback) and the original request goes out + # without an auth header. + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # 401 → triggers full OAuth flow + unauthorized_response = httpx.Response( + 401, + headers={ + "WWW-Authenticate": ( + 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + ) + }, + request=test_request, + ) + prm_request = await auth_flow.asend(unauthorized_response) + assert "oauth-protected-resource" in str(prm_request.url) + + # PRM discovery response + prm_response = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}' + ), + request=prm_request, + ) + asm_request = await auth_flow.asend(prm_response) + assert str(asm_request.url).startswith("https://auth.example.com/") + + # OASM discovery response — this is where our set_oauth_metadata + # fallback (meta_setter is None) executes for LegacyStorage. + asm_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register"}' + ), + request=asm_request, + ) + next_request = await auth_flow.asend(asm_response) + + # Discovery succeeded: flow advanced past metadata handling. + # (Legacy storage had no set_oauth_metadata, so persistence is skipped.) + assert next_request is not None + assert provider.context.oauth_metadata is not None + assert str(provider.context.oauth_metadata.token_endpoint) == "https://auth.example.com/token" + + await auth_flow.aclose()