From 27436f60540dda9d4c92c74346afe90dd2da573f Mon Sep 17 00:00:00 2001 From: voorhs Date: Mon, 8 Jun 2026 12:25:01 +0300 Subject: [PATCH] =?UTF-8?q?test(types):=20annotate=20tests/pipeline=20(41?= =?UTF-8?q?=E2=86=920=20mypy=20errors)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.7 --- tests/pipeline/test_inference.py | 40 ++++++++++++++--- tests/pipeline/test_optimization.py | 32 ++++++++++---- tests/pipeline/test_pipeline_interruption.py | 45 ++++++++++++++------ tests/pipeline/test_presets.py | 36 ++++++++++++---- tests/pipeline/test_validation.py | 11 ++++- 5 files changed, 124 insertions(+), 40 deletions(-) diff --git a/tests/pipeline/test_inference.py b/tests/pipeline/test_inference.py index ad283480..51f1a9d4 100644 --- a/tests/pipeline/test_inference.py +++ b/tests/pipeline/test_inference.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast + import pytest from autointent import Pipeline @@ -5,9 +9,16 @@ from autointent.custom_types import NodeType from tests.conftest import apply_test_models, get_search_space, setup_environment +if TYPE_CHECKING: + from pathlib import Path + + from autointent import Dataset + from autointent.generation import Generator + from tests.conftest import TaskType + @pytest.fixture -def project_dir(task_type): +def project_dir(task_type: TaskType) -> Path: return setup_environment() / "test_inference" / task_type @@ -21,7 +32,12 @@ def project_dir(task_type): "description_with_llm", ], ) -def test_inference_from_config(dataset, task_type, project_dir, patch_llm_scorer_generator): +def test_inference_from_config( + dataset: Dataset, + task_type: TaskType, + project_dir: Path, + patch_llm_scorer_generator: Generator, +) -> None: search_space = get_search_space(task_type) pipeline_optimizer = Pipeline.from_search_space(search_space) @@ -69,7 +85,12 @@ def test_inference_from_config(dataset, task_type, project_dir, patch_llm_scorer "description_with_llm", ], ) -def test_inference_on_the_fly(dataset, task_type, project_dir, patch_llm_scorer_generator): +def test_inference_on_the_fly( + dataset: Dataset, + task_type: TaskType, + project_dir: Path, + patch_llm_scorer_generator: Generator, +) -> None: search_space = get_search_space(task_type) pipeline = Pipeline.from_search_space(search_space) @@ -103,7 +124,7 @@ def test_inference_on_the_fly(dataset, task_type, project_dir, patch_llm_scorer_ assert prediction == prediction_v2 -def test_load_with_overrided_params(dataset): +def test_load_with_overrided_params(dataset: Dataset) -> None: project_dir = setup_environment() / "test_inference" / "override" search_space = get_search_space("light") @@ -128,7 +149,11 @@ def test_load_with_overrided_params(dataset): # case 2: rich inference from file system rich_outputs = inference_pipeline.predict_with_metadata(utterances) assert len(rich_outputs.predictions) == len(utterances) - assert inference_pipeline.nodes[NodeType.scoring].module._embedder.config.tokenizer_config.max_length == 8 + # The scoring module is concretely a LinearScorer (or sibling) here with a + # private _embedder attribute; reach through Any since BaseModule does not + # declare this internal field. + inference_scoring_module: Any = cast("Any", inference_pipeline.nodes[NodeType.scoring]).module + assert inference_scoring_module._embedder.config.tokenizer_config.max_length == 8 del inference_pipeline # case 3: dump and then load pipeline @@ -141,10 +166,11 @@ def test_load_with_overrided_params(dataset): ) prediction_v2 = loaded_pipe.predict(utterances) assert prediction == prediction_v2 - assert loaded_pipe.nodes[NodeType.scoring].module._embedder.config.tokenizer_config.max_length == 8 + loaded_scoring_module: Any = cast("Any", loaded_pipe.nodes[NodeType.scoring]).module + assert loaded_scoring_module._embedder.config.tokenizer_config.max_length == 8 -def test_no_saving(dataset): +def test_no_saving(dataset: Dataset) -> None: project_dir = setup_environment() / "test_inference" / "no_saving" search_space = get_search_space("light") diff --git a/tests/pipeline/test_optimization.py b/tests/pipeline/test_optimization.py index 652e71cf..85159276 100644 --- a/tests/pipeline/test_optimization.py +++ b/tests/pipeline/test_optimization.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import importlib.resources as ires +from typing import TYPE_CHECKING, cast import pytest @@ -6,6 +9,14 @@ from autointent.configs import DataConfig, HPOConfig, LoggingConfig from tests.conftest import apply_test_models, get_search_space, setup_environment +if TYPE_CHECKING: + from pathlib import Path + + from autointent import Dataset + from autointent.custom_types import SamplerType + from autointent.generation import Generator + from tests.conftest import TaskType + @pytest.mark.parametrize( ("data_config", "refit_after"), @@ -20,7 +31,7 @@ (DataConfig(scheme="cv", separation_ratio=0.5), True), ], ) -def test_with_regex(dataset, data_config, refit_after): +def test_with_regex(dataset: Dataset, data_config: DataConfig, refit_after: bool) -> None: project_dir = setup_environment() search_space = get_search_space("regex") @@ -33,7 +44,7 @@ def test_with_regex(dataset, data_config, refit_after): pipeline_optimizer.fit(dataset, refit_after=refit_after) -def test_no_node_separation(dataset_no_oos): +def test_no_node_separation(dataset_no_oos: Dataset) -> None: project_dir = setup_environment() search_space = get_search_space("light") @@ -46,8 +57,11 @@ def test_no_node_separation(dataset_no_oos): pipeline_optimizer.fit(dataset_no_oos, refit_after=False) -def test_full_config(dataset_no_oos): - config_path = ires.files("tests.assets.configs").joinpath("full_training.yaml") +def test_full_config(dataset_no_oos: Dataset) -> None: + # tests.assets.configs is a regular package, so importlib.resources.files + # returns a concrete Path; cast asserts that to mypy without changing + # behavior (matches the pattern used in tests/conftest.py). + config_path = cast("Path", ires.files("tests.assets.configs").joinpath("full_training.yaml")) pipeline_optimizer = Pipeline.from_optimization_config(config_path) apply_test_models(pipeline_optimizer) pipeline_optimizer.fit(dataset_no_oos, refit_after=False) @@ -57,7 +71,7 @@ def test_full_config(dataset_no_oos): "sampler", ["tpe", "random"], ) -def test_bayes(dataset, sampler): +def test_bayes(dataset: Dataset, sampler: SamplerType) -> None: project_dir = setup_environment() search_space = get_search_space("optuna") @@ -80,7 +94,7 @@ def test_bayes(dataset, sampler): "description_with_llm", ], ) -def test_cv(dataset, task_type, patch_llm_scorer_generator): +def test_cv(dataset: Dataset, task_type: TaskType, patch_llm_scorer_generator: Generator) -> None: project_dir = setup_environment() search_space = get_search_space(task_type) @@ -108,7 +122,7 @@ def test_cv(dataset, task_type, patch_llm_scorer_generator): "description_with_llm", ], ) -def test_no_context_optimization(dataset, task_type, patch_llm_scorer_generator): +def test_no_context_optimization(dataset: Dataset, task_type: TaskType, patch_llm_scorer_generator: Generator) -> None: project_dir = setup_environment() search_space = get_search_space(task_type) @@ -134,7 +148,7 @@ def test_no_context_optimization(dataset, task_type, patch_llm_scorer_generator) "description_with_llm", ], ) -def test_dump_modules(dataset, task_type, patch_llm_scorer_generator): +def test_dump_modules(dataset: Dataset, task_type: TaskType, patch_llm_scorer_generator: Generator) -> None: project_dir = setup_environment() search_space = get_search_space(task_type) @@ -156,7 +170,7 @@ def test_dump_modules(dataset, task_type, patch_llm_scorer_generator): "task_type", ["multiclass", "multilabel"], ) -def test_optimization_validation_metric_names(dataset, task_type): +def test_optimization_validation_metric_names(dataset: Dataset, task_type: TaskType) -> None: search_space = get_search_space(task_type) pipeline_optimizer = Pipeline.from_search_space(search_space) diff --git a/tests/pipeline/test_pipeline_interruption.py b/tests/pipeline/test_pipeline_interruption.py index c210a751..6906ce4d 100644 --- a/tests/pipeline/test_pipeline_interruption.py +++ b/tests/pipeline/test_pipeline_interruption.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import logging import sqlite3 +from typing import TYPE_CHECKING, Any, cast from unittest.mock import patch import pytest @@ -9,6 +12,14 @@ from autointent.custom_types import NodeType from tests.conftest import get_search_space +if TYPE_CHECKING: + from pathlib import Path + + from optuna.trial import Trial + + from autointent import Dataset + from autointent.nodes import NodeOptimizer + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -17,15 +28,15 @@ class InterruptAfterNCallsError(Exception): """Exception to simulate interruption.""" -def count_trials_in_database(db_path): +def count_trials_in_database(db_path: Path) -> int: """Count the number of trials in an Optuna SQLite database.""" with sqlite3.connect(db_path) as conn: cursor = conn.cursor() cursor.execute("SELECT COUNT(*) FROM trials") - return cursor.fetchone()[0] + return cast("int", cursor.fetchone()[0]) -def get_completed_trial_numbers(db_path): +def get_completed_trial_numbers(db_path: Path) -> set[int]: """Get the trial numbers that have been completed.""" with sqlite3.connect(db_path) as conn: cursor = conn.cursor() @@ -33,11 +44,11 @@ def get_completed_trial_numbers(db_path): return {row[0] for row in cursor.fetchall()} -def test_pipeline_with_exception_resume(dataset_no_oos, tmp_path): +def test_pipeline_with_exception_resume(dataset_no_oos: Dataset, tmp_path: Path) -> None: """Test that pipeline can resume after an exception and continues from where it left off.""" project_dir = tmp_path search_space = get_search_space("optuna") - first_run_trials = set() + first_run_trials: set[int] = set() call_count = 0 max_calls_before_exception = 2 @@ -49,9 +60,12 @@ def test_pipeline_with_exception_resume(dataset_no_oos, tmp_path): pipeline_optimizer.set_config(logging_config) pipeline_optimizer.set_config(DataConfig(scheme="ho", separation_ratio=None)) pipeline_optimizer.set_config(HPOConfig(sampler="random")) - original_objective = pipeline_optimizer.nodes[NodeType.scoring].objective + # Pipeline.from_search_space() builds NodeOptimizer instances; the union + # with InferenceNode does not expose `objective`. Cast each accessed node. + scoring_optimizer = cast("NodeOptimizer", pipeline_optimizer.nodes[NodeType.scoring]) + original_objective = scoring_optimizer.objective - def exception_raising_objective(trial, *args, **kwargs): + def exception_raising_objective(trial: Trial, *args: Any, **kwargs: Any) -> Any: nonlocal call_count call_count += 1 @@ -66,7 +80,7 @@ def exception_raising_objective(trial, *args, **kwargs): # Replace the objective with our exception-raising version # pipeline_optimizer.nodes[NodeType.scoring].objective = exception_raising_objective with ( - patch.object(pipeline_optimizer.nodes[NodeType.scoring], "objective", side_effect=exception_raising_objective), + patch.object(scoring_optimizer, "objective", side_effect=exception_raising_objective), pytest.raises(InterruptAfterNCallsError), ): pipeline_optimizer.fit(dataset_no_oos, refit_after=False) @@ -78,7 +92,7 @@ def exception_raising_objective(trial, *args, **kwargs): assert optuna_storage_dir.exists(), "Optuna storage directory not created in first run" db_files_first_run = list(optuna_storage_dir.glob("*.db")) - trials_after_first_run = {} + trials_after_first_run: dict[str, int] = {} for db_file in db_files_first_run: trials_after_first_run[db_file.name] = count_trials_in_database(db_file) @@ -89,10 +103,11 @@ def exception_raising_objective(trial, *args, **kwargs): pipeline_optimizer.set_config(HPOConfig(sampler="random")) # Add tracking for second run to see which trials are executed - second_run_trials = set() - original_objective2 = pipeline_optimizer.nodes[NodeType.scoring].objective + second_run_trials: set[int] = set() + scoring_optimizer2 = cast("NodeOptimizer", pipeline_optimizer.nodes[NodeType.scoring]) + original_objective2 = scoring_optimizer2.objective - def tracking_objective2(trial, *args, **kwargs): + def tracking_objective2(trial: Trial, *args: Any, **kwargs: Any) -> Any: msg = f"Second run: Processing trial #{trial.number}" logger.info(msg) second_run_trials.add(trial.number) @@ -100,7 +115,7 @@ def tracking_objective2(trial, *args, **kwargs): pipeline_optimizer.set_config(logging_config) pipeline_optimizer.set_config(DataConfig(scheme="ho", separation_ratio=None)) - with patch.object(pipeline_optimizer.nodes[NodeType.scoring], "objective", side_effect=tracking_objective2): + with patch.object(scoring_optimizer2, "objective", side_effect=tracking_objective2): # This run should complete without exceptions pipeline_optimizer.fit(dataset_no_oos, refit_after=False) @@ -124,7 +139,9 @@ def tracking_objective2(trial, *args, **kwargs): ) -def test_resuming_with_memory_storage_warning(dataset_no_oos, tmp_path, caplog): +def test_resuming_with_memory_storage_warning( + dataset_no_oos: Dataset, tmp_path: Path, caplog: pytest.LogCaptureFixture +) -> None: """Test that a warning is issued when trying to resume with memory storage.""" project_dir = tmp_path search_space = get_search_space("optuna") diff --git a/tests/pipeline/test_presets.py b/tests/pipeline/test_presets.py index f0af640a..7cb298aa 100644 --- a/tests/pipeline/test_presets.py +++ b/tests/pipeline/test_presets.py @@ -1,9 +1,19 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + import pytest from autointent import Pipeline from autointent.configs import DataConfig, HPOConfig, LoggingConfig from tests.conftest import apply_test_models, setup_environment +if TYPE_CHECKING: + from autointent import Dataset + from autointent.configs import SentenceTransformerEmbeddingConfig + from autointent.generation import Generator + from autointent.nodes import NodeOptimizer + @pytest.mark.parametrize( "preset", @@ -20,10 +30,10 @@ "zero-shot-encoders", ], ) -def test_presets(dataset, preset, patch_llm_scorer_generator): +def test_presets(dataset: Dataset, preset: str, patch_llm_scorer_generator: Generator) -> None: project_dir = setup_environment() - pipeline_optimizer = Pipeline.from_preset(preset) + pipeline_optimizer = Pipeline.from_preset(preset) # type: ignore[arg-type] # reason: parametrize values are runtime strings; mypy can't narrow to the SearchSpacePreset Literal apply_test_models(pipeline_optimizer) pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True)) @@ -33,7 +43,7 @@ def test_presets(dataset, preset, patch_llm_scorer_generator): pipeline_optimizer.fit(dataset, refit_after=False) -def test_apply_test_models_retargets_pipeline_slots(): +def test_apply_test_models_retargets_pipeline_slots() -> None: from autointent import Pipeline from tests.conftest import ( TINY_BERT, @@ -47,12 +57,17 @@ def test_apply_test_models_retargets_pipeline_slots(): # cross-encoder=bge-reranker-v2-m3. apply_test_models(pipeline) - assert pipeline.embedder_config.model_name == TINY_SENTENCE_TRANSFORMER + # apply_test_models() installs a SentenceTransformerEmbeddingConfig (see + # tests.conftest.tiny_sentence_transformer_config); narrow the EmbedderConfig + # union to that concrete subclass to access model_name (BaseEmbedderConfig + # has no model_name). + embedder_config = cast("SentenceTransformerEmbeddingConfig", pipeline.embedder_config) + assert embedder_config.model_name == TINY_SENTENCE_TRANSFORMER assert pipeline.cross_encoder_config.model_name == TINY_CROSS_ENCODER assert pipeline.transformer_config.model_name == TINY_BERT -def test_apply_test_models_rewrites_search_space_bert_entries(): +def test_apply_test_models_rewrites_search_space_bert_entries() -> None: from autointent import Pipeline from tests.conftest import TINY_BERT, apply_test_models @@ -61,10 +76,13 @@ def test_apply_test_models_rewrites_search_space_bert_entries(): # classification_model_config: [{model_name: 'microsoft/deberta-v3-large'}] apply_test_models(pipeline) + # Pipeline.from_preset() returns an optimization-mode Pipeline whose nodes + # are NodeOptimizer; the typed union with InferenceNode does not expose + # modules_search_spaces. Cast each node down to NodeOptimizer to access it. bert_entries = [ entry for node in pipeline.nodes.values() - for entry in node.modules_search_spaces + for entry in cast("NodeOptimizer", node).modules_search_spaces if entry.get("module_name") == "bert" ] assert bert_entries, "transformers-heavy preset must have a bert module entry" @@ -81,7 +99,7 @@ def test_apply_test_models_rewrites_search_space_bert_entries(): ) -def test_apply_test_models_drops_stale_revision_in_search_space(): +def test_apply_test_models_drops_stale_revision_in_search_space() -> None: """When a search-space entry pins model_name AND revision (e.g. catboost in tests/assets/configs/multiclass.yaml), the walker rewrites the model_name but must also drop the now-wrong revision so the @@ -94,7 +112,9 @@ def test_apply_test_models_drops_stale_revision_in_search_space(): apply_test_models(pipeline) for node in pipeline.nodes.values(): - for entry in node.modules_search_spaces: + # Pipeline.from_search_space() returns an optimization-mode Pipeline + # whose nodes are NodeOptimizer (see test_apply_test_models_rewrites_…). + for entry in cast("NodeOptimizer", node).modules_search_spaces: for field in ("classification_model_config", "embedder_config", "cross_encoder_config"): value = entry.get(field) if isinstance(value, list): diff --git a/tests/pipeline/test_validation.py b/tests/pipeline/test_validation.py index ce6640c5..4c759ff6 100644 --- a/tests/pipeline/test_validation.py +++ b/tests/pipeline/test_validation.py @@ -1,9 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import pytest from autointent import Pipeline +if TYPE_CHECKING: + from autointent import Dataset + -def test_validate_search_space_multiclass(dataset): +def test_validate_search_space_multiclass(dataset: Dataset) -> None: search_space = [ { "node_type": "decision", @@ -17,7 +24,7 @@ def test_validate_search_space_multiclass(dataset): pipeline_optimizer.validate_modules(dataset, mode="raise") -def test_validate_search_space_multilabel(dataset): +def test_validate_search_space_multilabel(dataset: Dataset) -> None: dataset = dataset.to_multilabel() search_space = [