Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions cuda_pathfinder/cuda/pathfinder/_utils/driver_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import ctypes
import functools
from collections.abc import Callable
from dataclasses import dataclass

from cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib import (
load_nvidia_dynamic_lib as _load_nvidia_dynamic_lib,
)
from cuda.pathfinder._utils.platform_aware import IS_WINDOWS


class QueryDriverCudaVersionError(RuntimeError):
"""Raised when ``query_driver_cuda_version()`` cannot determine the CUDA driver version."""


@dataclass(frozen=True, slots=True)
class DriverCudaVersion:
"""
CUDA-facing driver version reported by ``cuDriverGetVersion()``.

The name ``DriverCudaVersion`` is intentionally specific: this dataclass
models the version shown as ``CUDA Version`` in ``nvidia-smi``, not the
graphics driver release shown as ``Driver Version``. More specifically,
it reflects the CUDA user-mode driver (UMD) interface version reported by
``cuDriverGetVersion()``, not the kernel-mode driver (KMD) package
version.

Example ``nvidia-smi`` output::

+---------------------------------------------------------------------+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
+---------------------------------------------------------------------+

For the example above, ``DriverCudaVersion(encoded=13020, major=13,
minor=2)`` corresponds to ``CUDA Version: 13.2``. It does not correspond
to ``Driver Version: 595.58.03``.
"""

encoded: int
major: int
minor: int


@functools.cache
def query_driver_cuda_version() -> DriverCudaVersion:
"""Return the CUDA driver version parsed into its major/minor components."""
try:
encoded = _query_driver_cuda_version_int()
return DriverCudaVersion(
encoded=encoded,
major=encoded // 1000,
minor=(encoded % 1000) // 10,
)
except Exception as exc:
raise QueryDriverCudaVersionError("Failed to query the CUDA driver version.") from exc


def _query_driver_cuda_version_int() -> int:
"""Return the encoded CUDA driver version from ``cuDriverGetVersion()``."""
loaded_cuda = _load_nvidia_dynamic_lib("cuda")
if IS_WINDOWS:
# `ctypes.WinDLL` exists on Windows at runtime. The ignore is only for
# Linux mypy runs, where the platform stubs do not define that attribute.
loader_cls: Callable[[str], ctypes.CDLL] = ctypes.WinDLL # type: ignore[attr-defined]
else:
loader_cls = ctypes.CDLL
driver_lib = loader_cls(loaded_cuda.abs_path)
cu_driver_get_version = driver_lib.cuDriverGetVersion
cu_driver_get_version.argtypes = [ctypes.POINTER(ctypes.c_int)]
cu_driver_get_version.restype = ctypes.c_int
version = ctypes.c_int()
status = cu_driver_get_version(ctypes.byref(version))
if status != 0:
raise RuntimeError(f"Failed to query CUDA driver version via cuDriverGetVersion() (status={status}).")
return version.value
21 changes: 21 additions & 0 deletions cuda_pathfinder/tests/test_driver_lib_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_load_lib_no_cache,
)
from cuda.pathfinder._dynamic_libs.subprocess_protocol import STATUS_NOT_FOUND, parse_dynamic_lib_subprocess_payload
from cuda.pathfinder._utils import driver_info
from cuda.pathfinder._utils.platform_aware import IS_WINDOWS, quote_for_shell

STRICTNESS = os.environ.get("CUDA_PATHFINDER_TEST_LOAD_NVIDIA_DYNAMIC_LIB_STRICTNESS", "see_what_works")
Expand Down Expand Up @@ -157,3 +158,23 @@ def raise_child_process_failed():
assert abs_path is not None
info_summary_append(f"abs_path={quote_for_shell(abs_path)}")
assert os.path.isfile(abs_path)


def test_real_query_driver_cuda_version(info_summary_append):
driver_info._load_nvidia_dynamic_lib.cache_clear()
driver_info.query_driver_cuda_version.cache_clear()
try:
version = driver_info.query_driver_cuda_version()
except driver_info.QueryDriverCudaVersionError as exc:
if STRICTNESS == "all_must_work":
raise
info_summary_append(f"driver version unavailable: {exc.__class__.__name__}: {exc}")
return
finally:
driver_info._load_nvidia_dynamic_lib.cache_clear()
driver_info.query_driver_cuda_version.cache_clear()

info_summary_append(f"driver_version={version.major}.{version.minor} (encoded={version.encoded})")
assert version.encoded > 0
assert version.major == version.encoded // 1000
assert version.minor == (version.encoded % 1000) // 10
101 changes: 101 additions & 0 deletions cuda_pathfinder/tests/test_utils_driver_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import ctypes

import pytest

from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL
from cuda.pathfinder._utils import driver_info


@pytest.fixture(autouse=True)
def _clear_driver_cuda_version_query_cache():
driver_info.query_driver_cuda_version.cache_clear()
yield
driver_info.query_driver_cuda_version.cache_clear()


class _FakeCuDriverGetVersion:
def __init__(self, *, status: int, version: int):
self.argtypes = None
self.restype = None
self._status = status
self._version = version

def __call__(self, version_ptr) -> int:
ctypes.cast(version_ptr, ctypes.POINTER(ctypes.c_int)).contents.value = self._version
return self._status


class _FakeDriverLib:
def __init__(self, *, status: int, version: int):
self.cuDriverGetVersion = _FakeCuDriverGetVersion(status=status, version=version)


def _loaded_cuda(abs_path: str) -> LoadedDL:
return LoadedDL(
abs_path=abs_path,
was_already_loaded_from_elsewhere=False,
_handle_uint=0xBEEF,
found_via="system-search",
)


def test_query_driver_cuda_version_uses_windll_on_windows(monkeypatch):
fake_driver_lib = _FakeDriverLib(status=0, version=12080)
loaded_paths: list[str] = []

monkeypatch.setattr(driver_info, "IS_WINDOWS", True)
monkeypatch.setattr(
driver_info,
"_load_nvidia_dynamic_lib",
lambda _libname: _loaded_cuda(r"C:\Windows\System32\nvcuda.dll"),
)

def fake_windll(abs_path: str):
loaded_paths.append(abs_path)
return fake_driver_lib

monkeypatch.setattr(driver_info.ctypes, "WinDLL", fake_windll, raising=False)

assert driver_info._query_driver_cuda_version_int() == 12080
assert loaded_paths == [r"C:\Windows\System32\nvcuda.dll"]


def test_query_driver_cuda_version_returns_parsed_dataclass(monkeypatch):
monkeypatch.setattr(driver_info, "_query_driver_cuda_version_int", lambda: 12080)

assert driver_info.query_driver_cuda_version() == driver_info.DriverCudaVersion(
encoded=12080,
major=12,
minor=8,
)


def test_query_driver_cuda_version_wraps_internal_failures(monkeypatch):
root_cause = RuntimeError("low-level query failed")

def fail_query_driver_cuda_version_int() -> int:
raise root_cause

monkeypatch.setattr(driver_info, "_query_driver_cuda_version_int", fail_query_driver_cuda_version_int)

with pytest.raises(
driver_info.QueryDriverCudaVersionError,
match="Failed to query the CUDA driver version",
) as exc_info:
driver_info.query_driver_cuda_version()

assert exc_info.value.__cause__ is root_cause


def test_query_driver_cuda_version_int_raises_when_cuda_call_fails(monkeypatch):
fake_driver_lib = _FakeDriverLib(status=1, version=0)

monkeypatch.setattr(driver_info, "IS_WINDOWS", False)
monkeypatch.setattr(driver_info, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_cuda("/usr/lib/libcuda.so.1"))
monkeypatch.setattr(driver_info.ctypes, "CDLL", lambda _abs_path: fake_driver_lib)

with pytest.raises(RuntimeError, match=r"cuDriverGetVersion\(\) \(status=1\)"):
driver_info._query_driver_cuda_version_int()
Loading