diff --git a/cuda_pathfinder/cuda/pathfinder/__init__.py b/cuda_pathfinder/cuda/pathfinder/__init__.py index dc818dfd08..3a3f12f3b3 100644 --- a/cuda_pathfinder/cuda/pathfinder/__init__.py +++ b/cuda_pathfinder/cuda/pathfinder/__init__.py @@ -7,10 +7,16 @@ # cuda_pathfinder/docs/source/api.rst # to keep the documentation in sync. -from cuda.pathfinder._binaries.find_nvidia_binary_utility import ( - find_nvidia_binary_utility as find_nvidia_binary_utility, -) from cuda.pathfinder._binaries.supported_nvidia_binaries import SUPPORTED_BINARIES as _SUPPORTED_BINARIES +from cuda.pathfinder._compatibility_guard_rails import ( + CompatibilityCheckError as CompatibilityCheckError, +) +from cuda.pathfinder._compatibility_guard_rails import ( + CompatibilityGuardRails as CompatibilityGuardRails, +) +from cuda.pathfinder._compatibility_guard_rails import ( + CompatibilityInsufficientMetadataError as CompatibilityInsufficientMetadataError, +) from cuda.pathfinder._dynamic_libs.load_dl_common import ( DynamicLibNotAvailableError as DynamicLibNotAvailableError, ) @@ -19,16 +25,38 @@ DynamicLibUnknownError as DynamicLibUnknownError, ) from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL as LoadedDL -from cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib import load_nvidia_dynamic_lib as load_nvidia_dynamic_lib from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import ( SUPPORTED_LIBNAMES as SUPPORTED_NVIDIA_LIBNAMES, ) from cuda.pathfinder._headers.find_nvidia_headers import LocatedHeaderDir as LocatedHeaderDir -from cuda.pathfinder._headers.find_nvidia_headers import find_nvidia_header_directory as find_nvidia_header_directory -from cuda.pathfinder._headers.find_nvidia_headers import ( +from cuda.pathfinder._headers.supported_nvidia_headers import SUPPORTED_HEADERS_CTK as _SUPPORTED_HEADERS_CTK +from cuda.pathfinder._process_wide_compatibility_guard_rails import ( + find_bitcode_lib as find_bitcode_lib, +) +from cuda.pathfinder._process_wide_compatibility_guard_rails import ( + find_nvidia_binary_utility as find_nvidia_binary_utility, +) +from cuda.pathfinder._process_wide_compatibility_guard_rails import ( + find_nvidia_header_directory as find_nvidia_header_directory, +) +from cuda.pathfinder._process_wide_compatibility_guard_rails import ( + find_static_lib as find_static_lib, +) +from cuda.pathfinder._process_wide_compatibility_guard_rails import ( + load_nvidia_dynamic_lib as load_nvidia_dynamic_lib, +) +from cuda.pathfinder._process_wide_compatibility_guard_rails import ( + locate_bitcode_lib as locate_bitcode_lib, +) +from cuda.pathfinder._process_wide_compatibility_guard_rails import ( locate_nvidia_header_directory as locate_nvidia_header_directory, ) -from cuda.pathfinder._headers.supported_nvidia_headers import SUPPORTED_HEADERS_CTK as _SUPPORTED_HEADERS_CTK +from cuda.pathfinder._process_wide_compatibility_guard_rails import ( + locate_static_lib as locate_static_lib, +) +from cuda.pathfinder._process_wide_compatibility_guard_rails import ( + process_wide_compatibility_guard_rails as _process_wide_compatibility_guard_rails, +) from cuda.pathfinder._static_libs.find_bitcode_lib import ( SUPPORTED_BITCODE_LIBS as _SUPPORTED_BITCODE_LIBS, ) @@ -38,12 +66,6 @@ from cuda.pathfinder._static_libs.find_bitcode_lib import ( LocatedBitcodeLib as LocatedBitcodeLib, ) -from cuda.pathfinder._static_libs.find_bitcode_lib import ( - find_bitcode_lib as find_bitcode_lib, -) -from cuda.pathfinder._static_libs.find_bitcode_lib import ( - locate_bitcode_lib as locate_bitcode_lib, -) from cuda.pathfinder._static_libs.find_static_lib import ( SUPPORTED_STATIC_LIBS as _SUPPORTED_STATIC_LIBS, ) @@ -53,16 +75,16 @@ from cuda.pathfinder._static_libs.find_static_lib import ( StaticLibNotFoundError as StaticLibNotFoundError, ) -from cuda.pathfinder._static_libs.find_static_lib import ( - find_static_lib as find_static_lib, -) -from cuda.pathfinder._static_libs.find_static_lib import ( - locate_static_lib as locate_static_lib, -) from cuda.pathfinder._utils.env_vars import get_cuda_path_or_home as get_cuda_path_or_home from cuda.pathfinder._version import __version__ # isort: skip +#: Process-wide default compatibility guard rails instance. Public APIs can +#: delegate through this singleton while the explicit ``CompatibilityGuardRails`` +#: class remains available for advanced use cases. +process_wide_compatibility_guard_rails = _process_wide_compatibility_guard_rails + + # Indirections to help Sphinx find the docstrings. #: Mapping from short CUDA Toolkit (CTK) library names to their canonical #: header basenames (used to validate a discovered include directory). diff --git a/cuda_pathfinder/cuda/pathfinder/_compatibility_guard_rails.py b/cuda_pathfinder/cuda/pathfinder/_compatibility_guard_rails.py new file mode 100644 index 0000000000..35737c1cf5 --- /dev/null +++ b/cuda_pathfinder/cuda/pathfinder/_compatibility_guard_rails.py @@ -0,0 +1,661 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import functools +import importlib.metadata +import os +import re +from collections.abc import Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import TypeAlias, cast + +from cuda.pathfinder._binaries.find_nvidia_binary_utility import ( + find_nvidia_binary_utility as _find_nvidia_binary_utility, +) +from cuda.pathfinder._binaries.supported_nvidia_binaries import SUPPORTED_BINARIES_ALL +from cuda.pathfinder._dynamic_libs.lib_descriptor import LIB_DESCRIPTORS +from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL +from cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib import ( + load_nvidia_dynamic_lib as _load_nvidia_dynamic_lib, +) +from cuda.pathfinder._headers.find_nvidia_headers import ( + LocatedHeaderDir, +) +from cuda.pathfinder._headers.find_nvidia_headers import ( + locate_nvidia_header_directory as _locate_nvidia_header_directory, +) +from cuda.pathfinder._headers.header_descriptor import HEADER_DESCRIPTORS +from cuda.pathfinder._static_libs.find_bitcode_lib import ( + LocatedBitcodeLib, +) +from cuda.pathfinder._static_libs.find_bitcode_lib import ( + locate_bitcode_lib as _locate_bitcode_lib, +) +from cuda.pathfinder._static_libs.find_static_lib import ( + LocatedStaticLib, +) +from cuda.pathfinder._static_libs.find_static_lib import ( + locate_static_lib as _locate_static_lib, +) +from cuda.pathfinder._utils.driver_info import ( + DriverCudaVersion, + QueryDriverCudaVersionError, + query_driver_cuda_version, +) +from cuda.pathfinder._utils.toolkit_info import ReadCudaHeaderVersionError, read_cuda_header_version + +ItemKind: TypeAlias = str +PackagedWith: TypeAlias = str +ConstraintOperator: TypeAlias = str +ConstraintArg: TypeAlias = int | str | tuple[str, int] | None + +_CTK_VERSION_RE = re.compile(r"^(?P\d+)\.(?P\d+)") +_REQUIRES_DIST_RE = re.compile(r"^\s*(?P[A-Za-z0-9_.-]+)\s*(?P[^;]*)(?:\s*;|$)") +_VERSION_SPECIFIER_RE = re.compile(r"^\s*(?P==|<=|>=|<|>)\s*(?P[0-9][A-Za-z0-9.+-]*?(?:\.\*)?)\s*$") + +_STATIC_LIBS_PACKAGED_WITH: dict[str, PackagedWith] = { + "cudadevrt": "ctk", +} +_BITCODE_LIBS_PACKAGED_WITH: dict[str, PackagedWith] = { + "device": "ctk", + "nvshmem_device": "other", +} +_BINARY_PACKAGED_WITH: dict[str, PackagedWith] = dict.fromkeys(SUPPORTED_BINARIES_ALL, "ctk") + + +class CompatibilityCheckError(RuntimeError): + """Raised when compatibility checks reject a resolved item.""" + + +class CompatibilityInsufficientMetadataError(CompatibilityCheckError): + """Raised when v1 compatibility checks cannot reach a definitive answer.""" + + +@dataclass(frozen=True, slots=True) +class CtkMetadata: + ctk_version: CtkVersion + ctk_root: str | None + source: str + + +@dataclass(frozen=True, slots=True) +class CtkVersion: + major: int + minor: int + + def __str__(self) -> str: + return f"{self.major}.{self.minor}" + + +@dataclass(frozen=True, slots=True) +class ComparisonConstraint: + operator: ConstraintOperator + value: int + + def matches(self, candidate: int) -> bool: + if self.operator == "==": + return candidate == self.value + if self.operator == "<": + return candidate < self.value + if self.operator == "<=": + return candidate <= self.value + if self.operator == ">": + return candidate > self.value + if self.operator == ">=": + return candidate >= self.value + raise AssertionError(f"Unsupported operator: {self.operator!r}") + + def __str__(self) -> str: + return f"{self.operator}{self.value}" + + +@dataclass(frozen=True, slots=True) +class VersionSpecifier: + operator: ConstraintOperator + version: str + + +@dataclass(frozen=True, slots=True) +class ResolvedItem: + name: str + kind: ItemKind + packaged_with: PackagedWith + abs_path: str + found_via: str | None + ctk_root: str | None + ctk_version: CtkVersion | None + ctk_version_source: str | None + + def describe(self) -> str: + found_via = "" if self.found_via is None else f" via {self.found_via}" + return f"{self.kind} {self.name!r}{found_via} at {self.abs_path!r}" + + +@dataclass(frozen=True, slots=True) +class CompatibilityResult: + status: str + message: str + + def require_compatible(self) -> None: + if self.status == "compatible": + return + if self.status == "insufficient_metadata": + raise CompatibilityInsufficientMetadataError(self.message) + raise CompatibilityCheckError(self.message) + + +def _coerce_constraint(name: str, raw_value: ConstraintArg) -> ComparisonConstraint | None: + if raw_value is None: + return None + if isinstance(raw_value, int): + return ComparisonConstraint("==", raw_value) + if isinstance(raw_value, tuple): + if len(raw_value) != 2: + raise ValueError(f"{name} tuple constraints must have exactly two elements.") + operator, value = raw_value + if operator not in ("==", "<", "<=", ">", ">="): + raise ValueError(f"{name} has unsupported operator {operator!r}.") + if not isinstance(value, int): + raise ValueError(f"{name} constraint value must be an integer.") + return ComparisonConstraint(operator, value) + if isinstance(raw_value, str): + match = re.fullmatch(r"\s*(==|<|<=|>|>=)?\s*(\d+)\s*", raw_value) + if match is None: + raise ValueError(f"{name} must be an int, a (operator, value) tuple, or a string like '>=12'.") + operator = match.group(1) or "==" + value = int(match.group(2)) + return ComparisonConstraint(operator, value) + raise ValueError(f"{name} must be an int, a (operator, value) tuple, or a string like '>=12'.") + + +def _parse_ctk_version(cuda_version: str) -> CtkVersion | None: + match = _CTK_VERSION_RE.match(cuda_version) + if match is None: + return None + return CtkVersion(major=int(match.group("major")), minor=int(match.group("minor"))) + + +def _normalize_distribution_name(name: str) -> str: + return re.sub(r"[-_.]+", "-", name).lower() + + +def _distribution_name(dist: importlib.metadata.Distribution) -> str | None: + # Work around mypy's typing of Distribution.metadata as PackageMetadata: + # the runtime object behaves like a string mapping, but mypy does not + # expose Mapping.get() on PackageMetadata. + metadata = cast(Mapping[str, str], dist.metadata) + return metadata.get("Name") + + +def _release_version_parts(version: str) -> tuple[int, ...] | None: + match = re.match(r"^\d+(?:\.\d+)*", version) + if match is None: + return None + return tuple(int(part) for part in match.group(0).split(".")) + + +def _compare_release_versions(lhs: tuple[int, ...], rhs: tuple[int, ...]) -> int: + max_len = max(len(lhs), len(rhs)) + lhs_padded = lhs + (0,) * (max_len - len(lhs)) + rhs_padded = rhs + (0,) * (max_len - len(rhs)) + if lhs_padded < rhs_padded: + return -1 + if lhs_padded > rhs_padded: + return 1 + return 0 + + +def _parse_version_specifiers(specifier_text: str) -> tuple[VersionSpecifier, ...]: + stripped = specifier_text.strip() + if not stripped: + return () + parsed: list[VersionSpecifier] = [] + for raw_clause in stripped.split(","): + match = _VERSION_SPECIFIER_RE.match(raw_clause) + if match is None: + return () + parsed.append(VersionSpecifier(operator=match.group("operator"), version=match.group("version"))) + return tuple(parsed) + + +def _version_satisfies_specifiers(version: str, specifiers: tuple[VersionSpecifier, ...]) -> bool: + if not specifiers: + return False + for specifier in specifiers: + if specifier.operator == "==": + prefix = specifier.version.removesuffix(".*") + if version == prefix or version.startswith(prefix + "."): + continue + return False + candidate_parts = _release_version_parts(version) + required_parts = _release_version_parts(specifier.version) + if candidate_parts is None or required_parts is None: + return False + comparison = _compare_release_versions(candidate_parts, required_parts) + if specifier.operator == "<" and comparison < 0: + continue + if specifier.operator == "<=" and comparison <= 0: + continue + if specifier.operator == ">" and comparison > 0: + continue + if specifier.operator == ">=" and comparison >= 0: + continue + return False + return True + + +@functools.cache +def _owned_distribution_candidates(abs_path: str) -> tuple[tuple[str, str], ...]: + normalized_abs_path = os.path.normpath(os.path.abspath(abs_path)) + matches: set[tuple[str, str]] = set() + for dist in importlib.metadata.distributions(): + dist_name = _distribution_name(dist) + if not dist_name: + continue + for file in dist.files or (): + candidate_abs_path = os.path.normpath(os.path.abspath(str(dist.locate_file(file)))) + if candidate_abs_path == normalized_abs_path: + matches.add((dist_name, dist.version)) + return tuple(sorted(matches)) + + +@functools.cache +def _cuda_toolkit_requirement_maps() -> tuple[ + tuple[str, CtkVersion, dict[str, tuple[tuple[VersionSpecifier, ...], ...]]], ... +]: + results: list[tuple[str, CtkVersion, dict[str, tuple[tuple[VersionSpecifier, ...], ...]]]] = [] + for dist in importlib.metadata.distributions(): + dist_name = _distribution_name(dist) + if _normalize_distribution_name(dist_name or "") != "cuda-toolkit": + continue + ctk_version = _parse_ctk_version(dist.version) + if ctk_version is None: + continue + requirement_map: dict[str, set[tuple[VersionSpecifier, ...]]] = {} + for requirement in dist.requires or (): + match = _REQUIRES_DIST_RE.match(requirement) + if match is None: + continue + req_name = _normalize_distribution_name(match.group("name")) + parsed_specifiers = _parse_version_specifiers(match.group("specifier_text")) + if not parsed_specifiers: + continue + requirement_map.setdefault(req_name, set()).add(parsed_specifiers) + results.append( + ( + dist.version, + ctk_version, + { + name: tuple( + sorted( + specifier_sets, + key=lambda specifiers: tuple( + (specifier.operator, specifier.version) for specifier in specifiers + ), + ) + ) + for name, specifier_sets in requirement_map.items() + }, + ) + ) + return tuple(results) + + +def _wheel_metadata_for_abs_path(abs_path: str) -> CtkMetadata | None: + matched_versions: dict[CtkVersion, str] = {} + for owner_name, owner_version in _owned_distribution_candidates(abs_path): + normalized_owner_name = _normalize_distribution_name(owner_name) + for toolkit_dist_version, ctk_version, requirement_map in _cuda_toolkit_requirement_maps(): + requirement_specifier_sets = requirement_map.get(normalized_owner_name, ()) + if not any( + _version_satisfies_specifiers(owner_version, specifiers) for specifiers in requirement_specifier_sets + ): + continue + matched_versions[ctk_version] = ( + f"wheel metadata via {owner_name}=={owner_version} pinned by cuda-toolkit=={toolkit_dist_version}" + ) + if len(matched_versions) != 1: + return None + [(ctk_version, source)] = matched_versions.items() + return CtkMetadata(ctk_version=ctk_version, ctk_root=None, source=source) + + +def _normalized_ctk_root_for_cuda_header(cuda_header_path: Path) -> Path: + ctk_root = cuda_header_path.parent.parent + if ctk_root.parent.name == "targets": + return ctk_root.parent.parent + return ctk_root + + +@functools.cache +def _cuda_header_metadata_for_ctk_root_candidate(ctk_root_candidate: str) -> CtkMetadata | None: + candidate_path = Path(ctk_root_candidate) + header_paths: list[Path] = [] + + direct_header = candidate_path / "include" / "cuda.h" + if direct_header.is_file(): + header_paths.append(direct_header) + + targets_dir = candidate_path / "targets" + if targets_dir.is_dir(): + header_paths.extend(sorted(path for path in targets_dir.glob("*/include/cuda.h") if path.is_file())) + + matches: list[tuple[CtkVersion, Path, Path]] = [] + for cuda_header_path in header_paths: + try: + version = read_cuda_header_version(str(cuda_header_path)) + except ReadCudaHeaderVersionError: + continue + matches.append( + ( + CtkVersion(major=version.major, minor=version.minor), + _normalized_ctk_root_for_cuda_header(cuda_header_path), + cuda_header_path, + ) + ) + + if not matches: + return None + + ctk_version, ctk_root, source_path = matches[0] + if any(other_version != ctk_version for other_version, _other_root, _other_source in matches[1:]): + return None + + return CtkMetadata( + ctk_version=ctk_version, + ctk_root=str(ctk_root), + source=f"cuda.h at {source_path}", + ) + + +def _ctk_metadata_for_abs_path(abs_path: str) -> CtkMetadata | None: + current = Path(abs_path) + if current.is_file(): + current = current.parent + for candidate in (current, *current.parents): + ctk_metadata = _cuda_header_metadata_for_ctk_root_candidate(str(candidate)) + if ctk_metadata is not None: + return ctk_metadata + return _wheel_metadata_for_abs_path(abs_path) + + +def _resolve_item( + *, + name: str, + kind: ItemKind, + packaged_with: PackagedWith, + abs_path: str, + found_via: str | None, +) -> ResolvedItem: + ctk_metadata = _ctk_metadata_for_abs_path(abs_path) + return ResolvedItem( + name=name, + kind=kind, + packaged_with=packaged_with, + abs_path=abs_path, + found_via=found_via, + ctk_root=None if ctk_metadata is None else ctk_metadata.ctk_root, + ctk_version=None if ctk_metadata is None else ctk_metadata.ctk_version, + ctk_version_source=None if ctk_metadata is None else ctk_metadata.source, + ) + + +def _resolve_dynamic_lib_item(libname: str, loaded: LoadedDL) -> ResolvedItem: + if loaded.abs_path is None: + raise CompatibilityInsufficientMetadataError( + f"Could not determine an absolute path for dynamic library {libname!r}." + ) + desc = LIB_DESCRIPTORS[libname] + return _resolve_item( + name=libname, + kind="dynamic-lib", + packaged_with=desc.packaged_with, + abs_path=loaded.abs_path, + found_via=loaded.found_via, + ) + + +def _resolve_header_item(libname: str, located: LocatedHeaderDir) -> ResolvedItem: + if located.abs_path is None: + raise CompatibilityInsufficientMetadataError( + f"Could not determine an absolute path for header directory {libname!r}." + ) + desc = HEADER_DESCRIPTORS[libname] + metadata_abs_path = os.path.join(located.abs_path, desc.header_basename) + return _resolve_item( + name=libname, + kind="header-dir", + packaged_with=desc.packaged_with, + abs_path=metadata_abs_path, + found_via=located.found_via, + ) + + +def _resolve_static_lib_item(located: LocatedStaticLib) -> ResolvedItem: + packaged_with = _STATIC_LIBS_PACKAGED_WITH[located.name] + return _resolve_item( + name=located.name, + kind="static-lib", + packaged_with=packaged_with, + abs_path=located.abs_path, + found_via=located.found_via, + ) + + +def _resolve_bitcode_lib_item(located: LocatedBitcodeLib) -> ResolvedItem: + packaged_with = _BITCODE_LIBS_PACKAGED_WITH[located.name] + return _resolve_item( + name=located.name, + kind="bitcode-lib", + packaged_with=packaged_with, + abs_path=located.abs_path, + found_via=located.found_via, + ) + + +def _resolve_binary_item(utility_name: str, abs_path: str) -> ResolvedItem: + packaged_with = _BINARY_PACKAGED_WITH[utility_name] + return _resolve_item( + name=utility_name, + kind="binary", + packaged_with=packaged_with, + abs_path=abs_path, + found_via=None, + ) + + +def compatibility_check( + driver_cuda_version: DriverCudaVersion, item1: ResolvedItem, item2: ResolvedItem +) -> CompatibilityResult: + for item in (item1, item2): + if item.packaged_with != "ctk": + return CompatibilityResult( + status="insufficient_metadata", + message=( + "v1 compatibility checks only give definitive answers for " + f"packaged_with='ctk' items. {item.describe()} is packaged_with={item.packaged_with!r}." + ), + ) + if item.ctk_version is None or item.ctk_version_source is None: + return CompatibilityResult( + status="insufficient_metadata", + message=( + "v1 compatibility checks require either an enclosing CUDA Toolkit root " + "with cuda.h or wheel metadata that can be traced to an installed " + f"cuda-toolkit distribution. Could not determine the CTK version for {item.describe()}." + ), + ) + + assert item1.ctk_version is not None + assert item2.ctk_version is not None + + if item1.ctk_version != item2.ctk_version: + return CompatibilityResult( + status="incompatible", + message=( + f"{item1.describe()} resolves to CTK {item1.ctk_version}, while " + f"{item2.describe()} resolves to CTK {item2.ctk_version}. " + "v1 requires an exact CTK major.minor match." + ), + ) + + if driver_cuda_version.major < item1.ctk_version.major: + return CompatibilityResult( + status="incompatible", + message=( + f"Driver version {driver_cuda_version.encoded} only supports CUDA major version {driver_cuda_version.major}, " + f"but {item1.describe()} requires CTK {item1.ctk_version}. " + "v1 requires driver_major >= ctk_major." + ), + ) + + return CompatibilityResult( + status="compatible", + message=( + f"{item1.describe()} and {item2.describe()} both resolve to CTK {item1.ctk_version}, " + f"and driver version {driver_cuda_version.encoded} satisfies the v1 driver guard rail." + ), + ) + + +class CompatibilityGuardRails: + """Resolve CUDA artifacts while enforcing minimal v1 compatibility guard rails.""" + + def __init__( + self, + *, + ctk_major: ConstraintArg = None, + ctk_minor: ConstraintArg = None, + driver_cuda_version: DriverCudaVersion | None = None, + ) -> None: + self._ctk_major_constraint = _coerce_constraint("ctk_major", ctk_major) + self._ctk_minor_constraint = _coerce_constraint("ctk_minor", ctk_minor) + self._configured_driver_cuda_version = driver_cuda_version + self._driver_cuda_version = driver_cuda_version + self._resolved_items: list[ResolvedItem] = [] + + def _get_driver_cuda_version(self) -> DriverCudaVersion: + if self._driver_cuda_version is None: + try: + self._driver_cuda_version = query_driver_cuda_version() + except QueryDriverCudaVersionError as exc: + raise CompatibilityCheckError( + "Failed to query the CUDA driver version needed for compatibility checks." + ) from exc + return self._driver_cuda_version + + def _enforce_supported_packaging(self, item: ResolvedItem) -> None: + if item.packaged_with == "ctk": + return + raise CompatibilityInsufficientMetadataError( + "v1 compatibility checks only give definitive answers for " + f"packaged_with='ctk' items, plus compatibility-neutral driver libraries. " + f"{item.describe()} is packaged_with={item.packaged_with!r}." + ) + + def _enforce_ctk_metadata(self, item: ResolvedItem) -> None: + if item.ctk_version is not None and item.ctk_version_source is not None: + return + raise CompatibilityInsufficientMetadataError( + "v1 compatibility checks require either an enclosing CUDA Toolkit root " + "with cuda.h or wheel metadata that can be traced to an installed " + f"cuda-toolkit distribution. Could not determine the CTK version for {item.describe()}." + ) + + def _enforce_constraints(self, item: ResolvedItem) -> None: + assert item.ctk_version is not None + if self._ctk_major_constraint is not None and not self._ctk_major_constraint.matches(item.ctk_version.major): + raise CompatibilityCheckError( + f"{item.describe()} resolves to CTK {item.ctk_version}, which does not satisfy " + f"ctk_major{self._ctk_major_constraint}." + ) + if self._ctk_minor_constraint is not None and not self._ctk_minor_constraint.matches(item.ctk_version.minor): + raise CompatibilityCheckError( + f"{item.describe()} resolves to CTK {item.ctk_version}, which does not satisfy " + f"ctk_minor{self._ctk_minor_constraint}." + ) + + def _anchor_item(self) -> ResolvedItem | None: + for item in self._resolved_items: + if item.packaged_with == "ctk": + return item + return None + + def _remember(self, item: ResolvedItem) -> None: + if item not in self._resolved_items: + self._resolved_items.append(item) + + def _reset_for_testing(self) -> None: + self._driver_cuda_version = self._configured_driver_cuda_version + self._resolved_items.clear() + + def _register_and_check(self, item: ResolvedItem) -> None: + # Driver libraries come from the installed display driver rather than a + # CUDA Toolkit line, so they do not need CTK metadata and must not lock + # the process-wide CTK anchor. + if item.packaged_with == "driver": + self._remember(item) + return + self._enforce_supported_packaging(item) + self._enforce_ctk_metadata(item) + self._enforce_constraints(item) + anchor = self._anchor_item() + if anchor is None: + anchor = item + compatibility_check(self._get_driver_cuda_version(), anchor, item).require_compatible() + self._remember(item) + + def load_nvidia_dynamic_lib(self, libname: str) -> LoadedDL: + """Load a CUDA dynamic library and reject v1-incompatible resolutions.""" + loaded = _load_nvidia_dynamic_lib(libname) + self._register_and_check(_resolve_dynamic_lib_item(libname, loaded)) + return loaded + + def locate_nvidia_header_directory(self, libname: str) -> LocatedHeaderDir | None: + """Locate a CUDA header directory and reject v1-incompatible resolutions.""" + located = _locate_nvidia_header_directory(libname) + if located is None: + return None + self._register_and_check(_resolve_header_item(libname, located)) + return located + + def find_nvidia_header_directory(self, libname: str) -> str | None: + """Locate a CUDA header directory and return only the path string.""" + located = self.locate_nvidia_header_directory(libname) + return None if located is None else located.abs_path + + def locate_static_lib(self, name: str) -> LocatedStaticLib: + """Locate a CUDA static library and reject v1-incompatible resolutions.""" + located = _locate_static_lib(name) + self._register_and_check(_resolve_static_lib_item(located)) + return located + + def find_static_lib(self, name: str) -> str: + """Locate a CUDA static library and return only the path string.""" + abs_path = self.locate_static_lib(name).abs_path + assert isinstance(abs_path, str) + return abs_path + + def locate_bitcode_lib(self, name: str) -> LocatedBitcodeLib: + """Locate a CUDA bitcode library and reject v1-incompatible resolutions.""" + located = _locate_bitcode_lib(name) + self._register_and_check(_resolve_bitcode_lib_item(located)) + return located + + def find_bitcode_lib(self, name: str) -> str: + """Locate a CUDA bitcode library and return only the path string.""" + abs_path = self.locate_bitcode_lib(name).abs_path + assert isinstance(abs_path, str) + return abs_path + + def find_nvidia_binary_utility(self, utility_name: str) -> str | None: + """Locate a CUDA binary utility and reject v1-incompatible resolutions.""" + abs_path = _find_nvidia_binary_utility(utility_name) + if abs_path is None: + return None + self._register_and_check(_resolve_binary_item(utility_name, abs_path)) + assert isinstance(abs_path, str) + return abs_path diff --git a/cuda_pathfinder/cuda/pathfinder/_process_wide_compatibility_guard_rails.py b/cuda_pathfinder/cuda/pathfinder/_process_wide_compatibility_guard_rails.py new file mode 100644 index 0000000000..d66e8243be --- /dev/null +++ b/cuda_pathfinder/cuda/pathfinder/_process_wide_compatibility_guard_rails.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +import sys +from collections.abc import Callable +from typing import Protocol, TypeVar, cast + +from cuda.pathfinder._binaries.find_nvidia_binary_utility import ( + find_nvidia_binary_utility as _find_nvidia_binary_utility, +) +from cuda.pathfinder._compatibility_guard_rails import ( + CompatibilityGuardRails, + CompatibilityInsufficientMetadataError, +) +from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL +from cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib import ( + load_nvidia_dynamic_lib as _load_nvidia_dynamic_lib, +) +from cuda.pathfinder._headers.find_nvidia_headers import ( + LocatedHeaderDir, +) +from cuda.pathfinder._headers.find_nvidia_headers import ( + find_nvidia_header_directory as _find_nvidia_header_directory_impl, +) +from cuda.pathfinder._headers.find_nvidia_headers import ( + locate_nvidia_header_directory as _locate_nvidia_header_directory, +) +from cuda.pathfinder._static_libs.find_bitcode_lib import ( + LocatedBitcodeLib, +) +from cuda.pathfinder._static_libs.find_bitcode_lib import ( + find_bitcode_lib as _find_bitcode_lib, +) +from cuda.pathfinder._static_libs.find_bitcode_lib import ( + locate_bitcode_lib as _locate_bitcode_lib, +) +from cuda.pathfinder._static_libs.find_static_lib import ( + LocatedStaticLib, +) +from cuda.pathfinder._static_libs.find_static_lib import ( + find_static_lib as _find_static_lib, +) +from cuda.pathfinder._static_libs.find_static_lib import ( + locate_static_lib as _locate_static_lib, +) + +_T = TypeVar("_T") +_COMPATIBILITY_GUARD_RAILS_ENV_VAR = "CUDA_PATHFINDER_COMPATIBILITY_GUARD_RAILS" +_COMPATIBILITY_GUARD_RAILS_MODES = ("off", "best_effort", "strict") + + +class _ProcessWideGuardRailsApi(Protocol): + def load_nvidia_dynamic_lib(self, libname: str) -> LoadedDL: ... + + def locate_nvidia_header_directory(self, libname: str) -> LocatedHeaderDir | None: ... + + def find_nvidia_header_directory(self, libname: str) -> str | None: ... + + def locate_static_lib(self, name: str) -> LocatedStaticLib: ... + + def find_static_lib(self, name: str) -> str: ... + + def locate_bitcode_lib(self, name: str) -> LocatedBitcodeLib: ... + + def find_bitcode_lib(self, name: str) -> str: ... + + def find_nvidia_binary_utility(self, utility_name: str) -> str | None: ... + + +class _PublicPathfinderModule(Protocol): + process_wide_compatibility_guard_rails: object + + +process_wide_compatibility_guard_rails: CompatibilityGuardRails = CompatibilityGuardRails() + + +def _compatibility_guard_rails_mode() -> str: + value = os.environ.get(_COMPATIBILITY_GUARD_RAILS_ENV_VAR) + if not value: + return "strict" + if value in _COMPATIBILITY_GUARD_RAILS_MODES: + return value + allowed_values = ", ".join(repr(mode) for mode in _COMPATIBILITY_GUARD_RAILS_MODES) + raise RuntimeError( + f"Invalid {_COMPATIBILITY_GUARD_RAILS_ENV_VAR}={value!r}. " + f"Allowed values: {allowed_values}. Unset or empty defaults to 'strict'." + ) + + +def _public_module() -> _PublicPathfinderModule | None: + public_module = sys.modules.get("cuda.pathfinder") + if public_module is None: + return None + return cast(_PublicPathfinderModule, public_module) + + +def _current_process_wide_compatibility_guard_rails() -> _ProcessWideGuardRailsApi: + public_module = _public_module() + if public_module is None: + return cast(_ProcessWideGuardRailsApi, process_wide_compatibility_guard_rails) + return cast(_ProcessWideGuardRailsApi, public_module.process_wide_compatibility_guard_rails) + + +def _reset_process_wide_compatibility_guard_rails() -> None: + current = _current_process_wide_compatibility_guard_rails() + if isinstance(current, CompatibilityGuardRails): + current._reset_for_testing() + return + public_module = _public_module() + if public_module is None: + global process_wide_compatibility_guard_rails + process_wide_compatibility_guard_rails = CompatibilityGuardRails() + return + public_module.process_wide_compatibility_guard_rails = CompatibilityGuardRails() + + +def _try_process_wide_guard_rails_then_fallback(guard_rails_call: Callable[[], _T], raw_call: Callable[[], _T]) -> _T: + mode = _compatibility_guard_rails_mode() + if mode == "off": + return raw_call() + try: + return guard_rails_call() + except CompatibilityInsufficientMetadataError: + if mode == "best_effort": + return raw_call() + raise + + +def _cache_clear_with_process_state_reset(cache_clear: Callable[[], object]) -> Callable[[], None]: + def clear() -> None: + cache_clear() + _reset_process_wide_compatibility_guard_rails() + + return clear + + +def load_nvidia_dynamic_lib(libname: str) -> LoadedDL: + """Load a CUDA dynamic library via the process-wide compatibility guard rails.""" + return _try_process_wide_guard_rails_then_fallback( + lambda: _current_process_wide_compatibility_guard_rails().load_nvidia_dynamic_lib(libname), + lambda: _load_nvidia_dynamic_lib(libname), + ) + + +def locate_nvidia_header_directory(libname: str) -> LocatedHeaderDir | None: + """Locate a CUDA header directory via the process-wide compatibility guard rails.""" + return _try_process_wide_guard_rails_then_fallback( + lambda: _current_process_wide_compatibility_guard_rails().locate_nvidia_header_directory(libname), + lambda: _locate_nvidia_header_directory(libname), + ) + + +def find_nvidia_header_directory(libname: str) -> str | None: + """Locate a CUDA header directory and return its path string.""" + abs_path = _try_process_wide_guard_rails_then_fallback( + lambda: _current_process_wide_compatibility_guard_rails().find_nvidia_header_directory(libname), + lambda: _find_nvidia_header_directory_impl(libname), + ) + assert abs_path is None or isinstance(abs_path, str) + return abs_path + + +def locate_static_lib(name: str) -> LocatedStaticLib: + """Locate a CUDA static library via the process-wide compatibility guard rails.""" + return _try_process_wide_guard_rails_then_fallback( + lambda: _current_process_wide_compatibility_guard_rails().locate_static_lib(name), + lambda: _locate_static_lib(name), + ) + + +def find_static_lib(name: str) -> str: + """Locate a CUDA static library and return its path string.""" + abs_path = _try_process_wide_guard_rails_then_fallback( + lambda: _current_process_wide_compatibility_guard_rails().find_static_lib(name), + lambda: _find_static_lib(name), + ) + assert isinstance(abs_path, str) + return abs_path + + +def locate_bitcode_lib(name: str) -> LocatedBitcodeLib: + """Locate a CUDA bitcode library via the process-wide compatibility guard rails.""" + return _try_process_wide_guard_rails_then_fallback( + lambda: _current_process_wide_compatibility_guard_rails().locate_bitcode_lib(name), + lambda: _locate_bitcode_lib(name), + ) + + +def find_bitcode_lib(name: str) -> str: + """Locate a CUDA bitcode library and return its path string.""" + abs_path = _try_process_wide_guard_rails_then_fallback( + lambda: _current_process_wide_compatibility_guard_rails().find_bitcode_lib(name), + lambda: _find_bitcode_lib(name), + ) + assert isinstance(abs_path, str) + return abs_path + + +def find_nvidia_binary_utility(utility_name: str) -> str | None: + """Locate a CUDA binary utility via the process-wide compatibility guard rails.""" + abs_path = _try_process_wide_guard_rails_then_fallback( + lambda: _current_process_wide_compatibility_guard_rails().find_nvidia_binary_utility(utility_name), + lambda: _find_nvidia_binary_utility(utility_name), + ) + assert abs_path is None or isinstance(abs_path, str) + return abs_path + + +load_nvidia_dynamic_lib.cache_clear = _cache_clear_with_process_state_reset( # type: ignore[attr-defined] + _load_nvidia_dynamic_lib.cache_clear +) +locate_nvidia_header_directory.cache_clear = _cache_clear_with_process_state_reset( # type: ignore[attr-defined] + _locate_nvidia_header_directory.cache_clear +) +find_nvidia_binary_utility.cache_clear = _cache_clear_with_process_state_reset( # type: ignore[attr-defined] + _find_nvidia_binary_utility.cache_clear +) diff --git a/cuda_pathfinder/cuda/pathfinder/_static_libs/find_static_lib.py b/cuda_pathfinder/cuda/pathfinder/_static_libs/find_static_lib.py index 22cea7daad..804b1c04be 100644 --- a/cuda_pathfinder/cuda/pathfinder/_static_libs/find_static_lib.py +++ b/cuda_pathfinder/cuda/pathfinder/_static_libs/find_static_lib.py @@ -28,7 +28,7 @@ class LocatedStaticLib: class _StaticLibInfo(TypedDict): filename: str ctk_rel_paths: tuple[str, ...] - conda_rel_path: str + conda_rel_paths: tuple[str, ...] site_packages_dirs: tuple[str, ...] @@ -36,7 +36,7 @@ class _StaticLibInfo(TypedDict): "cudadevrt": { "filename": "cudadevrt.lib" if IS_WINDOWS else "libcudadevrt.a", "ctk_rel_paths": (os.path.join("lib", "x64"),) if IS_WINDOWS else ("lib64", "lib"), - "conda_rel_path": os.path.join("lib", "x64") if IS_WINDOWS else "lib", + "conda_rel_paths": ((os.path.join("lib", "x64"), "lib") if IS_WINDOWS else ("lib",)), "site_packages_dirs": ( ("nvidia/cu13/lib/x64", "nvidia/cuda_runtime/lib/x64") if IS_WINDOWS @@ -66,7 +66,7 @@ def __init__(self, name: str) -> None: self.config: _StaticLibInfo = _SUPPORTED_STATIC_LIBS_INFO[name] self.filename: str = self.config["filename"] self.ctk_rel_paths: tuple[str, ...] = self.config["ctk_rel_paths"] - self.conda_rel_path: str = self.config["conda_rel_path"] + self.conda_rel_paths: tuple[str, ...] = self.config["conda_rel_paths"] self.site_packages_dirs: tuple[str, ...] = self.config["site_packages_dirs"] self.error_messages: list[str] = [] self.attachments: list[str] = [] @@ -86,9 +86,10 @@ def try_with_conda_prefix(self) -> str | None: return None anchor = os.path.join(conda_prefix, "Library") if IS_WINDOWS else conda_prefix - file_path = os.path.join(anchor, self.conda_rel_path, self.filename) - if os.path.isfile(file_path): - return file_path + for rel_path in self.conda_rel_paths: + file_path = os.path.join(anchor, rel_path, self.filename) + if os.path.isfile(file_path): + return file_path return None def try_with_cuda_home(self) -> str | None: diff --git a/cuda_pathfinder/cuda/pathfinder/_utils/driver_info.py b/cuda_pathfinder/cuda/pathfinder/_utils/driver_info.py index a5d4d167d3..78e833f9ba 100644 --- a/cuda_pathfinder/cuda/pathfinder/_utils/driver_info.py +++ b/cuda_pathfinder/cuda/pathfinder/_utils/driver_info.py @@ -7,11 +7,13 @@ import functools from collections.abc import Callable from dataclasses import dataclass +from typing import cast from cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib import ( load_nvidia_dynamic_lib as _load_nvidia_dynamic_lib, ) from cuda.pathfinder._utils.platform_aware import IS_WINDOWS +from cuda.pathfinder._utils.toolkit_info import EncodedCudaVersion class QueryDriverCudaVersionError(RuntimeError): @@ -19,7 +21,7 @@ class QueryDriverCudaVersionError(RuntimeError): @dataclass(frozen=True, slots=True) -class DriverCudaVersion: +class DriverCudaVersion(EncodedCudaVersion): """ CUDA-facing driver version reported by ``cuDriverGetVersion()``. @@ -41,21 +43,13 @@ class DriverCudaVersion: to ``Driver Version: 595.58.03``. """ - encoded: int - major: int - minor: int - @functools.cache def query_driver_cuda_version() -> DriverCudaVersion: """Return the CUDA driver version parsed into its major/minor components.""" try: encoded = _query_driver_cuda_version_int() - return DriverCudaVersion( - encoded=encoded, - major=encoded // 1000, - minor=(encoded % 1000) // 10, - ) + return cast(DriverCudaVersion, DriverCudaVersion.from_encoded(encoded)) except Exception as exc: raise QueryDriverCudaVersionError("Failed to query the CUDA driver version.") from exc diff --git a/cuda_pathfinder/cuda/pathfinder/_utils/toolkit_info.py b/cuda_pathfinder/cuda/pathfinder/_utils/toolkit_info.py new file mode 100644 index 0000000000..431727bf4b --- /dev/null +++ b/cuda_pathfinder/cuda/pathfinder/_utils/toolkit_info.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import functools +import re +from dataclasses import dataclass +from pathlib import Path +from typing import TypeVar + +_CUDA_VERSION_RE = re.compile(r"^\s*#\s*define\s+CUDA_VERSION\s+(?P\d+)\b", re.MULTILINE) +EncodedCudaVersionT = TypeVar("EncodedCudaVersionT", bound="EncodedCudaVersion") + + +@dataclass(frozen=True, slots=True) +class EncodedCudaVersion: + """CUDA major/minor version represented in CUDA's integer ``encoded`` form.""" + + encoded: int + major: int + minor: int + + @classmethod + def from_encoded(cls: type[EncodedCudaVersionT], encoded: int | str) -> EncodedCudaVersionT: + if isinstance(encoded, str): + try: + encoded_int = int(encoded) + except ValueError as exc: + raise ValueError( + f"{cls.__name__}.from_encoded() expected an integer or decimal string, got {encoded!r}." + ) from exc + elif isinstance(encoded, int): + encoded_int = encoded + else: + raise TypeError( + f"{cls.__name__}.from_encoded() expected an integer or decimal string, got {type(encoded).__name__}." + ) + if encoded_int < 0: + raise ValueError( + f"{cls.__name__}.from_encoded() expected a non-negative encoded CUDA version, got {encoded_int}." + ) + # CUDA encodes versions as major * 1000 + minor * 10. The least-significant + # decimal is ignored here: it is 0 in all CUDA releases and is not a patch version. + return cls( + encoded=encoded_int, + major=encoded_int // 1000, + minor=(encoded_int % 1000) // 10, + ) + + +class ReadCudaHeaderVersionError(RuntimeError): + """Raised when ``read_cuda_header_version()`` cannot determine the CTK version from ``cuda.h``.""" + + +@dataclass(frozen=True, slots=True) +class CudaToolkitVersion(EncodedCudaVersion): + """CUDA Toolkit version encoded by the ``CUDA_VERSION`` macro in ``cuda.h``.""" + + +def parse_cuda_header_version(header_text: str) -> CudaToolkitVersion | None: + """Parse the CUDA Toolkit major/minor version from ``cuda.h`` text.""" + match = _CUDA_VERSION_RE.search(header_text) + if match is None: + return None + return CudaToolkitVersion.from_encoded(match.group("encoded")) + + +@functools.cache +def read_cuda_header_version(cuda_header_path: str) -> CudaToolkitVersion: + """Read and parse the CUDA Toolkit major/minor version from ``cuda.h``.""" + try: + header_text = Path(cuda_header_path).read_text(encoding="utf-8", errors="replace") + version = parse_cuda_header_version(header_text) + if version is None: + raise RuntimeError(f"{cuda_header_path!r} does not define CUDA_VERSION.") + return version + except Exception as exc: + raise ReadCudaHeaderVersionError( + f"Failed to read the CUDA Toolkit version from cuda.h at {cuda_header_path!r}." + ) from exc diff --git a/cuda_pathfinder/docs/source/api.rst b/cuda_pathfinder/docs/source/api.rst index e49478c09e..04290a4bbd 100644 --- a/cuda_pathfinder/docs/source/api.rst +++ b/cuda_pathfinder/docs/source/api.rst @@ -18,6 +18,10 @@ CUDA bitcode and static libraries. get_cuda_path_or_home + CompatibilityGuardRails + process_wide_compatibility_guard_rails + CompatibilityCheckError + CompatibilityInsufficientMetadataError SUPPORTED_NVIDIA_LIBNAMES load_nvidia_dynamic_lib LoadedDL diff --git a/cuda_pathfinder/tests/local_helpers.py b/cuda_pathfinder/tests/local_helpers.py index 7893ba8229..bfcfbe207c 100644 --- a/cuda_pathfinder/tests/local_helpers.py +++ b/cuda_pathfinder/tests/local_helpers.py @@ -4,6 +4,26 @@ import functools import importlib.metadata import re +from dataclasses import dataclass +from pathlib import Path + +import pytest + +from cuda.pathfinder._headers.find_nvidia_headers import ( + locate_nvidia_header_directory as locate_nvidia_header_directory_raw, +) +from cuda.pathfinder._utils import driver_info +from cuda.pathfinder._utils.toolkit_info import CudaToolkitVersion, read_cuda_header_version + + +@dataclass(frozen=True, slots=True) +class LocatedRealCudaToolkitVersion: + """Real-host CTK version discovered from ``cuda.h`` next to resolved ``cudart`` headers.""" + + version: CudaToolkitVersion + cuda_h_path: str + header_dir: str + found_via: str @functools.cache @@ -14,3 +34,46 @@ def have_distribution(name_pattern: str) -> bool: for dist in importlib.metadata.distributions() if "Name" in dist.metadata ) + + +@functools.cache +def locate_real_cuda_toolkit_version_from_cuda_h() -> LocatedRealCudaToolkitVersion | None: + """Return the real-host CTK version from ``cuda.h`` if ``cudart`` headers can be located.""" + located = locate_nvidia_header_directory_raw("cudart") + if located is None or located.abs_path is None: + return None + cuda_h_path = Path(located.abs_path) / "cuda.h" + if not cuda_h_path.is_file(): + return None + return LocatedRealCudaToolkitVersion( + version=read_cuda_header_version(str(cuda_h_path)), + cuda_h_path=str(cuda_h_path), + header_dir=located.abs_path, + found_via=located.found_via, + ) + + +def require_real_cuda_toolkit_version_from_cuda_h() -> LocatedRealCudaToolkitVersion: + """Return the real-host CTK version from ``cuda.h`` or skip if it cannot be located.""" + located = locate_nvidia_header_directory_raw("cudart") + if located is None or located.abs_path is None: + pytest.skip("Could not locate cudart headers, so could not find cuda.h for a real CTK installation.") + cuda_h_path = Path(located.abs_path) / "cuda.h" + if not cuda_h_path.is_file(): + pytest.skip( + f"Located cudart headers via {located.found_via} at {located.abs_path!r}, but could not find cuda.h." + ) + return LocatedRealCudaToolkitVersion( + version=read_cuda_header_version(str(cuda_h_path)), + cuda_h_path=str(cuda_h_path), + header_dir=located.abs_path, + found_via=located.found_via, + ) + + +def require_real_driver_cuda_version() -> driver_info.DriverCudaVersion: + """Return the real-host CUDA driver version or skip if it cannot be queried.""" + try: + return driver_info.query_driver_cuda_version() + except driver_info.QueryDriverCudaVersionError as exc: + pytest.skip(f"Could not query the CUDA driver version for a real driver installation: {exc}") diff --git a/cuda_pathfinder/tests/test_compatibility_guard_rails.py b/cuda_pathfinder/tests/test_compatibility_guard_rails.py new file mode 100644 index 0000000000..672ba422ca --- /dev/null +++ b/cuda_pathfinder/tests/test_compatibility_guard_rails.py @@ -0,0 +1,773 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import os +from pathlib import Path + +import pytest +from local_helpers import ( + have_distribution, + locate_real_cuda_toolkit_version_from_cuda_h, + require_real_cuda_toolkit_version_from_cuda_h, + require_real_driver_cuda_version, +) + +import cuda.pathfinder._compatibility_guard_rails as compatibility_module +from cuda import pathfinder +from cuda.pathfinder import ( + BitcodeLibNotFoundError, + CompatibilityCheckError, + CompatibilityGuardRails, + CompatibilityInsufficientMetadataError, + DynamicLibNotFoundError, + LoadedDL, + LocatedBitcodeLib, + LocatedHeaderDir, + LocatedStaticLib, + StaticLibNotFoundError, + process_wide_compatibility_guard_rails, +) +from cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib import _resolve_system_loaded_abs_path_in_subprocess +from cuda.pathfinder._headers.find_nvidia_headers import ( + locate_nvidia_header_directory as locate_nvidia_header_directory_raw, +) +from cuda.pathfinder._utils import driver_info +from cuda.pathfinder._utils.driver_info import DriverCudaVersion, QueryDriverCudaVersionError +from cuda.pathfinder._utils.env_vars import get_cuda_path_or_home +from cuda.pathfinder._utils.toolkit_info import read_cuda_header_version + +STRICTNESS = os.environ.get("CUDA_PATHFINDER_TEST_COMPATIBILITY_GUARD_RAILS_STRICTNESS", "see_what_works") +assert STRICTNESS in ("see_what_works", "all_must_work") +COMPATIBILITY_GUARD_RAILS_ENV_VAR = "CUDA_PATHFINDER_COMPATIBILITY_GUARD_RAILS" +process_wide_module = importlib.import_module("cuda.pathfinder._process_wide_compatibility_guard_rails") + + +@pytest.fixture(autouse=True) +def _default_process_wide_guard_rails_mode(monkeypatch): + monkeypatch.delenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, raising=False) + + +@pytest.fixture +def clear_real_host_probe_caches(): + have_distribution.cache_clear() + locate_real_cuda_toolkit_version_from_cuda_h.cache_clear() + locate_nvidia_header_directory_raw.cache_clear() + _resolve_system_loaded_abs_path_in_subprocess.cache_clear() + get_cuda_path_or_home.cache_clear() + read_cuda_header_version.cache_clear() + driver_info._load_nvidia_dynamic_lib.cache_clear() + driver_info.query_driver_cuda_version.cache_clear() + yield + have_distribution.cache_clear() + locate_real_cuda_toolkit_version_from_cuda_h.cache_clear() + locate_nvidia_header_directory_raw.cache_clear() + _resolve_system_loaded_abs_path_in_subprocess.cache_clear() + get_cuda_path_or_home.cache_clear() + read_cuda_header_version.cache_clear() + driver_info._load_nvidia_dynamic_lib.cache_clear() + driver_info.query_driver_cuda_version.cache_clear() + + +def _write_cuda_h( + ctk_root: Path, + toolkit_version: str, + *, + include_dir_parts: tuple[str, ...] = ("targets", "x86_64-linux", "include"), +) -> None: + parts = toolkit_version.split(".") + if len(parts) < 2: + raise AssertionError(f"Expected at least major.minor in toolkit version, got {toolkit_version!r}") + encoded = int(parts[0]) * 1000 + int(parts[1]) * 10 + cuda_h_path = ctk_root.joinpath(*include_dir_parts, "cuda.h") + cuda_h_path.parent.mkdir(parents=True, exist_ok=True) + cuda_h_path.write_text( + f"#ifndef CUDA_H\n#define CUDA_H\n#define CUDA_VERSION {encoded}\n#endif\n", + encoding="utf-8", + ) + + +def _touch(path: Path) -> str: + path.parent.mkdir(parents=True, exist_ok=True) + path.touch() + return str(path) + + +def _loaded_dl(abs_path: str, *, found_via: str = "CUDA_PATH") -> LoadedDL: + return LoadedDL( + abs_path=abs_path, + was_already_loaded_from_elsewhere=False, + _handle_uint=1, + found_via=found_via, + ) + + +def _located_static_lib(name: str, abs_path: str) -> LocatedStaticLib: + return LocatedStaticLib( + name=name, + abs_path=abs_path, + filename=os.path.basename(abs_path), + found_via="CUDA_PATH", + ) + + +def _located_bitcode_lib(name: str, abs_path: str) -> LocatedBitcodeLib: + return LocatedBitcodeLib( + name=name, + abs_path=abs_path, + filename=os.path.basename(abs_path), + found_via="CUDA_PATH", + ) + + +def _driver_cuda_version(encoded: int) -> DriverCudaVersion: + return DriverCudaVersion.from_encoded(encoded) + + +class _FakeDistribution: + def __init__( + self, + *, + name: str, + version: str, + root: Path, + files: tuple[str, ...] = (), + requires: tuple[str, ...] = (), + ) -> None: + self.metadata = {"Name": name} + self.version = version + self.files = tuple(Path(file) for file in files) + self.requires = list(requires) + self._root = root + + def locate_file(self, file: Path) -> Path: + return self._root / file + + +def _assert_real_ctk_backed_path(path: str) -> None: + norm_path = os.path.normpath(os.path.abspath(path)) + if "site-packages" in Path(norm_path).parts: + return + current = Path(norm_path) + if current.is_file(): + current = current.parent + for candidate in (current, *current.parents): + if (candidate / "include" / "cuda.h").is_file(): + return + if any(path.is_file() for path in (candidate / "targets").glob("*/include/cuda.h")): + return + for env_var in ("CUDA_PATH", "CUDA_HOME"): + ctk_root = os.environ.get(env_var) + if not ctk_root: + continue + norm_ctk_root = os.path.normpath(os.path.abspath(ctk_root)) + if os.path.commonpath((norm_path, norm_ctk_root)) == norm_ctk_root: + return + raise AssertionError( + "Expected a site-packages path, a path under a CTK root with cuda.h, " + f"or a path under CUDA_PATH/CUDA_HOME, got {path!r}" + ) + + +class _DelegatingProcessWideGuardRails: + def __init__(self, method_name: str, return_value: object) -> None: + self._method_name = method_name + self._return_value = return_value + self.calls: list[tuple[str, tuple[object, ...]]] = [] + + def __getattr__(self, name: str): + if name != self._method_name: + raise AttributeError(name) + + def delegated(*args: object) -> object: + self.calls.append((name, args)) + return self._return_value + + return delegated + + +def test_process_wide_compatibility_guard_rails_is_public_singleton(): + assert process_wide_compatibility_guard_rails is pathfinder.process_wide_compatibility_guard_rails + assert isinstance(process_wide_compatibility_guard_rails, CompatibilityGuardRails) + + +@pytest.mark.parametrize( + ("public_api_name", "guard_rails_method_name", "args", "return_value"), + [ + ( + "load_nvidia_dynamic_lib", + "load_nvidia_dynamic_lib", + ("nvrtc",), + _loaded_dl("/opt/mock/libnvrtc.so.12"), + ), + ( + "locate_nvidia_header_directory", + "locate_nvidia_header_directory", + ("nvrtc",), + LocatedHeaderDir(abs_path="/opt/mock/include", found_via="CUDA_PATH"), + ), + ("find_nvidia_header_directory", "find_nvidia_header_directory", ("nvrtc",), "/opt/mock/include"), + ( + "locate_static_lib", + "locate_static_lib", + ("cudadevrt",), + _located_static_lib("cudadevrt", "/opt/mock/libcudadevrt.a"), + ), + ("find_static_lib", "find_static_lib", ("cudadevrt",), "/opt/mock/libcudadevrt.a"), + ( + "locate_bitcode_lib", + "locate_bitcode_lib", + ("device",), + _located_bitcode_lib("device", "/opt/mock/libdevice.10.bc"), + ), + ("find_bitcode_lib", "find_bitcode_lib", ("device",), "/opt/mock/libdevice.10.bc"), + ("find_nvidia_binary_utility", "find_nvidia_binary_utility", ("nvcc",), "/opt/mock/nvcc"), + ], +) +def test_public_apis_route_through_process_wide_guard_rails( + monkeypatch, public_api_name, guard_rails_method_name, args, return_value +): + fake_guard_rails = _DelegatingProcessWideGuardRails(guard_rails_method_name, return_value) + monkeypatch.setattr(pathfinder, "process_wide_compatibility_guard_rails", fake_guard_rails) + + result = getattr(pathfinder, public_api_name)(*args) + + assert result == return_value + assert fake_guard_rails.calls == [(guard_rails_method_name, args)] + + +def test_public_driver_libs_are_allowed_in_strict_mode(monkeypatch, tmp_path): + driver_lib_path = _touch(tmp_path / "driver-root" / "libnvidia-ml.so.1") + + monkeypatch.setattr( + compatibility_module, + "_load_nvidia_dynamic_lib", + lambda _libname: _loaded_dl(driver_lib_path, found_via="system-search"), + ) + monkeypatch.setattr( + pathfinder, + "process_wide_compatibility_guard_rails", + CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)), + ) + + def fail_raw_fallback(_libname: str) -> LoadedDL: + pytest.fail("strict mode must not fall back to raw loading") + + monkeypatch.setattr(process_wide_module, "_load_nvidia_dynamic_lib", fail_raw_fallback) + + loaded = pathfinder.load_nvidia_dynamic_lib("nvml") + + assert loaded.abs_path == driver_lib_path + + +@pytest.mark.parametrize("env_value", [None, ""]) +def test_public_apis_default_to_strict_when_env_var_is_unset_or_empty(monkeypatch, tmp_path, env_value): + lib_path = _touch(tmp_path / "no-cuda-h" / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12") + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path)) + monkeypatch.setattr( + pathfinder, + "process_wide_compatibility_guard_rails", + CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)), + ) + + def fail_raw_fallback(_libname: str) -> LoadedDL: + pytest.fail("strict mode must not fall back to raw loading") + + monkeypatch.setattr(process_wide_module, "_load_nvidia_dynamic_lib", fail_raw_fallback) + if env_value is None: + monkeypatch.delenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, raising=False) + else: + monkeypatch.setenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, env_value) + + with pytest.raises(CompatibilityInsufficientMetadataError, match="cuda.h"): + pathfinder.load_nvidia_dynamic_lib("nvrtc") + + +def test_public_apis_best_effort_fall_back_on_insufficient_metadata(monkeypatch, tmp_path): + guarded_lib_path = _touch(tmp_path / "no-cuda-h" / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12") + raw_loaded = _loaded_dl("/opt/mock/libnvrtc.so.12", found_via="system-search") + + monkeypatch.setenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, "best_effort") + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(guarded_lib_path)) + monkeypatch.setattr(process_wide_module, "_load_nvidia_dynamic_lib", lambda _libname: raw_loaded) + monkeypatch.setattr( + pathfinder, + "process_wide_compatibility_guard_rails", + CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)), + ) + + loaded = pathfinder.load_nvidia_dynamic_lib("nvrtc") + + assert loaded is raw_loaded + + +def test_public_apis_off_bypass_process_wide_guard_rails(monkeypatch): + raw_loaded = _loaded_dl("/opt/mock/libnvrtc.so.12", found_via="system-search") + fake_guard_rails = _DelegatingProcessWideGuardRails( + "load_nvidia_dynamic_lib", + _loaded_dl("/opt/mock/guard-rails/libnvrtc.so.12"), + ) + + monkeypatch.setenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, "off") + monkeypatch.setattr(pathfinder, "process_wide_compatibility_guard_rails", fake_guard_rails) + monkeypatch.setattr(process_wide_module, "_load_nvidia_dynamic_lib", lambda _libname: raw_loaded) + + loaded = pathfinder.load_nvidia_dynamic_lib("nvrtc") + + assert loaded is raw_loaded + assert fake_guard_rails.calls == [] + + +def test_public_apis_reject_invalid_guard_rails_mode(monkeypatch): + monkeypatch.setenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, "unexpected") + + with pytest.raises(RuntimeError, match=COMPATIBILITY_GUARD_RAILS_ENV_VAR) as exc_info: + pathfinder.find_nvidia_binary_utility("nvcc") + + message = str(exc_info.value) + assert "'off'" in message + assert "'best_effort'" in message + assert "'strict'" in message + + +def test_public_apis_share_process_wide_guard_rails_state(monkeypatch, tmp_path): + lib_root = tmp_path / "cuda-12.8" + hdr_root = tmp_path / "cuda-12.9" + _write_cuda_h(lib_root, "12.8.20250303") + _write_cuda_h(hdr_root, "12.9.20250531") + + lib_path = _touch(lib_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12") + hdr_dir = hdr_root / "targets" / "x86_64-linux" / "include" + _touch(hdr_dir / "nvrtc.h") + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path)) + monkeypatch.setattr( + compatibility_module, + "_locate_nvidia_header_directory", + lambda _libname: LocatedHeaderDir(abs_path=str(hdr_dir), found_via="CUDA_PATH"), + ) + monkeypatch.setattr( + pathfinder, + "process_wide_compatibility_guard_rails", + CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)), + ) + + loaded = pathfinder.load_nvidia_dynamic_lib("nvrtc") + + assert loaded.abs_path == lib_path + with pytest.raises(CompatibilityCheckError, match="exact CTK major.minor match"): + pathfinder.find_nvidia_header_directory("nvrtc") + + +def test_load_dynamic_lib_then_find_headers_same_ctk_version(monkeypatch, tmp_path): + ctk_root = tmp_path / "cuda-12.9" + _write_cuda_h(ctk_root, "12.9.20250531") + lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12") + hdr_dir = ctk_root / "targets" / "x86_64-linux" / "include" + _touch(hdr_dir / "nvrtc.h") + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path)) + monkeypatch.setattr( + compatibility_module, + "_locate_nvidia_header_directory", + lambda _libname: LocatedHeaderDir(abs_path=str(hdr_dir), found_via="CUDA_PATH"), + ) + + guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)) + + loaded = guard_rails.load_nvidia_dynamic_lib("nvrtc") + hdr_path = guard_rails.find_nvidia_header_directory("nvrtc") + + assert loaded.abs_path == lib_path + assert hdr_path == str(hdr_dir) + + +def test_exact_ctk_major_minor_match_is_required(monkeypatch, tmp_path): + lib_root = tmp_path / "cuda-12.8" + hdr_root = tmp_path / "cuda-12.9" + _write_cuda_h(lib_root, "12.8.20250303") + _write_cuda_h(hdr_root, "12.9.20250531") + + lib_path = _touch(lib_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12") + hdr_dir = hdr_root / "targets" / "x86_64-linux" / "include" + _touch(hdr_dir / "nvrtc.h") + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path)) + monkeypatch.setattr( + compatibility_module, + "_locate_nvidia_header_directory", + lambda _libname: LocatedHeaderDir(abs_path=str(hdr_dir), found_via="CUDA_PATH"), + ) + + guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)) + guard_rails.load_nvidia_dynamic_lib("nvrtc") + + with pytest.raises(CompatibilityCheckError, match="exact CTK major.minor match"): + guard_rails.find_nvidia_header_directory("nvrtc") + + +def test_driver_major_must_not_be_older_than_ctk_major(monkeypatch, tmp_path): + ctk_root = tmp_path / "cuda-13.0" + _write_cuda_h(ctk_root, "13.0.20251003") + lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.13") + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path)) + + guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(12080)) + + with pytest.raises(CompatibilityCheckError, match="driver_major >= ctk_major"): + guard_rails.load_nvidia_dynamic_lib("nvrtc") + + +def test_missing_cuda_h_raises_insufficient_metadata(monkeypatch, tmp_path): + lib_path = _touch(tmp_path / "no-cuda-h" / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12") + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path)) + + guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)) + + with pytest.raises(CompatibilityInsufficientMetadataError, match="cuda.h"): + guard_rails.load_nvidia_dynamic_lib("nvrtc") + + +def test_windows_style_ctk_root_uses_root_include_cuda_h(monkeypatch, tmp_path): + ctk_root = tmp_path / "cuda-13.2" + _write_cuda_h(ctk_root, "13.2.20251003", include_dir_parts=("include",)) + lib_path = _touch(ctk_root / "bin" / "x64" / "nvrtc64_130_0.dll") + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path)) + + guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)) + + loaded = guard_rails.load_nvidia_dynamic_lib("nvrtc") + + assert loaded.abs_path == lib_path + + +def test_other_packaging_raises_insufficient_metadata(monkeypatch, tmp_path): + abs_path = _touch(tmp_path / "site-packages" / "nvidia" / "nvshmem" / "lib" / "libnvshmem_device.bc") + + monkeypatch.setattr( + compatibility_module, + "_locate_bitcode_lib", + lambda _name: _located_bitcode_lib("nvshmem_device", abs_path), + ) + + guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)) + + with pytest.raises(CompatibilityInsufficientMetadataError, match="packaged_with='ctk'"): + guard_rails.find_bitcode_lib("nvshmem_device") + + +def test_driver_libs_do_not_lock_ctk_anchor(monkeypatch, tmp_path): + driver_lib_path = _touch(tmp_path / "driver-root" / "libnvidia-ml.so.1") + ctk_root = tmp_path / "cuda-12.9" + _write_cuda_h(ctk_root, "12.9.20250531") + ctk_lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12") + + def fake_load_nvidia_dynamic_lib(libname: str) -> LoadedDL: + if libname == "nvml": + return _loaded_dl(driver_lib_path, found_via="system-search") + if libname == "nvrtc": + return _loaded_dl(ctk_lib_path) + raise AssertionError(f"Unexpected libname: {libname!r}") + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", fake_load_nvidia_dynamic_lib) + + guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)) + + driver_loaded = guard_rails.load_nvidia_dynamic_lib("nvml") + ctk_loaded = guard_rails.load_nvidia_dynamic_lib("nvrtc") + + assert driver_loaded.abs_path == driver_lib_path + assert ctk_loaded.abs_path == ctk_lib_path + + +def test_driver_libs_do_not_mask_later_ctk_mismatch(monkeypatch, tmp_path): + driver_lib_path = _touch(tmp_path / "driver-root" / "libnvidia-ml.so.1") + lib_root = tmp_path / "cuda-12.8" + hdr_root = tmp_path / "cuda-12.9" + _write_cuda_h(lib_root, "12.8.20250303") + _write_cuda_h(hdr_root, "12.9.20250531") + + lib_path = _touch(lib_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12") + hdr_dir = hdr_root / "targets" / "x86_64-linux" / "include" + _touch(hdr_dir / "nvrtc.h") + + def fake_load_nvidia_dynamic_lib(libname: str) -> LoadedDL: + if libname == "nvml": + return _loaded_dl(driver_lib_path, found_via="system-search") + if libname == "nvrtc": + return _loaded_dl(lib_path) + raise AssertionError(f"Unexpected libname: {libname!r}") + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", fake_load_nvidia_dynamic_lib) + monkeypatch.setattr( + compatibility_module, + "_locate_nvidia_header_directory", + lambda _libname: LocatedHeaderDir(abs_path=str(hdr_dir), found_via="CUDA_PATH"), + ) + + guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)) + guard_rails.load_nvidia_dynamic_lib("nvml") + guard_rails.load_nvidia_dynamic_lib("nvrtc") + + with pytest.raises(CompatibilityCheckError, match="exact CTK major.minor match"): + guard_rails.find_nvidia_header_directory("nvrtc") + + +@pytest.mark.parametrize( + "requirement", + ( + "nvidia-nvjitlink == 13.2.78.*; extra == 'nvjitlink'", + "nvidia-nvjitlink<14,>=13.2.78; extra == 'nvjitlink'", + ), +) +def test_wheel_metadata_accepts_exact_and_range_requirements(monkeypatch, tmp_path, requirement): + site_packages = tmp_path / "site-packages" + lib_path = _touch(site_packages / "nvidia" / "cu13" / "lib" / "libnvJitLink.so.13") + owner_dist = _FakeDistribution( + name="nvidia-nvjitlink", + version="13.2.78", + root=site_packages, + files=("nvidia/cu13/lib/libnvJitLink.so.13",), + ) + cuda_toolkit_dist = _FakeDistribution( + name="cuda-toolkit", + version="13.2.1", + root=site_packages, + requires=(requirement,), + ) + + compatibility_module._owned_distribution_candidates.cache_clear() + compatibility_module._cuda_toolkit_requirement_maps.cache_clear() + try: + monkeypatch.setattr( + compatibility_module.importlib.metadata, + "distributions", + lambda: (owner_dist, cuda_toolkit_dist), + ) + + metadata = compatibility_module._wheel_metadata_for_abs_path(lib_path) + finally: + compatibility_module._owned_distribution_candidates.cache_clear() + compatibility_module._cuda_toolkit_requirement_maps.cache_clear() + + assert metadata is not None + assert metadata.ctk_version.major == 13 + assert metadata.ctk_version.minor == 2 + assert metadata.source == "wheel metadata via nvidia-nvjitlink==13.2.78 pinned by cuda-toolkit==13.2.1" + + +def test_constraints_accept_string_and_tuple_forms(monkeypatch, tmp_path): + ctk_root = tmp_path / "cuda-12.9" + _write_cuda_h(ctk_root, "12.9.20250531") + lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12") + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path)) + + guard_rails = CompatibilityGuardRails( + ctk_major=(">=", 12), + ctk_minor=">=9", + driver_cuda_version=_driver_cuda_version(13000), + ) + + loaded = guard_rails.load_nvidia_dynamic_lib("nvrtc") + + assert loaded.abs_path == lib_path + + +def test_constraint_failure_raises(monkeypatch, tmp_path): + ctk_root = tmp_path / "cuda-12.9" + _write_cuda_h(ctk_root, "12.9.20250531") + lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12") + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path)) + + guard_rails = CompatibilityGuardRails( + ctk_major=12, + ctk_minor="<9", + driver_cuda_version=_driver_cuda_version(13000), + ) + + with pytest.raises(CompatibilityCheckError, match="ctk_minor<9"): + guard_rails.load_nvidia_dynamic_lib("nvrtc") + + +def test_static_bitcode_and_binary_methods_participate_in_checks(monkeypatch, tmp_path): + ctk_root = tmp_path / "cuda-12.9" + _write_cuda_h(ctk_root, "12.9.20250531") + + lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12") + static_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libcudadevrt.a") + bitcode_path = _touch(ctk_root / "nvvm" / "libdevice" / "libdevice.10.bc") + binary_path = _touch(ctk_root / "bin" / "nvcc") + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path)) + monkeypatch.setattr( + compatibility_module, + "_locate_static_lib", + lambda _name: _located_static_lib("cudadevrt", static_path), + ) + monkeypatch.setattr( + compatibility_module, + "_locate_bitcode_lib", + lambda _name: _located_bitcode_lib("device", bitcode_path), + ) + monkeypatch.setattr( + compatibility_module, + "_find_nvidia_binary_utility", + lambda _utility_name: binary_path, + ) + + guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)) + + guard_rails.load_nvidia_dynamic_lib("nvrtc") + assert guard_rails.find_static_lib("cudadevrt") == static_path + assert guard_rails.find_bitcode_lib("device") == bitcode_path + assert guard_rails.find_nvidia_binary_utility("nvcc") == binary_path + + +def test_guard_rails_query_driver_cuda_version_by_default(monkeypatch, tmp_path): + ctk_root = tmp_path / "cuda-12.9" + _write_cuda_h(ctk_root, "12.9.20250531") + lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12") + + query_calls: list[int] = [] + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path)) + + def fake_query_driver_cuda_version() -> DriverCudaVersion: + query_calls.append(1) + return _driver_cuda_version(13000) + + monkeypatch.setattr(compatibility_module, "query_driver_cuda_version", fake_query_driver_cuda_version) + + guard_rails = CompatibilityGuardRails() + + guard_rails.load_nvidia_dynamic_lib("nvrtc") + guard_rails.load_nvidia_dynamic_lib("nvrtc") + + assert len(query_calls) == 1 + + +def test_guard_rails_wrap_driver_query_failures(monkeypatch, tmp_path): + ctk_root = tmp_path / "cuda-12.9" + _write_cuda_h(ctk_root, "12.9.20250531") + lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12") + + monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path)) + + def fail_query_driver_cuda_version() -> DriverCudaVersion: + raise QueryDriverCudaVersionError("driver query failed") + + monkeypatch.setattr(compatibility_module, "query_driver_cuda_version", fail_query_driver_cuda_version) + + guard_rails = CompatibilityGuardRails() + + with pytest.raises( + CompatibilityCheckError, + match="Failed to query the CUDA driver version needed for compatibility checks", + ) as exc_info: + guard_rails.load_nvidia_dynamic_lib("nvrtc") + + assert isinstance(exc_info.value.__cause__, QueryDriverCudaVersionError) + + +def test_find_nvidia_header_directory_returns_none_when_unresolved(monkeypatch): + monkeypatch.setattr( + compatibility_module, + "_locate_nvidia_header_directory", + lambda _libname: None, + ) + + guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)) + + assert guard_rails.find_nvidia_header_directory("nvrtc") is None + + +@pytest.mark.usefixtures("clear_real_host_probe_caches") +def test_real_driver(info_summary_append): + real_driver = require_real_driver_cuda_version() + info_summary_append( + f"real driver CUDA version={real_driver.major}.{real_driver.minor} (encoded={real_driver.encoded})" + ) + + +@pytest.mark.usefixtures("clear_real_host_probe_caches") +def test_real_ctk(info_summary_append): + real_ctk = require_real_cuda_toolkit_version_from_cuda_h() + info_summary_append( + f"real cuda.h CTK version={real_ctk.version.major}.{real_ctk.version.minor} " + f"via {real_ctk.found_via} at {real_ctk.cuda_h_path!r}" + ) + + +@pytest.mark.usefixtures("clear_real_host_probe_caches") +def test_real_wheel_ctk_items_are_compatible(info_summary_append): + real_ctk = require_real_cuda_toolkit_version_from_cuda_h() + real_driver = require_real_driver_cuda_version() + guard_rails = CompatibilityGuardRails( + ctk_major=real_ctk.version.major, + ctk_minor=real_ctk.version.minor, + driver_cuda_version=real_driver, + ) + + try: + loaded = guard_rails.load_nvidia_dynamic_lib("nvrtc") + header_dir = guard_rails.find_nvidia_header_directory("nvrtc") + static_lib = guard_rails.find_static_lib("cudadevrt") + bitcode_lib = guard_rails.find_bitcode_lib("device") + nvcc = guard_rails.find_nvidia_binary_utility("nvcc") + except ( + CompatibilityCheckError, + CompatibilityInsufficientMetadataError, + DynamicLibNotFoundError, + StaticLibNotFoundError, + BitcodeLibNotFoundError, + ) as exc: + if STRICTNESS == "all_must_work": + raise + pytest.skip(f"real CTK check unavailable: {exc.__class__.__name__}: {exc}") + + assert isinstance(loaded.abs_path, str) + assert header_dir is not None + for path in (loaded.abs_path, header_dir, static_lib, bitcode_lib): + _assert_real_ctk_backed_path(path) + if have_distribution(r"^nvidia-cuda-nvcc-cu12$"): + # For CUDA 12, NVIDIA publishes a PyPI package named nvidia-cuda-nvcc-cu12, + # but the wheels only contain nvcc-adjacent compiler components such as + # ptxas, CRT headers, libnvvm, and libdevice; the nvcc executable itself + # is not included. + if nvcc is not None: + # nvcc found elsewhere, e.g. /usr/local or Conda. + _assert_real_ctk_backed_path(nvcc) + else: + assert nvcc is not None + _assert_real_ctk_backed_path(nvcc) + + +@pytest.mark.usefixtures("clear_real_host_probe_caches") +def test_real_wheel_component_version_does_not_override_ctk_line(info_summary_append): + real_ctk = require_real_cuda_toolkit_version_from_cuda_h() + real_driver = require_real_driver_cuda_version() + guard_rails = CompatibilityGuardRails( + ctk_major=real_ctk.version.major, + ctk_minor=real_ctk.version.minor, + driver_cuda_version=real_driver, + ) + + try: + header_dir = guard_rails.find_nvidia_header_directory("cufft") + except (CompatibilityCheckError, CompatibilityInsufficientMetadataError) as exc: + if STRICTNESS == "all_must_work": + raise + pytest.skip(f"real cufft CTK check unavailable: {exc.__class__.__name__}: {exc}") + + if header_dir is None: + if STRICTNESS == "all_must_work": + raise AssertionError("Expected CTK-backed cufft headers to be discoverable.") + pytest.skip("real cufft CTK check unavailable: cufft headers not found") + + _assert_real_ctk_backed_path(header_dir) diff --git a/cuda_pathfinder/tests/test_driver_lib_loading.py b/cuda_pathfinder/tests/test_driver_lib_loading.py index b97453c9b5..e47edd9001 100644 --- a/cuda_pathfinder/tests/test_driver_lib_loading.py +++ b/cuda_pathfinder/tests/test_driver_lib_loading.py @@ -30,6 +30,7 @@ STRICTNESS = os.environ.get("CUDA_PATHFINDER_TEST_LOAD_NVIDIA_DYNAMIC_LIB_STRICTNESS", "see_what_works") assert STRICTNESS in ("see_what_works", "all_must_work") +COMPATIBILITY_GUARD_RAILS_ENV_VAR = "CUDA_PATHFINDER_COMPATIBILITY_GUARD_RAILS" _MODULE = "cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib" _LOADER_MODULE = "cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib.LOADER" @@ -38,6 +39,11 @@ _NVML_DESC = LIB_DESCRIPTORS["nvml"] +@pytest.fixture(autouse=True) +def _disable_process_wide_compatibility_guard_rails(monkeypatch): + monkeypatch.setenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, "off") + + def _make_loaded_dl(path, found_via): return LoadedDL(path, False, 0xDEAD, found_via) @@ -164,7 +170,7 @@ def test_real_query_driver_cuda_version(info_summary_append): driver_info._load_nvidia_dynamic_lib.cache_clear() driver_info.query_driver_cuda_version.cache_clear() try: - version = driver_info.query_driver_cuda_version() + driver_cuda_version = driver_info.query_driver_cuda_version() except driver_info.QueryDriverCudaVersionError as exc: if STRICTNESS == "all_must_work": raise @@ -174,7 +180,11 @@ def test_real_query_driver_cuda_version(info_summary_append): driver_info._load_nvidia_dynamic_lib.cache_clear() driver_info.query_driver_cuda_version.cache_clear() - info_summary_append(f"driver_version={version.major}.{version.minor} (encoded={version.encoded})") - assert version.encoded > 0 - assert version.major == version.encoded // 1000 - assert version.minor == (version.encoded % 1000) // 10 + info_summary_append( + "driver_cuda_version=" + f"{driver_cuda_version.major}.{driver_cuda_version.minor} " + f"(encoded={driver_cuda_version.encoded})" + ) + assert driver_cuda_version.encoded > 0 + assert driver_cuda_version.major == driver_cuda_version.encoded // 1000 + assert driver_cuda_version.minor == (driver_cuda_version.encoded % 1000) // 10 diff --git a/cuda_pathfinder/tests/test_find_nvidia_binaries.py b/cuda_pathfinder/tests/test_find_nvidia_binaries.py index ec9740cd85..dbdbf5b61f 100644 --- a/cuda_pathfinder/tests/test_find_nvidia_binaries.py +++ b/cuda_pathfinder/tests/test_find_nvidia_binaries.py @@ -14,6 +14,13 @@ SUPPORTED_BINARIES_ALL, ) +COMPATIBILITY_GUARD_RAILS_ENV_VAR = "CUDA_PATHFINDER_COMPATIBILITY_GUARD_RAILS" + + +@pytest.fixture(autouse=True) +def _disable_process_wide_compatibility_guard_rails(monkeypatch): + monkeypatch.setenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, "off") + def test_unknown_utility_name(): with pytest.raises(UnsupportedBinaryError, match=r"'unknown-utility' is not supported"): diff --git a/cuda_pathfinder/tests/test_find_nvidia_headers.py b/cuda_pathfinder/tests/test_find_nvidia_headers.py index e28f64d352..596d0d2b29 100644 --- a/cuda_pathfinder/tests/test_find_nvidia_headers.py +++ b/cuda_pathfinder/tests/test_find_nvidia_headers.py @@ -39,6 +39,13 @@ STRICTNESS = os.environ.get("CUDA_PATHFINDER_TEST_FIND_NVIDIA_HEADERS_STRICTNESS", "see_what_works") assert STRICTNESS in ("see_what_works", "all_must_work") +COMPATIBILITY_GUARD_RAILS_ENV_VAR = "CUDA_PATHFINDER_COMPATIBILITY_GUARD_RAILS" + + +@pytest.fixture(autouse=True) +def _disable_process_wide_compatibility_guard_rails(monkeypatch): + monkeypatch.setenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, "off") + NON_CTK_IMPORTLIB_METADATA_DISTRIBUTIONS_NAMES = { "cusolverMp": r"^nvidia-cusolvermp-.*$", diff --git a/cuda_pathfinder/tests/test_find_static_lib.py b/cuda_pathfinder/tests/test_find_static_lib.py index 2b30aa1201..e5560dcabb 100644 --- a/cuda_pathfinder/tests/test_find_static_lib.py +++ b/cuda_pathfinder/tests/test_find_static_lib.py @@ -78,7 +78,7 @@ def test_locate_static_lib(info_summary_append, libname): @pytest.mark.usefixtures("clear_find_static_lib_cache") def test_locate_static_lib_search_order(monkeypatch, tmp_path): filename = CUDADEVRT_INFO["filename"] - conda_rel_path = CUDADEVRT_INFO["conda_rel_path"] + conda_rel_path = CUDADEVRT_INFO["conda_rel_paths"][0] site_pkg_rel = CUDADEVRT_INFO["site_packages_dirs"][0] site_packages_lib_dir = tmp_path / "site-packages" / Path(site_pkg_rel.replace("/", os.sep)) @@ -117,6 +117,32 @@ def test_locate_static_lib_search_order(monkeypatch, tmp_path): assert located_lib.found_via == "CUDA_PATH" +@pytest.mark.usefixtures("clear_find_static_lib_cache") +def test_locate_static_lib_conda_rel_path_fallback(monkeypatch, tmp_path): + filename = CUDADEVRT_INFO["filename"] + conda_rel_paths = CUDADEVRT_INFO["conda_rel_paths"] + if len(conda_rel_paths) == 1: + monkeypatch.setitem(CUDADEVRT_INFO, "conda_rel_paths", ("missing-first", conda_rel_paths[0])) + conda_rel_paths = CUDADEVRT_INFO["conda_rel_paths"] + + conda_prefix = tmp_path / "conda-prefix" + conda_lib_dir = _conda_anchor(conda_prefix) / Path(conda_rel_paths[1]) + conda_path = _make_static_lib_file(conda_lib_dir, filename) + + monkeypatch.setattr( + find_static_lib_module, + "find_sub_dirs_all_sitepackages", + lambda _sub_dir: [], + ) + monkeypatch.setenv("CONDA_PREFIX", str(conda_prefix)) + monkeypatch.delenv("CUDA_HOME", raising=False) + monkeypatch.delenv("CUDA_PATH", raising=False) + + located_lib = locate_static_lib("cudadevrt") + assert located_lib.abs_path == conda_path + assert located_lib.found_via == "conda" + + @pytest.mark.usefixtures("clear_find_static_lib_cache") def test_find_static_lib_not_found_error_includes_cuda_home_directory_listing(monkeypatch, tmp_path): filename = CUDADEVRT_INFO["filename"] diff --git a/cuda_pathfinder/tests/test_load_nvidia_dynamic_lib.py b/cuda_pathfinder/tests/test_load_nvidia_dynamic_lib.py index 401e7dc13f..c43a8f1741 100644 --- a/cuda_pathfinder/tests/test_load_nvidia_dynamic_lib.py +++ b/cuda_pathfinder/tests/test_load_nvidia_dynamic_lib.py @@ -23,6 +23,12 @@ STRICTNESS = os.environ.get("CUDA_PATHFINDER_TEST_LOAD_NVIDIA_DYNAMIC_LIB_STRICTNESS", "see_what_works") assert STRICTNESS in ("see_what_works", "all_must_work") +COMPATIBILITY_GUARD_RAILS_ENV_VAR = "CUDA_PATHFINDER_COMPATIBILITY_GUARD_RAILS" + + +@pytest.fixture(autouse=True) +def _disable_process_wide_compatibility_guard_rails(monkeypatch): + monkeypatch.setenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, "off") def test_supported_libnames_linux_sonames_consistency(): diff --git a/cuda_pathfinder/tests/test_utils_driver_info.py b/cuda_pathfinder/tests/test_utils_driver_info.py index 21948dadaf..99af76a69b 100644 --- a/cuda_pathfinder/tests/test_utils_driver_info.py +++ b/cuda_pathfinder/tests/test_utils_driver_info.py @@ -73,6 +73,17 @@ def test_query_driver_cuda_version_returns_parsed_dataclass(monkeypatch): ) +def test_driver_cuda_version_from_encoded_returns_subclass_instance(): + version = driver_info.DriverCudaVersion.from_encoded(12080) + + assert version == driver_info.DriverCudaVersion( + encoded=12080, + major=12, + minor=8, + ) + assert type(version) is driver_info.DriverCudaVersion + + def test_query_driver_cuda_version_wraps_internal_failures(monkeypatch): root_cause = RuntimeError("low-level query failed") diff --git a/cuda_pathfinder/tests/test_utils_toolkit_info.py b/cuda_pathfinder/tests/test_utils_toolkit_info.py new file mode 100644 index 0000000000..a62db6b960 --- /dev/null +++ b/cuda_pathfinder/tests/test_utils_toolkit_info.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from cuda.pathfinder._utils import toolkit_info + + +@pytest.fixture(autouse=True) +def _clear_cuda_header_version_cache(): + toolkit_info.read_cuda_header_version.cache_clear() + yield + toolkit_info.read_cuda_header_version.cache_clear() + + +def test_encoded_cuda_version_from_encoded_decodes_major_minor(): + assert toolkit_info.EncodedCudaVersion.from_encoded(13020) == toolkit_info.EncodedCudaVersion( + encoded=13020, + major=13, + minor=2, + ) + + +def test_encoded_cuda_version_from_encoded_accepts_decimal_string(): + assert toolkit_info.EncodedCudaVersion.from_encoded("13020") == toolkit_info.EncodedCudaVersion( + encoded=13020, + major=13, + minor=2, + ) + + +def test_encoded_cuda_version_from_encoded_raises_helpful_error_for_invalid_string(): + with pytest.raises( + ValueError, + match=r"EncodedCudaVersion\.from_encoded\(\) expected an integer or decimal string, got '13\.2'", + ): + toolkit_info.EncodedCudaVersion.from_encoded("13.2") + + +@pytest.mark.parametrize("encoded", [-1, "-1"]) +def test_encoded_cuda_version_from_encoded_rejects_negative_values(encoded): + with pytest.raises( + ValueError, + match=r"EncodedCudaVersion\.from_encoded\(\) expected a non-negative encoded CUDA version, got -1", + ): + toolkit_info.EncodedCudaVersion.from_encoded(encoded) + + +def test_parse_cuda_header_version_returns_parsed_dataclass(): + header_text = """ + #ifndef CUDA_H + #define CUDA_H + #define CUDA_VERSION 13020 + #endif + """ + + assert toolkit_info.parse_cuda_header_version(header_text) == toolkit_info.CudaToolkitVersion( + encoded=13020, + major=13, + minor=2, + ) + + +def test_cuda_toolkit_version_from_encoded_returns_subclass_instance(): + version = toolkit_info.CudaToolkitVersion.from_encoded(12090) + + assert version == toolkit_info.CudaToolkitVersion( + encoded=12090, + major=12, + minor=9, + ) + assert type(version) is toolkit_info.CudaToolkitVersion + + +def test_parse_cuda_header_version_returns_none_when_macro_is_missing(): + header_text = """ + #ifndef CUDA_H + #define CUDA_H + #define CUDA_API_PER_THREAD_DEFAULT_STREAM 1 + #endif + """ + + assert toolkit_info.parse_cuda_header_version(header_text) is None + + +def test_read_cuda_header_version_reads_file_and_returns_parsed_dataclass(tmp_path): + cuda_h_path = tmp_path / "cuda.h" + cuda_h_path.write_text( + """ + #ifndef CUDA_H + #define CUDA_H + #define CUDA_VERSION 12090 /* CUDA 12.9 */ + #endif + """, + encoding="utf-8", + ) + + assert toolkit_info.read_cuda_header_version(str(cuda_h_path)) == toolkit_info.CudaToolkitVersion( + encoded=12090, + major=12, + minor=9, + ) + + +def test_read_cuda_header_version_tolerates_non_utf8_bytes(tmp_path): + cuda_h_path = tmp_path / "cuda.h" + cuda_h_path.write_bytes( + b"#ifndef CUDA_H\n" + b"#define CUDA_H\n" + b"\xff\xfe invalid bytes in comment or banner\n" + b"#define CUDA_VERSION 12080\n" + b"#endif\n" + ) + + assert toolkit_info.read_cuda_header_version(str(cuda_h_path)) == toolkit_info.CudaToolkitVersion( + encoded=12080, + major=12, + minor=8, + ) + + +def test_read_cuda_header_version_wraps_parse_failures(tmp_path): + cuda_h_path = tmp_path / "cuda.h" + cuda_h_path.write_text( + """ + #ifndef CUDA_H + #define CUDA_H + #endif + """, + encoding="utf-8", + ) + + with pytest.raises( + toolkit_info.ReadCudaHeaderVersionError, + match="Failed to read the CUDA Toolkit version from cuda.h", + ) as exc_info: + toolkit_info.read_cuda_header_version(str(cuda_h_path)) + + assert isinstance(exc_info.value.__cause__, RuntimeError) + assert "does not define CUDA_VERSION" in str(exc_info.value.__cause__) diff --git a/toolshed/conda_create_for_pathfinder_testing.ps1 b/toolshed/conda_create_for_pathfinder_testing.ps1 index 115720f6e5..1c0b2999ff 100644 --- a/toolshed/conda_create_for_pathfinder_testing.ps1 +++ b/toolshed/conda_create_for_pathfinder_testing.ps1 @@ -7,22 +7,30 @@ param( ) $ErrorActionPreference = "Stop" +Set-StrictMode -Version Latest + +$cudaMajor = $CudaVersion.Split(".", 2)[0] +switch ($cudaMajor) { + "12" { $pythonVersion = "3.12" } + "13" { $pythonVersion = "3.14" } + default { + throw "Unsupported CUDA major version for this helper: $cudaMajor. Expected a 12.x or 13.x toolkit version." + } +} & "$env:CONDA_EXE" "shell.powershell" "hook" | Out-String | Invoke-Expression -conda create --yes -n "pathfinder_testing_cu$CudaVersion" python=3.13 "cuda-toolkit=$CudaVersion" +conda create --yes -n "pathfinder_testing_cu$CudaVersion" "python=$pythonVersion" "cuda-toolkit=$CudaVersion" conda activate "pathfinder_testing_cu$CudaVersion" +# Keep this list aligned with the Windows-installable subset of +# cuda_pathfinder/pyproject.toml. $cpkgs = @( "cusparselt-dev", "cutensor", - "libcublasmp-dev", + "cutlass", "libcudss-dev", - "libcufftmp-dev", - "libmathdx-dev", - "libnvshmem3", - "libnvshmem-dev", - "libnvpl-fft-dev" + "libmathdx-dev" ) foreach ($cpkg in $cpkgs) { diff --git a/toolshed/conda_create_for_pathfinder_testing.sh b/toolshed/conda_create_for_pathfinder_testing.sh index 1ed57e6765..8674bb1ed0 100755 --- a/toolshed/conda_create_for_pathfinder_testing.sh +++ b/toolshed/conda_create_for_pathfinder_testing.sh @@ -3,26 +3,63 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +set -euo pipefail + if [[ $# -ne 1 ]]; then echo "Usage: $(basename "$0") ctk-major-minor-patch" 1>&2 exit 1 fi +cuda_version="$1" +cuda_major="${cuda_version%%.*}" +uname_m="$(uname -m)" +case "$cuda_major" in + 12) + python_version=3.12 + ;; + 13) + python_version=3.14 + ;; + *) + echo "Unsupported CUDA major version for this helper: $cuda_major" 1>&2 + echo "Expected a 12.x or 13.x toolkit version." 1>&2 + exit 1 + ;; +esac + eval "$(conda shell.bash hook)" -conda create --yes -n "pathfinder_testing_cu$1" python=3.13 cuda-toolkit="$1" -conda activate "pathfinder_testing_cu$1" - -for cpkg in \ - cusparselt-dev \ - cutensor \ - libcublasmp-dev \ - libcudss-dev \ - libcufftmp-dev \ - libmathdx-dev \ - libnvshmem3 \ - libnvshmem-dev \ - libnvpl-fft-dev; do +conda create --yes -n "pathfinder_testing_cu$cuda_version" "python=$python_version" cuda-toolkit="$cuda_version" +set +u +conda activate "pathfinder_testing_cu$cuda_version" +set -u + +# Keep this list aligned with the Linux-installable subset of +# cuda_pathfinder/pyproject.toml. +cpkgs=( + "cusparselt-dev" + "cutensor" + "cutlass" + "libcublasmp-dev" + "libcudss-dev" + "libcufftmp-dev" + "libcusolvermp-dev" + "libmathdx-dev" + "libnvshmem3" + "libnvshmem-dev" +) + +# Keep the conda environment aligned with platform-scoped pyproject groups. +if [[ "$uname_m" == "aarch64" ]]; then + cpkgs+=("libnvpl-fft-dev") + if [[ "$cuda_major" == "13" ]]; then + cpkgs+=("libcudla-dev") + fi +fi + +for cpkg in "${cpkgs[@]}"; do echo "CONDA INSTALL: $cpkg" + set +u conda install -y -c conda-forge "$cpkg" + set -u done