Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 52 additions & 7 deletions src/specify_cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,16 +684,44 @@ def preset_add(

elif from_url:
# Validate URL scheme before downloading
from ipaddress import ip_address
from urllib.parse import urlparse as _urlparse

_parsed = _urlparse(from_url)
_is_localhost = _parsed.hostname in ("localhost", "127.0.0.1", "::1")
if _parsed.scheme != "https" and not (_parsed.scheme == "http" and _is_localhost):
console.print(f"[red]Error:[/red] URL must use HTTPS (got {_parsed.scheme}://). HTTP is only allowed for localhost.")

def _is_allowed_download_url(parsed_url):
host = parsed_url.hostname
if not host:
return False
is_loopback = host == "localhost"
if not is_loopback:
try:
is_loopback = ip_address(host).is_loopback
except ValueError:
# Host is not an IP literal (e.g., a regular hostname); treat as non-loopback.
pass
return parsed_url.scheme == "https" or (parsed_url.scheme == "http" and is_loopback)

def _validate_download_redirect(old_url, new_url):
if not _is_allowed_download_url(_urlparse(new_url)):
import urllib.error

raise urllib.error.URLError(
"redirect target must use HTTPS with a hostname, "
"or HTTP for localhost/loopback"
)

if not _is_allowed_download_url(_parsed):
console.print(
"[red]Error:[/red] URL must use HTTPS with a hostname, "
"or HTTP for localhost/loopback."
)
raise typer.Exit(1)

console.print(f"Installing preset from [cyan]{from_url}[/cyan]...")
import urllib.error
import tempfile
import shutil

with tempfile.TemporaryDirectory() as tmpdir:
zip_path = Path(tmpdir) / "preset.zip"
Expand All @@ -707,8 +735,25 @@ def preset_add(
from_url = _resolved_from_url
_preset_extra_headers = {"Accept": "application/octet-stream"}

with _open_url(from_url, timeout=60, extra_headers=_preset_extra_headers) as response:
zip_path.write_bytes(response.read())
with _open_url(
from_url,
timeout=60,
extra_headers=_preset_extra_headers,
redirect_validator=_validate_download_redirect,
) as response:
final_url = response.geturl() if hasattr(response, "geturl") else from_url
if not _is_allowed_download_url(_urlparse(final_url)):
console.print(
"[red]Error:[/red] Preset URL redirected to a disallowed URL: "
f"{final_url}. Redirect targets must use HTTPS with a hostname, "
"or HTTP for localhost/loopback."
)
raise typer.Exit(1)
with zip_path.open("wb") as output:
try:
shutil.copyfileobj(response, output)
except TypeError:
output.write(response.read())
except urllib.error.URLError as e:
console.print(f"[red]Error:[/red] Failed to download: {e}")
raise typer.Exit(1)
Expand Down Expand Up @@ -1186,7 +1231,7 @@ def preset_catalog_add(
})

config["catalogs"] = catalogs
config_path.write_text(yaml.dump(config, default_flow_style=False, sort_keys=False, allow_unicode=True), encoding="utf-8")
config_path.write_text(yaml.safe_dump(config, default_flow_style=False, sort_keys=False, allow_unicode=True), encoding="utf-8")

install_label = "install allowed" if install_allowed else "discovery only"
console.print(f"\n[green]✓[/green] Added catalog '[bold]{name}[/bold]' ({install_label})")
Expand Down Expand Up @@ -1226,7 +1271,7 @@ def preset_catalog_remove(
raise typer.Exit(1)

config["catalogs"] = catalogs
config_path.write_text(yaml.dump(config, default_flow_style=False, sort_keys=False, allow_unicode=True), encoding="utf-8")
config_path.write_text(yaml.safe_dump(config, default_flow_style=False, sort_keys=False, allow_unicode=True), encoding="utf-8")

console.print(f"[green]✓[/green] Removed catalog '{name}'")
if not catalogs:
Expand Down
37 changes: 31 additions & 6 deletions src/specify_cli/authentication/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import urllib.error
import urllib.request
from fnmatch import fnmatch
from typing import Callable
from urllib.parse import urlparse

from . import get_provider
Expand Down Expand Up @@ -56,22 +57,36 @@ def _hostname_in_hosts(hostname: str, hosts: tuple[str, ...]) -> bool:
return any(p == hostname or fnmatch(hostname, p) for p in hosts)


RedirectValidator = Callable[[str, str], None]


class _StripAuthOnRedirect(urllib.request.HTTPRedirectHandler):
"""Drop ``Authorization`` when a redirect leaves the entry's declared hosts."""
"""Drop ``Authorization`` when a redirect leaves trusted hosts or downgrades."""

def __init__(self, hosts: tuple[str, ...]) -> None:
def __init__(
self,
hosts: tuple[str, ...],
redirect_validator: RedirectValidator | None = None,
) -> None:
super().__init__()
self._hosts = hosts
self._redirect_validator = redirect_validator

def redirect_request(self, req, fp, code, msg, headers, newurl):
if self._redirect_validator is not None:
self._redirect_validator(req.full_url, newurl)

original_auth = (
req.get_header("Authorization")
or req.unredirected_hdrs.get("Authorization")
)
new_req = super().redirect_request(req, fp, code, msg, headers, newurl)
if new_req is not None:
hostname = (urlparse(newurl).hostname or "").lower()
if _hostname_in_hosts(hostname, self._hosts):
old_scheme = urlparse(req.full_url).scheme
new_parsed = urlparse(newurl)
hostname = (new_parsed.hostname or "").lower()
is_https_downgrade = old_scheme == "https" and new_parsed.scheme != "https"
if _hostname_in_hosts(hostname, self._hosts) and not is_https_downgrade:
if original_auth:
new_req.add_unredirected_header("Authorization", original_auth)
else:
Expand Down Expand Up @@ -103,7 +118,12 @@ def build_request(url: str, extra_headers: dict[str, str] | None = None) -> urll
return urllib.request.Request(url, headers=headers)


def open_url(url: str, timeout: int = 10, extra_headers: dict[str, str] | None = None):
def open_url(
url: str,
timeout: int = 10,
extra_headers: dict[str, str] | None = None,
redirect_validator: RedirectValidator | None = None,
):
"""Open *url* with config-driven auth, redirect stripping, and fallthrough.

1. Find ``auth.json`` entries whose hosts match the URL.
Expand All @@ -113,6 +133,8 @@ def open_url(url: str, timeout: int = 10, extra_headers: dict[str, str] | None =
5. Non-auth errors (404, 500, network) raise immediately.

*extra_headers* (e.g. ``Accept``) are merged into every attempt.
*redirect_validator*, when provided, is called with ``(old_url, new_url)``
before following each redirect and may raise to reject the redirect.
"""
entries = find_entries_for_url(url, _load_config())

Expand All @@ -135,7 +157,7 @@ def _make_req(auth_headers: dict[str, str]) -> urllib.request.Request:
continue

req = _make_req(provider.auth_headers(token, entry.auth))
opener = urllib.request.build_opener(_StripAuthOnRedirect(entry.hosts))
opener = urllib.request.build_opener(_StripAuthOnRedirect(entry.hosts, redirect_validator))
try:
return opener.open(req, timeout=timeout)
except urllib.error.HTTPError as exc:
Expand All @@ -146,4 +168,7 @@ def _make_req(auth_headers: dict[str, str]) -> urllib.request.Request:

# No entry worked (or none matched) — unauthenticated fallback
req = _make_req({})
if redirect_validator is not None:
opener = urllib.request.build_opener(_StripAuthOnRedirect((), redirect_validator))
return opener.open(req, timeout=timeout)
return urllib.request.urlopen(req, timeout=timeout) # noqa: S310
29 changes: 29 additions & 0 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,35 @@ def test_redirect_outside_hosts_strips_auth(self):
assert new_req.headers.get("Authorization") is None
assert new_req.unredirected_hdrs.get("Authorization") is None

def test_https_to_http_same_host_redirect_strips_auth(self):
from specify_cli.authentication.http import _StripAuthOnRedirect
from urllib.request import Request
import io
handler = _StripAuthOnRedirect(("github.com",))
req = Request("https://github.com/org/repo", headers={"Authorization": "Bearer tok"})
new_req = handler.redirect_request(req, io.BytesIO(b""), 302, "Found", {},
"http://github.com/org/repo")
assert new_req is not None
assert new_req.headers.get("Authorization") is None
assert new_req.unredirected_hdrs.get("Authorization") is None

def test_redirect_validator_can_reject_before_following_redirect(self):
import urllib.error
from specify_cli.authentication.http import _StripAuthOnRedirect
from urllib.request import Request
import io

def reject_http(old_url, new_url):
if new_url.startswith("http://"):
raise urllib.error.URLError("scheme downgrade")

handler = _StripAuthOnRedirect(("github.com",), reject_http)
req = Request("https://github.com/org/repo", headers={"Authorization": "Bearer tok"})

with pytest.raises(urllib.error.URLError, match="scheme downgrade"):
handler.redirect_request(req, io.BytesIO(b""), 302, "Found", {},
"http://github.com/org/repo")

def test_multi_hop_redirect_within_hosts_preserves_auth(self):
"""Auth survives a multi-hop redirect chain within allowed hosts."""
from specify_cli.authentication.http import _StripAuthOnRedirect
Expand Down
2 changes: 1 addition & 1 deletion tests/test_github_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,4 @@ def capturing_open(url, timeout=None, extra_headers=None):
capturing_open,
)
assert len(captured_urls) == 1
assert "releases/tags/v1%23beta" in captured_urls[0]
assert "releases/tags/v1%23beta" in captured_urls[0]
Loading