Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 46 additions & 20 deletions src/mcp/server/transport_security.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""DNS rebinding protection for MCP server transports."""

import logging
from urllib.parse import urlparse

from pydantic import BaseModel, Field
from starlette.requests import Request
Expand All @@ -22,12 +23,24 @@ class TransportSecuritySettings(BaseModel):
allowed_hosts: list[str] = Field(default_factory=list)
"""List of allowed Host header values.

Supports exact matches, port wildcards, and subdomain wildcards:

- ``"example.com"`` — exact match
- ``"example.com:*"`` — any port on that host
- ``"*.example.com"`` — any subdomain (or the base domain itself)

Only applies when `enable_dns_rebinding_protection` is `True`.
"""

allowed_origins: list[str] = Field(default_factory=list)
"""List of allowed Origin header values.

Supports exact matches, port wildcards, and subdomain wildcards:

- ``"https://example.com"`` — exact match
- ``"https://example.com:*"`` — any port on that origin
- ``"https://*.example.com"`` — any subdomain (or the base domain itself) with HTTPS

Only applies when `enable_dns_rebinding_protection` is `True`.
"""

Expand All @@ -40,46 +53,61 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
# If not specified, disable DNS rebinding protection by default for backwards compatibility
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)

def _validate_host(self, host: str | None) -> bool: # pragma: no cover
def _validate_host(self, host: str | None) -> bool:
"""Validate the Host header against allowed values."""
if not host:
logger.warning("Missing Host header in request")
return False

# Check exact match first
if host in self.settings.allowed_hosts:
return True

# Check wildcard port patterns
# Strip port for subdomain wildcard matching
host_without_port = host.split(":")[0]

for allowed in self.settings.allowed_hosts:
if allowed.endswith(":*"):
# Extract base host from pattern
# Port wildcard: e.g., "example.com:*" matches "example.com:8080"
base_host = allowed[:-2]
# Check if the actual host starts with base host and has a port
if host.startswith(base_host + ":"):
return True
elif allowed.startswith("*."):
# Subdomain wildcard: e.g., "*.example.com" matches "example.com"
# and "sub.example.com" (port is ignored)
suffix = allowed[2:]
if host_without_port == suffix or host_without_port.endswith("." + suffix):
return True

logger.warning(f"Invalid Host header: {host}")
return False

def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover
def _validate_origin(self, origin: str | None) -> bool:
"""Validate the Origin header against allowed values."""
# Origin can be absent for same-origin requests
if not origin:
return True

# Check exact match first
if origin in self.settings.allowed_origins:
return True

# Check wildcard port patterns
for allowed in self.settings.allowed_origins:
if allowed.endswith(":*"):
# Extract base origin from pattern
# Port wildcard: e.g., "https://example.com:*" matches "https://example.com:8080"
base_origin = allowed[:-2]
# Check if the actual origin starts with base origin and has a port
if origin.startswith(base_origin + ":"):
return True
elif "://*." in allowed:
# Subdomain wildcard: e.g., "https://*.example.com" matches
# "https://example.com" and "https://sub.example.com"
parsed_allowed = urlparse(allowed)
parsed_origin = urlparse(origin)
if parsed_allowed.scheme != parsed_origin.scheme:
continue
# hostname is "*.suffix" because "://*." is in the pattern
suffix = (parsed_allowed.hostname or "")[2:]
origin_hostname = parsed_origin.hostname or ""
if origin_hostname == suffix or origin_hostname.endswith("." + suffix):
return True

logger.warning(f"Invalid Origin header: {origin}")
return False
Expand All @@ -94,7 +122,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
Returns None if validation passes, or an error Response if validation fails.
"""
# Always validate Content-Type for POST requests
if is_post: # pragma: no branch
if is_post:
content_type = request.headers.get("content-type")
if not self._validate_content_type(content_type):
return Response("Invalid Content-Type header", status_code=400)
Expand All @@ -103,14 +131,12 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
if not self.settings.enable_dns_rebinding_protection:
return None

# Validate Host header # pragma: no cover
host = request.headers.get("host") # pragma: no cover
if not self._validate_host(host): # pragma: no cover
return Response("Invalid Host header", status_code=421) # pragma: no cover
host = request.headers.get("host")
if not self._validate_host(host):
return Response("Invalid Host header", status_code=421)

# Validate Origin header # pragma: no cover
origin = request.headers.get("origin") # pragma: no cover
if not self._validate_origin(origin): # pragma: no cover
return Response("Invalid Origin header", status_code=403) # pragma: no cover
origin = request.headers.get("origin")
if not self._validate_origin(origin):
return Response("Invalid Origin header", status_code=403)

return None # pragma: no cover
return None
218 changes: 218 additions & 0 deletions tests/server/test_transport_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
"""Unit tests for TransportSecurityMiddleware."""

import pytest
from starlette.requests import Request

from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings


def make_request(headers: dict[str, str], method: str = "GET") -> Request:
scope = {
"type": "http",
"method": method,
"path": "/",
"query_string": b"",
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
}
return Request(scope)


def make_middleware(
*,
allowed_hosts: list[str] | None = None,
allowed_origins: list[str] | None = None,
) -> TransportSecurityMiddleware:
return TransportSecurityMiddleware(
TransportSecuritySettings(
enable_dns_rebinding_protection=True,
allowed_hosts=allowed_hosts or [],
allowed_origins=allowed_origins or [],
)
)


# ---------------------------------------------------------------------------
# _validate_host
# ---------------------------------------------------------------------------


def test_validate_host_missing_header():
mw = make_middleware(allowed_hosts=["example.com"])
assert mw._validate_host(None) is False


def test_validate_host_exact_match():
mw = make_middleware(allowed_hosts=["example.com"])
assert mw._validate_host("example.com") is True


def test_validate_host_no_match():
mw = make_middleware(allowed_hosts=["example.com"])
assert mw._validate_host("evil.com") is False


def test_validate_host_port_wildcard_matches():
mw = make_middleware(allowed_hosts=["example.com:*"])
assert mw._validate_host("example.com:8080") is True


def test_validate_host_port_wildcard_different_host():
mw = make_middleware(allowed_hosts=["example.com:*"])
assert mw._validate_host("evil.com:8080") is False


def test_validate_host_subdomain_wildcard_base_domain():
# "*.example.com" should match the base domain itself
mw = make_middleware(allowed_hosts=["*.example.com"])
assert mw._validate_host("example.com") is True


def test_validate_host_subdomain_wildcard_with_subdomain():
mw = make_middleware(allowed_hosts=["*.example.com"])
assert mw._validate_host("app.example.com") is True


def test_validate_host_subdomain_wildcard_with_nested_subdomain():
mw = make_middleware(allowed_hosts=["*.example.com"])
assert mw._validate_host("api.staging.example.com") is True


def test_validate_host_subdomain_wildcard_with_port():
# Port should be stripped before subdomain matching
mw = make_middleware(allowed_hosts=["*.example.com"])
assert mw._validate_host("app.example.com:443") is True


def test_validate_host_subdomain_wildcard_no_match():
mw = make_middleware(allowed_hosts=["*.example.com"])
assert mw._validate_host("notexample.com") is False


def test_validate_host_subdomain_wildcard_suffix_collision():
# "fakeexample.com" must not match "*.example.com"
mw = make_middleware(allowed_hosts=["*.example.com"])
assert mw._validate_host("fakeexample.com") is False


# ---------------------------------------------------------------------------
# _validate_origin
# ---------------------------------------------------------------------------


def test_validate_origin_absent():
mw = make_middleware(allowed_origins=["https://example.com"])
assert mw._validate_origin(None) is True


def test_validate_origin_exact_match():
mw = make_middleware(allowed_origins=["https://example.com"])
assert mw._validate_origin("https://example.com") is True


def test_validate_origin_no_match():
mw = make_middleware(allowed_origins=["https://example.com"])
assert mw._validate_origin("https://evil.com") is False


def test_validate_origin_port_wildcard_matches():
mw = make_middleware(allowed_origins=["https://example.com:*"])
assert mw._validate_origin("https://example.com:8443") is True


def test_validate_origin_port_wildcard_different_host():
mw = make_middleware(allowed_origins=["https://example.com:*"])
assert mw._validate_origin("https://evil.com:8443") is False


def test_validate_origin_subdomain_wildcard_base_domain():
# "https://*.example.com" should match the base domain itself
mw = make_middleware(allowed_origins=["https://*.example.com"])
assert mw._validate_origin("https://example.com") is True


def test_validate_origin_subdomain_wildcard_with_subdomain():
mw = make_middleware(allowed_origins=["https://*.example.com"])
assert mw._validate_origin("https://app.example.com") is True


def test_validate_origin_subdomain_wildcard_scheme_mismatch():
mw = make_middleware(allowed_origins=["https://*.example.com"])
assert mw._validate_origin("http://app.example.com") is False


def test_validate_origin_subdomain_wildcard_no_match():
mw = make_middleware(allowed_origins=["https://*.example.com"])
assert mw._validate_origin("https://evil.com") is False


# ---------------------------------------------------------------------------
# validate_request (integration over the public method)
# ---------------------------------------------------------------------------


@pytest.mark.anyio
async def test_validate_request_post_invalid_content_type():
mw = make_middleware(allowed_hosts=["example.com"])
req = make_request({"host": "example.com", "content-type": "text/plain"}, method="POST")
resp = await mw.validate_request(req, is_post=True)
assert resp is not None
assert resp.status_code == 400


@pytest.mark.anyio
async def test_validate_request_post_valid_content_type_protection_disabled():
mw = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
req = make_request({"host": "example.com", "content-type": "application/json"}, method="POST")
resp = await mw.validate_request(req, is_post=True)
assert resp is None


@pytest.mark.anyio
async def test_validate_request_get_protection_disabled():
mw = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
req = make_request({"host": "evil.com"}, method="GET")
resp = await mw.validate_request(req, is_post=False)
assert resp is None


@pytest.mark.anyio
async def test_validate_request_get_invalid_host():
mw = make_middleware(allowed_hosts=["example.com"])
req = make_request({"host": "evil.com"}, method="GET")
resp = await mw.validate_request(req, is_post=False)
assert resp is not None
assert resp.status_code == 421


@pytest.mark.anyio
async def test_validate_request_post_invalid_host():
mw = make_middleware(allowed_hosts=["example.com"])
req = make_request({"host": "evil.com", "content-type": "application/json"}, method="POST")
resp = await mw.validate_request(req, is_post=True)
assert resp is not None
assert resp.status_code == 421


@pytest.mark.anyio
async def test_validate_request_invalid_origin():
mw = make_middleware(allowed_hosts=["example.com"], allowed_origins=["https://example.com"])
req = make_request({"host": "example.com", "origin": "https://evil.com"}, method="GET")
resp = await mw.validate_request(req, is_post=False)
assert resp is not None
assert resp.status_code == 403


@pytest.mark.anyio
async def test_validate_request_all_valid():
mw = make_middleware(allowed_hosts=["example.com"], allowed_origins=["https://example.com"])
req = make_request({"host": "example.com", "origin": "https://example.com"}, method="GET")
resp = await mw.validate_request(req, is_post=False)
assert resp is None


@pytest.mark.anyio
async def test_validate_request_wildcard_host_end_to_end():
mw = make_middleware(allowed_hosts=["*.example.com"], allowed_origins=["https://*.example.com"])
req = make_request({"host": "api.example.com", "origin": "https://app.example.com"}, method="GET")
resp = await mw.validate_request(req, is_post=False)
assert resp is None
Loading