Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions tests/pipeline/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, cast

import pytest

from autointent import Pipeline
from autointent.configs import LoggingConfig, TokenizerConfig, get_default_embedder_config
from autointent.custom_types import NodeType
from tests.conftest import apply_test_models, get_search_space, setup_environment

if TYPE_CHECKING:
from pathlib import Path

from autointent import Dataset
from autointent.generation import Generator
from tests.conftest import TaskType


@pytest.fixture
def project_dir(task_type):
def project_dir(task_type: TaskType) -> Path:
return setup_environment() / "test_inference" / task_type


Expand All @@ -21,7 +32,12 @@ def project_dir(task_type):
"description_with_llm",
],
)
def test_inference_from_config(dataset, task_type, project_dir, patch_llm_scorer_generator):
def test_inference_from_config(
dataset: Dataset,
task_type: TaskType,
project_dir: Path,
patch_llm_scorer_generator: Generator,
) -> None:
search_space = get_search_space(task_type)

pipeline_optimizer = Pipeline.from_search_space(search_space)
Expand Down Expand Up @@ -69,7 +85,12 @@ def test_inference_from_config(dataset, task_type, project_dir, patch_llm_scorer
"description_with_llm",
],
)
def test_inference_on_the_fly(dataset, task_type, project_dir, patch_llm_scorer_generator):
def test_inference_on_the_fly(
dataset: Dataset,
task_type: TaskType,
project_dir: Path,
patch_llm_scorer_generator: Generator,
) -> None:
search_space = get_search_space(task_type)

pipeline = Pipeline.from_search_space(search_space)
Expand Down Expand Up @@ -103,7 +124,7 @@ def test_inference_on_the_fly(dataset, task_type, project_dir, patch_llm_scorer_
assert prediction == prediction_v2


def test_load_with_overrided_params(dataset):
def test_load_with_overrided_params(dataset: Dataset) -> None:
project_dir = setup_environment() / "test_inference" / "override"
search_space = get_search_space("light")

Expand All @@ -128,7 +149,11 @@ def test_load_with_overrided_params(dataset):
# case 2: rich inference from file system
rich_outputs = inference_pipeline.predict_with_metadata(utterances)
assert len(rich_outputs.predictions) == len(utterances)
assert inference_pipeline.nodes[NodeType.scoring].module._embedder.config.tokenizer_config.max_length == 8
# The scoring module is concretely a LinearScorer (or sibling) here with a
# private _embedder attribute; reach through Any since BaseModule does not
# declare this internal field.
inference_scoring_module: Any = cast("Any", inference_pipeline.nodes[NodeType.scoring]).module
assert inference_scoring_module._embedder.config.tokenizer_config.max_length == 8
del inference_pipeline

# case 3: dump and then load pipeline
Expand All @@ -141,10 +166,11 @@ def test_load_with_overrided_params(dataset):
)
prediction_v2 = loaded_pipe.predict(utterances)
assert prediction == prediction_v2
assert loaded_pipe.nodes[NodeType.scoring].module._embedder.config.tokenizer_config.max_length == 8
loaded_scoring_module: Any = cast("Any", loaded_pipe.nodes[NodeType.scoring]).module
assert loaded_scoring_module._embedder.config.tokenizer_config.max_length == 8


def test_no_saving(dataset):
def test_no_saving(dataset: Dataset) -> None:
project_dir = setup_environment() / "test_inference" / "no_saving"
search_space = get_search_space("light")

Expand Down
32 changes: 23 additions & 9 deletions tests/pipeline/test_optimization.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from __future__ import annotations

import importlib.resources as ires
from typing import TYPE_CHECKING, cast

import pytest

from autointent import Pipeline
from autointent.configs import DataConfig, HPOConfig, LoggingConfig
from tests.conftest import apply_test_models, get_search_space, setup_environment

if TYPE_CHECKING:
from pathlib import Path

from autointent import Dataset
from autointent.custom_types import SamplerType
from autointent.generation import Generator
from tests.conftest import TaskType


@pytest.mark.parametrize(
("data_config", "refit_after"),
Expand All @@ -20,7 +31,7 @@
(DataConfig(scheme="cv", separation_ratio=0.5), True),
],
)
def test_with_regex(dataset, data_config, refit_after):
def test_with_regex(dataset: Dataset, data_config: DataConfig, refit_after: bool) -> None:
project_dir = setup_environment()
search_space = get_search_space("regex")

Expand All @@ -33,7 +44,7 @@ def test_with_regex(dataset, data_config, refit_after):
pipeline_optimizer.fit(dataset, refit_after=refit_after)


def test_no_node_separation(dataset_no_oos):
def test_no_node_separation(dataset_no_oos: Dataset) -> None:
project_dir = setup_environment()
search_space = get_search_space("light")

Expand All @@ -46,8 +57,11 @@ def test_no_node_separation(dataset_no_oos):
pipeline_optimizer.fit(dataset_no_oos, refit_after=False)


def test_full_config(dataset_no_oos):
config_path = ires.files("tests.assets.configs").joinpath("full_training.yaml")
def test_full_config(dataset_no_oos: Dataset) -> None:
# tests.assets.configs is a regular package, so importlib.resources.files
# returns a concrete Path; cast asserts that to mypy without changing
# behavior (matches the pattern used in tests/conftest.py).
config_path = cast("Path", ires.files("tests.assets.configs").joinpath("full_training.yaml"))
pipeline_optimizer = Pipeline.from_optimization_config(config_path)
apply_test_models(pipeline_optimizer)
pipeline_optimizer.fit(dataset_no_oos, refit_after=False)
Expand All @@ -57,7 +71,7 @@ def test_full_config(dataset_no_oos):
"sampler",
["tpe", "random"],
)
def test_bayes(dataset, sampler):
def test_bayes(dataset: Dataset, sampler: SamplerType) -> None:
project_dir = setup_environment()
search_space = get_search_space("optuna")

Expand All @@ -80,7 +94,7 @@ def test_bayes(dataset, sampler):
"description_with_llm",
],
)
def test_cv(dataset, task_type, patch_llm_scorer_generator):
def test_cv(dataset: Dataset, task_type: TaskType, patch_llm_scorer_generator: Generator) -> None:
project_dir = setup_environment()
search_space = get_search_space(task_type)

Expand Down Expand Up @@ -108,7 +122,7 @@ def test_cv(dataset, task_type, patch_llm_scorer_generator):
"description_with_llm",
],
)
def test_no_context_optimization(dataset, task_type, patch_llm_scorer_generator):
def test_no_context_optimization(dataset: Dataset, task_type: TaskType, patch_llm_scorer_generator: Generator) -> None:
project_dir = setup_environment()
search_space = get_search_space(task_type)

Expand All @@ -134,7 +148,7 @@ def test_no_context_optimization(dataset, task_type, patch_llm_scorer_generator)
"description_with_llm",
],
)
def test_dump_modules(dataset, task_type, patch_llm_scorer_generator):
def test_dump_modules(dataset: Dataset, task_type: TaskType, patch_llm_scorer_generator: Generator) -> None:
project_dir = setup_environment()
search_space = get_search_space(task_type)

Expand All @@ -156,7 +170,7 @@ def test_dump_modules(dataset, task_type, patch_llm_scorer_generator):
"task_type",
["multiclass", "multilabel"],
)
def test_optimization_validation_metric_names(dataset, task_type):
def test_optimization_validation_metric_names(dataset: Dataset, task_type: TaskType) -> None:
search_space = get_search_space(task_type)

pipeline_optimizer = Pipeline.from_search_space(search_space)
Expand Down
45 changes: 31 additions & 14 deletions tests/pipeline/test_pipeline_interruption.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import logging
import sqlite3
from typing import TYPE_CHECKING, Any, cast
from unittest.mock import patch

import pytest
Expand All @@ -9,6 +12,14 @@
from autointent.custom_types import NodeType
from tests.conftest import get_search_space

if TYPE_CHECKING:
from pathlib import Path

from optuna.trial import Trial

from autointent import Dataset
from autointent.nodes import NodeOptimizer

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

Expand All @@ -17,27 +28,27 @@ class InterruptAfterNCallsError(Exception):
"""Exception to simulate interruption."""


def count_trials_in_database(db_path):
def count_trials_in_database(db_path: Path) -> int:
"""Count the number of trials in an Optuna SQLite database."""
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM trials")
return cursor.fetchone()[0]
return cast("int", cursor.fetchone()[0])


def get_completed_trial_numbers(db_path):
def get_completed_trial_numbers(db_path: Path) -> set[int]:
"""Get the trial numbers that have been completed."""
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT number FROM trials WHERE state = 'COMPLETE'")
return {row[0] for row in cursor.fetchall()}


def test_pipeline_with_exception_resume(dataset_no_oos, tmp_path):
def test_pipeline_with_exception_resume(dataset_no_oos: Dataset, tmp_path: Path) -> None:
"""Test that pipeline can resume after an exception and continues from where it left off."""
project_dir = tmp_path
search_space = get_search_space("optuna")
first_run_trials = set()
first_run_trials: set[int] = set()

call_count = 0
max_calls_before_exception = 2
Expand All @@ -49,9 +60,12 @@ def test_pipeline_with_exception_resume(dataset_no_oos, tmp_path):
pipeline_optimizer.set_config(logging_config)
pipeline_optimizer.set_config(DataConfig(scheme="ho", separation_ratio=None))
pipeline_optimizer.set_config(HPOConfig(sampler="random"))
original_objective = pipeline_optimizer.nodes[NodeType.scoring].objective
# Pipeline.from_search_space() builds NodeOptimizer instances; the union
# with InferenceNode does not expose `objective`. Cast each accessed node.
scoring_optimizer = cast("NodeOptimizer", pipeline_optimizer.nodes[NodeType.scoring])
original_objective = scoring_optimizer.objective

def exception_raising_objective(trial, *args, **kwargs):
def exception_raising_objective(trial: Trial, *args: Any, **kwargs: Any) -> Any:
nonlocal call_count
call_count += 1

Expand All @@ -66,7 +80,7 @@ def exception_raising_objective(trial, *args, **kwargs):
# Replace the objective with our exception-raising version
# pipeline_optimizer.nodes[NodeType.scoring].objective = exception_raising_objective
with (
patch.object(pipeline_optimizer.nodes[NodeType.scoring], "objective", side_effect=exception_raising_objective),
patch.object(scoring_optimizer, "objective", side_effect=exception_raising_objective),
pytest.raises(InterruptAfterNCallsError),
):
pipeline_optimizer.fit(dataset_no_oos, refit_after=False)
Expand All @@ -78,7 +92,7 @@ def exception_raising_objective(trial, *args, **kwargs):
assert optuna_storage_dir.exists(), "Optuna storage directory not created in first run"

db_files_first_run = list(optuna_storage_dir.glob("*.db"))
trials_after_first_run = {}
trials_after_first_run: dict[str, int] = {}
for db_file in db_files_first_run:
trials_after_first_run[db_file.name] = count_trials_in_database(db_file)

Expand All @@ -89,18 +103,19 @@ def exception_raising_objective(trial, *args, **kwargs):
pipeline_optimizer.set_config(HPOConfig(sampler="random"))

# Add tracking for second run to see which trials are executed
second_run_trials = set()
original_objective2 = pipeline_optimizer.nodes[NodeType.scoring].objective
second_run_trials: set[int] = set()
scoring_optimizer2 = cast("NodeOptimizer", pipeline_optimizer.nodes[NodeType.scoring])
original_objective2 = scoring_optimizer2.objective

def tracking_objective2(trial, *args, **kwargs):
def tracking_objective2(trial: Trial, *args: Any, **kwargs: Any) -> Any:
msg = f"Second run: Processing trial #{trial.number}"
logger.info(msg)
second_run_trials.add(trial.number)
return original_objective2(trial, *args, **kwargs)

pipeline_optimizer.set_config(logging_config)
pipeline_optimizer.set_config(DataConfig(scheme="ho", separation_ratio=None))
with patch.object(pipeline_optimizer.nodes[NodeType.scoring], "objective", side_effect=tracking_objective2):
with patch.object(scoring_optimizer2, "objective", side_effect=tracking_objective2):
# This run should complete without exceptions
pipeline_optimizer.fit(dataset_no_oos, refit_after=False)

Expand All @@ -124,7 +139,9 @@ def tracking_objective2(trial, *args, **kwargs):
)


def test_resuming_with_memory_storage_warning(dataset_no_oos, tmp_path, caplog):
def test_resuming_with_memory_storage_warning(
dataset_no_oos: Dataset, tmp_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that a warning is issued when trying to resume with memory storage."""
project_dir = tmp_path
search_space = get_search_space("optuna")
Expand Down
Loading
Loading