diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0..a21bd6f8c 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -1,6 +1,7 @@ """DNS rebinding protection for MCP server transports.""" import logging +from urllib.parse import urlparse from pydantic import BaseModel, Field from starlette.requests import Request @@ -22,12 +23,24 @@ class TransportSecuritySettings(BaseModel): allowed_hosts: list[str] = Field(default_factory=list) """List of allowed Host header values. + Supports exact matches, port wildcards, and subdomain wildcards: + + - ``"example.com"`` — exact match + - ``"example.com:*"`` — any port on that host + - ``"*.example.com"`` — any subdomain (or the base domain itself) + Only applies when `enable_dns_rebinding_protection` is `True`. """ allowed_origins: list[str] = Field(default_factory=list) """List of allowed Origin header values. + Supports exact matches, port wildcards, and subdomain wildcards: + + - ``"https://example.com"`` — exact match + - ``"https://example.com:*"`` — any port on that origin + - ``"https://*.example.com"`` — any subdomain (or the base domain itself) with HTTPS + Only applies when `enable_dns_rebinding_protection` is `True`. """ @@ -40,46 +53,61 @@ def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) - def _validate_host(self, host: str | None) -> bool: # pragma: no cover + def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" if not host: logger.warning("Missing Host header in request") return False - # Check exact match first if host in self.settings.allowed_hosts: return True - # Check wildcard port patterns + # Strip port for subdomain wildcard matching + host_without_port = host.split(":")[0] + for allowed in self.settings.allowed_hosts: if allowed.endswith(":*"): - # Extract base host from pattern + # Port wildcard: e.g., "example.com:*" matches "example.com:8080" base_host = allowed[:-2] - # Check if the actual host starts with base host and has a port if host.startswith(base_host + ":"): return True + elif allowed.startswith("*."): + # Subdomain wildcard: e.g., "*.example.com" matches "example.com" + # and "sub.example.com" (port is ignored) + suffix = allowed[2:] + if host_without_port == suffix or host_without_port.endswith("." + suffix): + return True logger.warning(f"Invalid Host header: {host}") return False - def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover + def _validate_origin(self, origin: str | None) -> bool: """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests if not origin: return True - # Check exact match first if origin in self.settings.allowed_origins: return True - # Check wildcard port patterns for allowed in self.settings.allowed_origins: if allowed.endswith(":*"): - # Extract base origin from pattern + # Port wildcard: e.g., "https://example.com:*" matches "https://example.com:8080" base_origin = allowed[:-2] - # Check if the actual origin starts with base origin and has a port if origin.startswith(base_origin + ":"): return True + elif "://*." in allowed: + # Subdomain wildcard: e.g., "https://*.example.com" matches + # "https://example.com" and "https://sub.example.com" + parsed_allowed = urlparse(allowed) + parsed_origin = urlparse(origin) + if parsed_allowed.scheme != parsed_origin.scheme: + continue + # hostname is "*.suffix" because "://*." is in the pattern + suffix = (parsed_allowed.hostname or "")[2:] + origin_hostname = parsed_origin.hostname or "" + if origin_hostname == suffix or origin_hostname.endswith("." + suffix): + return True logger.warning(f"Invalid Origin header: {origin}") return False @@ -94,7 +122,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res Returns None if validation passes, or an error Response if validation fails. """ # Always validate Content-Type for POST requests - if is_post: # pragma: no branch + if is_post: content_type = request.headers.get("content-type") if not self._validate_content_type(content_type): return Response("Invalid Content-Type header", status_code=400) @@ -103,14 +131,12 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res if not self.settings.enable_dns_rebinding_protection: return None - # Validate Host header # pragma: no cover - host = request.headers.get("host") # pragma: no cover - if not self._validate_host(host): # pragma: no cover - return Response("Invalid Host header", status_code=421) # pragma: no cover + host = request.headers.get("host") + if not self._validate_host(host): + return Response("Invalid Host header", status_code=421) - # Validate Origin header # pragma: no cover - origin = request.headers.get("origin") # pragma: no cover - if not self._validate_origin(origin): # pragma: no cover - return Response("Invalid Origin header", status_code=403) # pragma: no cover + origin = request.headers.get("origin") + if not self._validate_origin(origin): + return Response("Invalid Origin header", status_code=403) - return None # pragma: no cover + return None diff --git a/tests/server/test_transport_security.py b/tests/server/test_transport_security.py new file mode 100644 index 000000000..afc109c4c --- /dev/null +++ b/tests/server/test_transport_security.py @@ -0,0 +1,218 @@ +"""Unit tests for TransportSecurityMiddleware.""" + +import pytest +from starlette.requests import Request + +from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings + + +def make_request(headers: dict[str, str], method: str = "GET") -> Request: + scope = { + "type": "http", + "method": method, + "path": "/", + "query_string": b"", + "headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()], + } + return Request(scope) + + +def make_middleware( + *, + allowed_hosts: list[str] | None = None, + allowed_origins: list[str] | None = None, +) -> TransportSecurityMiddleware: + return TransportSecurityMiddleware( + TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=allowed_hosts or [], + allowed_origins=allowed_origins or [], + ) + ) + + +# --------------------------------------------------------------------------- +# _validate_host +# --------------------------------------------------------------------------- + + +def test_validate_host_missing_header(): + mw = make_middleware(allowed_hosts=["example.com"]) + assert mw._validate_host(None) is False + + +def test_validate_host_exact_match(): + mw = make_middleware(allowed_hosts=["example.com"]) + assert mw._validate_host("example.com") is True + + +def test_validate_host_no_match(): + mw = make_middleware(allowed_hosts=["example.com"]) + assert mw._validate_host("evil.com") is False + + +def test_validate_host_port_wildcard_matches(): + mw = make_middleware(allowed_hosts=["example.com:*"]) + assert mw._validate_host("example.com:8080") is True + + +def test_validate_host_port_wildcard_different_host(): + mw = make_middleware(allowed_hosts=["example.com:*"]) + assert mw._validate_host("evil.com:8080") is False + + +def test_validate_host_subdomain_wildcard_base_domain(): + # "*.example.com" should match the base domain itself + mw = make_middleware(allowed_hosts=["*.example.com"]) + assert mw._validate_host("example.com") is True + + +def test_validate_host_subdomain_wildcard_with_subdomain(): + mw = make_middleware(allowed_hosts=["*.example.com"]) + assert mw._validate_host("app.example.com") is True + + +def test_validate_host_subdomain_wildcard_with_nested_subdomain(): + mw = make_middleware(allowed_hosts=["*.example.com"]) + assert mw._validate_host("api.staging.example.com") is True + + +def test_validate_host_subdomain_wildcard_with_port(): + # Port should be stripped before subdomain matching + mw = make_middleware(allowed_hosts=["*.example.com"]) + assert mw._validate_host("app.example.com:443") is True + + +def test_validate_host_subdomain_wildcard_no_match(): + mw = make_middleware(allowed_hosts=["*.example.com"]) + assert mw._validate_host("notexample.com") is False + + +def test_validate_host_subdomain_wildcard_suffix_collision(): + # "fakeexample.com" must not match "*.example.com" + mw = make_middleware(allowed_hosts=["*.example.com"]) + assert mw._validate_host("fakeexample.com") is False + + +# --------------------------------------------------------------------------- +# _validate_origin +# --------------------------------------------------------------------------- + + +def test_validate_origin_absent(): + mw = make_middleware(allowed_origins=["https://example.com"]) + assert mw._validate_origin(None) is True + + +def test_validate_origin_exact_match(): + mw = make_middleware(allowed_origins=["https://example.com"]) + assert mw._validate_origin("https://example.com") is True + + +def test_validate_origin_no_match(): + mw = make_middleware(allowed_origins=["https://example.com"]) + assert mw._validate_origin("https://evil.com") is False + + +def test_validate_origin_port_wildcard_matches(): + mw = make_middleware(allowed_origins=["https://example.com:*"]) + assert mw._validate_origin("https://example.com:8443") is True + + +def test_validate_origin_port_wildcard_different_host(): + mw = make_middleware(allowed_origins=["https://example.com:*"]) + assert mw._validate_origin("https://evil.com:8443") is False + + +def test_validate_origin_subdomain_wildcard_base_domain(): + # "https://*.example.com" should match the base domain itself + mw = make_middleware(allowed_origins=["https://*.example.com"]) + assert mw._validate_origin("https://example.com") is True + + +def test_validate_origin_subdomain_wildcard_with_subdomain(): + mw = make_middleware(allowed_origins=["https://*.example.com"]) + assert mw._validate_origin("https://app.example.com") is True + + +def test_validate_origin_subdomain_wildcard_scheme_mismatch(): + mw = make_middleware(allowed_origins=["https://*.example.com"]) + assert mw._validate_origin("http://app.example.com") is False + + +def test_validate_origin_subdomain_wildcard_no_match(): + mw = make_middleware(allowed_origins=["https://*.example.com"]) + assert mw._validate_origin("https://evil.com") is False + + +# --------------------------------------------------------------------------- +# validate_request (integration over the public method) +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_validate_request_post_invalid_content_type(): + mw = make_middleware(allowed_hosts=["example.com"]) + req = make_request({"host": "example.com", "content-type": "text/plain"}, method="POST") + resp = await mw.validate_request(req, is_post=True) + assert resp is not None + assert resp.status_code == 400 + + +@pytest.mark.anyio +async def test_validate_request_post_valid_content_type_protection_disabled(): + mw = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + req = make_request({"host": "example.com", "content-type": "application/json"}, method="POST") + resp = await mw.validate_request(req, is_post=True) + assert resp is None + + +@pytest.mark.anyio +async def test_validate_request_get_protection_disabled(): + mw = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + req = make_request({"host": "evil.com"}, method="GET") + resp = await mw.validate_request(req, is_post=False) + assert resp is None + + +@pytest.mark.anyio +async def test_validate_request_get_invalid_host(): + mw = make_middleware(allowed_hosts=["example.com"]) + req = make_request({"host": "evil.com"}, method="GET") + resp = await mw.validate_request(req, is_post=False) + assert resp is not None + assert resp.status_code == 421 + + +@pytest.mark.anyio +async def test_validate_request_post_invalid_host(): + mw = make_middleware(allowed_hosts=["example.com"]) + req = make_request({"host": "evil.com", "content-type": "application/json"}, method="POST") + resp = await mw.validate_request(req, is_post=True) + assert resp is not None + assert resp.status_code == 421 + + +@pytest.mark.anyio +async def test_validate_request_invalid_origin(): + mw = make_middleware(allowed_hosts=["example.com"], allowed_origins=["https://example.com"]) + req = make_request({"host": "example.com", "origin": "https://evil.com"}, method="GET") + resp = await mw.validate_request(req, is_post=False) + assert resp is not None + assert resp.status_code == 403 + + +@pytest.mark.anyio +async def test_validate_request_all_valid(): + mw = make_middleware(allowed_hosts=["example.com"], allowed_origins=["https://example.com"]) + req = make_request({"host": "example.com", "origin": "https://example.com"}, method="GET") + resp = await mw.validate_request(req, is_post=False) + assert resp is None + + +@pytest.mark.anyio +async def test_validate_request_wildcard_host_end_to_end(): + mw = make_middleware(allowed_hosts=["*.example.com"], allowed_origins=["https://*.example.com"]) + req = make_request({"host": "api.example.com", "origin": "https://app.example.com"}, method="GET") + resp = await mw.validate_request(req, is_post=False) + assert resp is None