diff --git a/pyproject.toml b/pyproject.toml index dd7bfd9..d886093 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,6 @@ disallow_untyped_defs = true exclude = [".venv/", "docs/"] [[tool.mypy.overrides]] -module = "tests.*" disallow_untyped_defs = false [tool.pytest.ini_options] diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index a3c5152..0000000 --- a/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Unit tests for muse2_data_analysis.""" diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 0000000..dd4ce39 --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,69 @@ +"""Tests for helpers.py.""" + +from pathlib import Path +from unittest.mock import Mock, patch + +from muse2_data_analysis import helpers + + +@patch("muse2_data_analysis.helpers.run_muse2") +def test_get_example_input_dir_returns_existing_directory( + run_muse2_mock: Mock, tmp_path: Path +) -> None: + """Check that get_example_input_dir returns an existing input directory.""" + input_dir = tmp_path / "input" + input_dir.mkdir() + + with patch.object(helpers, "_INPUT_DIR", input_dir): + assert helpers.get_example_input_dir() == input_dir + + run_muse2_mock.assert_not_called() + + +@patch("muse2_data_analysis.helpers.run_muse2") +def test_get_example_input_dir_extracts_example_when_missing( + run_muse2_mock: Mock, tmp_path: Path +) -> None: + """Check that get_example_input_dir extracts the example when missing.""" + input_dir = tmp_path / "input" + + with patch.object(helpers, "_INPUT_DIR", input_dir): + assert helpers.get_example_input_dir() == input_dir + + run_muse2_mock.assert_called_once_with( + "example", "extract", helpers.EXAMPLE_NAME, str(input_dir) + ) + + +@patch("muse2_data_analysis.helpers.run_muse2") +def test_get_example_output_dir_returns_existing_directory( + run_muse2_mock: Mock, tmp_path: Path +) -> None: + """Check that get_example_output_dir returns an existing output directory.""" + output_dir = tmp_path / "output" + output_dir.mkdir() + + with patch.object(helpers, "_OUTPUT_DIR", output_dir): + assert helpers.get_example_output_dir() == output_dir + + run_muse2_mock.assert_not_called() + + +@patch("muse2_data_analysis.helpers.run_muse2") +@patch("muse2_data_analysis.helpers.get_example_input_dir") +def test_get_example_output_dir_runs_example_when_missing( + get_example_input_dir_mock: Mock, run_muse2_mock: Mock, tmp_path: Path +) -> None: + """Check that get_example_output_dir ensures input exists and runs the model.""" + input_dir = tmp_path / "input" + output_dir = tmp_path / "output" + get_example_input_dir_mock.return_value = input_dir + + with patch.object(helpers, "_INPUT_DIR", input_dir): + with patch.object(helpers, "_OUTPUT_DIR", output_dir): + assert helpers.get_example_output_dir() == output_dir + + get_example_input_dir_mock.assert_called_once_with() + run_muse2_mock.assert_called_once_with( + "run", str(input_dir), "--output-dir", str(output_dir) + ) diff --git a/tests/test_muse2.py b/tests/test_muse2.py new file mode 100644 index 0000000..3d89527 --- /dev/null +++ b/tests/test_muse2.py @@ -0,0 +1,101 @@ +"""Tests for muse2.py.""" + +import subprocess as sp +from collections.abc import Generator +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest + +from muse2_data_analysis import muse2 + + +@pytest.fixture(autouse=True) +def reset_muse2_path() -> Generator[None]: + """Reset the cached MUSE2 path before and after each test. + + Note this won't work if tests are run in parallel. + """ + previous = muse2._muse2_path + muse2._muse2_path = None + yield + muse2._muse2_path = previous + + +@patch("muse2_data_analysis.muse2.shutil.which") +@patch("muse2_data_analysis.muse2.os.getenv") +def test_find_muse2_uses_environment_variable( + getenv_mock: Mock, which_mock: Mock +) -> None: + """Check that find_muse2 prefers MUSE2_PATH over PATH lookup.""" + getenv_mock.return_value = "/tmp/muse2-from-env" + + assert muse2.find_muse2() == Path("/tmp/muse2-from-env") + which_mock.assert_not_called() + + +@patch("muse2_data_analysis.muse2.shutil.which") +@patch("muse2_data_analysis.muse2.os.getenv") +def test_find_muse2_falls_back_to_path_lookup( + getenv_mock: Mock, which_mock: Mock +) -> None: + """Check that find_muse2 uses PATH lookup when no env var is set.""" + getenv_mock.return_value = None + which_mock.return_value = "/tmp/muse2-from-path" + + assert muse2.find_muse2() == Path("/tmp/muse2-from-path") + + +@patch("muse2_data_analysis.muse2.shutil.which") +@patch("muse2_data_analysis.muse2.os.getenv") +def test_find_muse2_raises_when_binary_is_not_available( + getenv_mock: Mock, which_mock: Mock +) -> None: + """Check that find_muse2 raises if no muse2 binary can be found.""" + getenv_mock.return_value = None + which_mock.return_value = None + + with pytest.raises(RuntimeError, match="Could not find path to muse2"): + muse2.find_muse2() + + +@patch("muse2_data_analysis.muse2.sp.run") +@patch("muse2_data_analysis.muse2.find_muse2") +def test_run_muse2_returns_stdout(find_muse2_mock: Mock, run_mock: Mock) -> None: + """Check that run_muse2 returns combined stdout/stderr output.""" + find_muse2_mock.return_value = Path("/tmp/muse2") + run_mock.return_value = sp.CompletedProcess( + args=(Path("/tmp/muse2"), "--version"), + returncode=0, + stdout="muse2 version output", + ) + + assert muse2.run_muse2("--version") == "muse2 version output" + run_mock.assert_called_once_with( + (Path("/tmp/muse2"), "--version"), + text=True, + stdout=sp.PIPE, + stderr=sp.STDOUT, + env={ + **muse2.os.environ, + "MUSE2_USE_DEFAULT_SETTINGS": "1", + }, + check=True, + ) + + +@patch("muse2_data_analysis.muse2.sp.run") +@patch("muse2_data_analysis.muse2.find_muse2") +def test_run_muse2_raises_runtime_error_on_subprocess_failure( + find_muse2_mock: Mock, run_mock: Mock +) -> None: + """Check that run_muse2 re-raises subprocess failures as RuntimeError.""" + find_muse2_mock.return_value = Path("/tmp/muse2") + run_mock.side_effect = sp.CalledProcessError( + 1, + (Path("/tmp/muse2"), "bad-arg"), + output="combined output", + ) + + with pytest.raises(RuntimeError, match="Error running muse2: combined output"): + muse2.run_muse2("bad-arg") diff --git a/tests/test_muse2_data_analysis.py b/tests/test_muse2_data_analysis.py deleted file mode 100644 index de7943b..0000000 --- a/tests/test_muse2_data_analysis.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Tests for the main module.""" - -from muse2_data_analysis import __version__ - - -def test_version(): - """Check that the version is acceptable.""" - assert isinstance(__version__, str)