diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 72309f577..7ce36c6ee 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -87,6 +87,24 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None """Store client information.""" ... + async def get_oauth_metadata(self) -> OAuthMetadata | None: + """Get stored authorization server metadata. + + Optional: implementations may return ``None`` if metadata persistence + is not desired. Implementations that persist tokens across restarts + should also persist metadata so :meth:`OAuthClientProvider._refresh_token` + can resolve the correct token endpoint without rediscovering metadata + on every restart. + """ + return None + + async def set_oauth_metadata(self, metadata: OAuthMetadata) -> None: + """Store authorization server metadata. + + Optional: no-op by default. See :meth:`get_oauth_metadata`. + """ + return + @dataclass class OAuthContext: @@ -473,10 +491,19 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: # p self.context.clear_tokens() return False - async def _initialize(self) -> None: # pragma: no cover - """Load stored tokens and client info.""" + async def _initialize(self) -> None: + """Load stored tokens, client info, and authorization server metadata.""" self.context.current_tokens = await self.context.storage.get_tokens() self.context.client_info = await self.context.storage.get_client_info() + # Restore authorization server metadata so ``_refresh_token`` can + # resolve the correct token endpoint without rediscovering it on + # every restart. ``getattr`` preserves backward compatibility with + # storage implementations predating ``get_oauth_metadata``: they + # return ``None`` and the refresh path falls back to the legacy + # ``/token`` behaviour as before. + meta_getter = getattr(self.context.storage, "get_oauth_metadata", None) + if meta_getter is not None: + self.context.oauth_metadata = await meta_getter() self._initialized = True def _add_auth_header(self, request: httpx.Request) -> None: @@ -507,7 +534,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. """HTTPX auth flow integration.""" async with self.context.lock: if not self._initialized: - await self._initialize() # pragma: no cover + await self._initialize() # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) @@ -572,6 +599,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. break if ok and asm: self.context.oauth_metadata = asm + # Persist so subsequent restarts can resolve the + # correct token endpoint without rediscovery. + meta_setter = getattr(self.context.storage, "set_oauth_metadata", None) + if meta_setter is not None: + await meta_setter(asm) break else: logger.debug(f"OAuth metadata discovery failed: {url}") @@ -612,7 +644,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 5: Perform authorization and complete token exchange token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) - except Exception: # pragma: no cover + except Exception: logger.exception("OAuth flow error") raise diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index bb0bce4c9..f19ff0650 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -43,19 +43,26 @@ class MockTokenStorage: def __init__(self): self._tokens: OAuthToken | None = None self._client_info: OAuthClientInformationFull | None = None + self._oauth_metadata: OAuthMetadata | None = None async def get_tokens(self) -> OAuthToken | None: - return self._tokens # pragma: no cover + return self._tokens async def set_tokens(self, tokens: OAuthToken) -> None: self._tokens = tokens async def get_client_info(self) -> OAuthClientInformationFull | None: - return self._client_info # pragma: no cover + return self._client_info async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: self._client_info = client_info + async def get_oauth_metadata(self) -> OAuthMetadata | None: + return self._oauth_metadata + + async def set_oauth_metadata(self, metadata: OAuthMetadata) -> None: + self._oauth_metadata = metadata + @pytest.fixture def mock_storage(): @@ -2618,3 +2625,252 @@ async def callback_handler() -> tuple[str, str | None]: await auth_flow.asend(final_response) except StopAsyncIteration: pass + + +# --- Regression coverage for #1318: restore oauth_metadata on _initialize --- + + +@pytest.mark.anyio +async def test_initialize_restores_oauth_metadata( + oauth_provider: OAuthClientProvider, + mock_storage: MockTokenStorage, +): + """``_initialize`` should restore ``oauth_metadata`` from storage. + + Without this, ``_refresh_token`` loses the authoritative token endpoint + discovered during the prior session and falls back to ``/token`` + after every restart — a 404 for servers whose token endpoint sits on a + non-standard path. + """ + stored_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/oauth/v3/token"), + ) + await mock_storage.set_oauth_metadata(stored_metadata) + + await oauth_provider._initialize() + + assert oauth_provider.context.oauth_metadata is not None + assert str(oauth_provider.context.oauth_metadata.token_endpoint) == ("https://auth.example.com/oauth/v3/token") + + +@pytest.mark.anyio +async def test_refresh_token_uses_persisted_metadata_endpoint( + oauth_provider: OAuthClientProvider, + mock_storage: MockTokenStorage, + valid_tokens: OAuthToken, +): + """After a restart with persisted metadata, ``_refresh_token`` uses the + correct ``token_endpoint`` rather than the ``/token`` fallback. + """ + custom_token_endpoint = "https://auth.example.com/oauth/v3/token" + await mock_storage.set_tokens(valid_tokens) + await mock_storage.set_client_info( + OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + token_endpoint_auth_method="client_secret_post", + ) + ) + await mock_storage.set_oauth_metadata( + OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl(custom_token_endpoint), + ) + ) + + await oauth_provider._initialize() + request = await oauth_provider._refresh_token() + + assert str(request.url) == custom_token_endpoint + + +@pytest.mark.anyio +async def test_initialize_backward_compat_without_metadata_methods( + client_metadata: OAuthClientMetadata, + valid_tokens: OAuthToken, +): + """Storage implementations predating ``get_oauth_metadata`` keep working. + + Duck-typed ``TokenStorage`` instances written before this method was + introduced must not raise ``AttributeError`` on ``_initialize``. + """ + + class LegacyStorage: + """Duck-typed storage matching the pre-change ``TokenStorage``.""" + + def __init__(self, tokens: OAuthToken | None): + self._tokens = tokens + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens # pragma: no cover + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info # pragma: no cover + + legacy_storage = LegacyStorage(valid_tokens) + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=legacy_storage, # type: ignore[arg-type] + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + await provider._initialize() + + assert provider.context.current_tokens is valid_tokens + assert provider.context.oauth_metadata is None + + +@pytest.mark.anyio +async def test_token_storage_protocol_default_metadata_methods(): + """``TokenStorage`` provides no-op defaults for the optional metadata methods. + + Storage subclasses that don't care about metadata persistence can inherit + ``TokenStorage`` without overriding ``get_oauth_metadata`` / + ``set_oauth_metadata``; the default ``get`` returns ``None`` and the + default ``set`` is a no-op (equivalent to opting out of persistence). + """ + from mcp.client.auth.oauth2 import TokenStorage + + class DefaultStorage(TokenStorage): + async def get_tokens(self) -> OAuthToken | None: + return None # pragma: no cover + + async def set_tokens(self, tokens: OAuthToken) -> None: ... # pragma: no cover + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return None # pragma: no cover + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: ... # pragma: no cover + + storage = DefaultStorage() + assert await storage.get_oauth_metadata() is None + + metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + ) + # No-op: set completes without storing + await storage.set_oauth_metadata(metadata) + assert await storage.get_oauth_metadata() is None + + +@pytest.mark.anyio +async def test_auth_flow_discovery_with_legacy_storage_skips_metadata_persistence( + client_metadata: OAuthClientMetadata, +): + """OAuth discovery succeeds when storage lacks ``set_oauth_metadata``. + + Covers the ``getattr`` fallback branch in ``async_auth_flow`` that bypasses + persistence for storage implementations predating the metadata API. + """ + + class LegacyStorage: + """Duck-typed storage matching the pre-change ``TokenStorage``.""" + + def __init__(self) -> None: + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens # pragma: no cover + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info # pragma: no cover + + legacy_storage = LegacyStorage() + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=legacy_storage, # type: ignore[arg-type] + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + test_request = httpx.Request("GET", "https://api.example.com/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First yield: ``_initialize`` loads state from LegacyStorage (no tokens, + # no client info, no metadata fallback) and the original request goes out + # without an auth header. + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # 401 → triggers full OAuth flow + unauthorized_response = httpx.Response( + 401, + headers={ + "WWW-Authenticate": ( + 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + ) + }, + request=test_request, + ) + prm_request = await auth_flow.asend(unauthorized_response) + assert "oauth-protected-resource" in str(prm_request.url) + + # PRM discovery response + prm_response = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}' + ), + request=prm_request, + ) + asm_request = await auth_flow.asend(prm_response) + assert str(asm_request.url).startswith("https://auth.example.com/") + + # OASM discovery response — this is where our set_oauth_metadata + # fallback (meta_setter is None) executes for LegacyStorage. + asm_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register"}' + ), + request=asm_request, + ) + next_request = await auth_flow.asend(asm_response) + + # Discovery succeeded: flow advanced past metadata handling. + # (Legacy storage had no set_oauth_metadata, so persistence is skipped.) + assert next_request is not None + assert provider.context.oauth_metadata is not None + assert str(provider.context.oauth_metadata.token_endpoint) == "https://auth.example.com/token" + + await auth_flow.aclose()