diff --git a/.pyrit_conf_example b/.pyrit_conf_example index c45bb390ce..9d9e66305d 100644 --- a/.pyrit_conf_example +++ b/.pyrit_conf_example @@ -111,6 +111,12 @@ operation: op_trash_panda # - /path/to/.env # - /path/to/.env.local +# Max Concurrent Scenario Runs +# ---------------------------- +# Maximum number of scenario runs that can execute concurrently in the backend. +# Applies only to the pyrit_backend server. +max_concurrent_scenario_runs: 3 + # Silent Mode # ----------- # If true, suppresses print statements during initialization. diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index d606d89eb0..aeebe087fd 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -48,8 +48,8 @@ PreviewStep, ) from pyrit.backend.models.scenarios import ( - ScenarioListResponse, - ScenarioSummary, + ListRegisteredScenarioResponse, + RegisteredScenario, ) from pyrit.backend.models.targets import ( CreateTargetRequest, @@ -96,8 +96,8 @@ "CreateConverterResponse", "PreviewStep", # Scenarios - "ScenarioListResponse", - "ScenarioSummary", + "ListRegisteredScenarioResponse", + "RegisteredScenario", # Targets "CreateTargetRequest", "TargetInstance", diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 95b98a8f49..64d6269e44 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -95,18 +95,24 @@ class AttackSummary(BaseModel): attack_result_id: str = Field(..., description="Database-assigned unique ID for this AttackResult") conversation_id: str = Field(..., description="Primary conversation of this attack result") - attack_type: str = Field(..., description="Attack class name (e.g., 'CrescendoAttack', 'ManualAttack')") + attack_type: str = Field("", description="Attack class name (e.g., 'CrescendoAttack', 'ManualAttack')") attack_specific_params: Optional[dict[str, Any]] = Field(None, description="Additional attack-specific parameters") target: Optional[TargetInfo] = Field(None, description="Target information from the stored identifier") converters: list[str] = Field( default_factory=list, description="Request converter class names applied in this attack" ) + objective: str = Field("", description="Natural-language description of the attacker's objective") outcome: Optional[Literal["undetermined", "success", "failure"]] = Field( None, description="Attack outcome (null if not yet determined)" ) + outcome_reason: str | None = Field(None, description="Reason for the outcome") + last_response: str | None = Field(None, description="Model response from the final turn") last_message_preview: Optional[str] = Field( None, description="Preview of the last message (truncated to ~100 chars)" ) + score_value: str | None = Field(None, description="Score value from the objective scorer") + executed_turns: int = Field(0, ge=0, description="Number of turns executed") + execution_time_ms: int = Field(0, ge=0, description="Execution time in milliseconds") message_count: int = Field(0, description="Total number of messages in the attack") related_conversation_ids: list[str] = Field( default_factory=list, description="IDs of related conversations within this attack" diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index a47e431805..5a8fdf252a 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -5,17 +5,20 @@ Scenario API response models. Scenarios are multi-attack security testing campaigns. These models represent -the metadata about available scenarios (listing), not scenario execution results. +the metadata about available scenarios (listing) and scenario execution (runs). """ -from typing import Optional +from datetime import datetime +from enum import Enum +from typing import Any, Optional from pydantic import BaseModel, Field +from pyrit.backend.models.attacks import AttackSummary from pyrit.backend.models.common import PaginationInfo -class ScenarioSummary(BaseModel): +class RegisteredScenario(BaseModel): """Summary of a registered scenario.""" scenario_name: str = Field(..., description="Registry key (e.g., 'foundry.red_team_agent')") @@ -30,8 +33,103 @@ class ScenarioSummary(BaseModel): max_dataset_size: Optional[int] = Field(None, description="Maximum items per dataset (None means unlimited)") -class ScenarioListResponse(BaseModel): +class ListRegisteredScenarioResponse(BaseModel): """Response for listing scenarios.""" - items: list[ScenarioSummary] = Field(..., description="List of scenario summaries") + items: list[RegisteredScenario] = Field(..., description="List of scenario summaries") pagination: PaginationInfo = Field(..., description="Pagination metadata") + + +# ============================================================================ +# Scenario Run Models +# ============================================================================ + + +class ScenarioRunStatus(str, Enum): + """Status of a scenario run.""" + + PENDING = "pending" + INITIALIZING = "initializing" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class RunScenarioRequest(BaseModel): + """Request body for starting a scenario run.""" + + scenario_name: str = Field(..., description="Registry key of the scenario to run") + target_name: str = Field(..., description="Name of a registered target from the TargetRegistry") + initializers: list[str] | None = Field( + None, description="Initializer names to run before scenario (e.g., ['target', 'load_default_datasets'])" + ) + strategies: list[str] | None = Field(None, description="Strategy names to use (uses scenario default if omitted)") + dataset_names: list[str] | None = Field(None, description="Dataset names to use (uses scenario default if omitted)") + max_dataset_size: int | None = Field(None, ge=1, description="Maximum items per dataset") + max_concurrency: int = Field(10, ge=1, le=100, description="Maximum concurrent operations") + max_retries: int = Field(0, ge=0, le=20, description="Maximum retry attempts on failure") + labels: dict[str, str] | None = Field(None, description="Labels to attach to memory entries") + scenario_params: dict[str, Any] | None = Field( + None, + description="Custom parameters for the scenario (passed to scenario.set_params_from_args). " + "Keys are parameter names declared by the scenario's supported_parameters().", + ) + initializer_args: dict[str, dict[str, Any]] | None = Field( + None, + description="Per-initializer arguments keyed by initializer name. " + "Each value is a dict of args passed to that initializer's set_params_from_args(). " + "Example: {'target': {'endpoint': 'https://...'}}.", + ) + scenario_result_id: str | None = Field( + None, + description="Optional ID of an existing ScenarioResult to resume. " + "If provided, the scenario will resume from prior progress instead of starting fresh.", + ) + + +class ScenarioRunSummary(BaseModel): + """Response for a scenario run (status + result details).""" + + scenario_result_id: str = Field(..., description="UUID of the ScenarioResult in memory") + scenario_name: str = Field(..., description="Registry key of the scenario being run") + scenario_version: int = Field(0, ge=0, description="Version of the scenario") + status: ScenarioRunStatus = Field(..., description="Current run status") + created_at: datetime = Field(..., description="When the run was created") + updated_at: datetime = Field(..., description="When the run status last changed") + error: str | None = Field(None, description="Error message if status is FAILED") + strategies_used: list[str] = Field(default_factory=list, description="Strategy names that were executed") + total_attacks: int = Field(0, ge=0, description="Total number of atomic attacks") + completed_attacks: int = Field(0, ge=0, description="Number of attacks that completed") + objective_achieved_rate: int = Field(0, ge=0, le=100, description="Success rate as percentage (0-100)") + labels: dict[str, str] = Field(default_factory=dict, description="Labels attached to this run") + completed_at: datetime | None = Field(None, description="When the scenario finished") + + +class ScenarioRunListResponse(BaseModel): + """Response for listing scenario runs.""" + + items: list[ScenarioRunSummary] = Field(..., description="List of scenario runs") + + +# ============================================================================ +# Scenario Results Detail Models +# ============================================================================ + + +class AtomicAttackResults(BaseModel): + """Results grouped by atomic attack name.""" + + atomic_attack_name: str = Field(..., description="Name of the atomic attack (strategy)") + display_group: str | None = Field(None, description="Display group label for UI grouping") + results: list[AttackSummary] = Field(..., description="Individual attack results") + success_count: int = Field(0, ge=0, description="Number of successful attacks") + failure_count: int = Field(0, ge=0, description="Number of failed attacks") + total_count: int = Field(0, ge=0, description="Total number of attack results") + + +class ScenarioRunDetail(BaseModel): + """Full detailed results of a scenario run.""" + + run: ScenarioRunSummary = Field(..., description="The scenario run summary") + attacks: list[AtomicAttackResults] = Field(..., description="Results grouped by atomic attack") diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py index 9cd3e2ef43..a9e6e885fe 100644 --- a/pyrit/backend/routes/scenarios.py +++ b/pyrit/backend/routes/scenarios.py @@ -4,7 +4,12 @@ """ Scenario API routes. -Provides endpoints for listing available scenarios and their metadata. +Provides endpoints for listing available scenarios, their metadata, +and managing scenario runs. + +Route structure: + /api/scenarios/catalog — scenario catalog (list + detail) + /api/scenarios/runs — scenario execution lifecycle """ from typing import Optional @@ -12,25 +17,38 @@ from fastapi import APIRouter, HTTPException, Query, status from pyrit.backend.models.common import ProblemDetail -from pyrit.backend.models.scenarios import ScenarioListResponse, ScenarioSummary +from pyrit.backend.models.scenarios import ( + ListRegisteredScenarioResponse, + RegisteredScenario, + RunScenarioRequest, + ScenarioRunDetail, + ScenarioRunListResponse, + ScenarioRunSummary, +) +from pyrit.backend.services.scenario_run_service import get_scenario_run_service from pyrit.backend.services.scenario_service import get_scenario_service router = APIRouter(prefix="/scenarios", tags=["scenarios"]) +# ============================================================================ +# Scenario Catalog +# ============================================================================ + + @router.get( - "", - response_model=ScenarioListResponse, + "/catalog", + response_model=ListRegisteredScenarioResponse, ) async def list_scenarios( limit: int = Query(50, ge=1, le=200, description="Maximum items per page"), cursor: Optional[str] = Query(None, description="Pagination cursor (scenario_name to start after)"), -) -> ScenarioListResponse: +) -> ListRegisteredScenarioResponse: """ List all available scenarios. Returns scenario metadata including strategies, datasets, and defaults. - Use GET /api/scenarios/{scenario_name} for full details on a specific scenario. + Use GET /api/scenarios/catalog/{scenario_name} for full details on a specific scenario. Returns: ScenarioListResponse: Paginated list of scenario summaries. @@ -40,13 +58,13 @@ async def list_scenarios( @router.get( - "/{scenario_name:path}", - response_model=ScenarioSummary, + "/catalog/{scenario_name:path}", + response_model=RegisteredScenario, responses={ 404: {"model": ProblemDetail, "description": "Scenario not found"}, }, ) -async def get_scenario(scenario_name: str) -> ScenarioSummary: +async def get_scenario(scenario_name: str) -> RegisteredScenario: """ Get details for a specific scenario. @@ -66,3 +84,147 @@ async def get_scenario(scenario_name: str) -> ScenarioSummary: ) return scenario + + +# ============================================================================ +# Scenario Runs +# ============================================================================ + + +@router.post( + "/runs", + response_model=ScenarioRunSummary, + status_code=status.HTTP_202_ACCEPTED, + responses={ + 400: {"model": ProblemDetail, "description": "Invalid request (bad scenario/target/strategy)"}, + }, +) +async def start_scenario_run(request: RunScenarioRequest) -> ScenarioRunSummary: + """ + Start a new scenario run as a background task. + + Returns immediately with a scenario_result_id that can be polled for status. + + Args: + request: Scenario run configuration. + + Returns: + ScenarioRunSummary: Run metadata with PENDING status. + """ + service = get_scenario_run_service() + try: + return await service.start_run_async(request=request) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from None + + +@router.get( + "/runs", + response_model=ScenarioRunListResponse, +) +async def list_scenario_runs(limit: int = Query(100, ge=1)) -> ScenarioRunListResponse: + """ + List tracked scenario runs (most recent first). + + Args: + limit (int): Maximum number of runs to return. Defaults to 100. + + Returns: + ScenarioRunListResponse: Runs, most recent first. + """ + service = get_scenario_run_service() + return service.list_runs(limit=limit) + + +@router.get( + "/runs/{scenario_result_id}", + response_model=ScenarioRunSummary, + responses={ + 404: {"model": ProblemDetail, "description": "Run not found"}, + }, +) +async def get_scenario_run(scenario_result_id: str) -> ScenarioRunSummary: + """ + Get the current status and result of a scenario run. + + Args: + scenario_result_id: The scenario_result_id returned by POST /runs. + + Returns: + ScenarioRunSummary: Current run status (and result if completed). + """ + service = get_scenario_run_service() + run = service.get_run(scenario_result_id=scenario_result_id) + if run is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Scenario run '{scenario_result_id}' not found", + ) + return run + + +@router.post( + "/runs/{scenario_result_id}/cancel", + response_model=ScenarioRunSummary, + responses={ + 404: {"model": ProblemDetail, "description": "Run not found"}, + 409: {"model": ProblemDetail, "description": "Run already in terminal state"}, + }, +) +async def cancel_scenario_run(scenario_result_id: str) -> ScenarioRunSummary: + """ + Cancel a running scenario. + + Args: + scenario_result_id: The scenario_result_id to cancel. + + Returns: + ScenarioRunSummary: Updated run with CANCELLED status. + """ + service = get_scenario_run_service() + try: + result = await service.cancel_run_async(scenario_result_id=scenario_result_id) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from None + + if result is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Scenario run '{scenario_result_id}' not found", + ) + return result + + +@router.get( + "/runs/{scenario_result_id}/results", + response_model=ScenarioRunDetail, + responses={ + 404: {"model": ProblemDetail, "description": "Run not found"}, + 409: {"model": ProblemDetail, "description": "Run not yet completed"}, + }, +) +async def get_scenario_run_results(scenario_result_id: str) -> ScenarioRunDetail: + """ + Get detailed results for a completed scenario run. + + Returns per-attack outcomes including objectives, responses, scores, + and success/failure counts. + + Args: + scenario_result_id: The scenario_result_id. + + Returns: + ScenarioRunDetail: Full attack-level results. + """ + service = get_scenario_run_service() + try: + result = service.get_run_results(scenario_result_id=scenario_result_id) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from None + + if result is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Scenario run '{scenario_result_id}' not found", + ) + return result diff --git a/pyrit/backend/services/__init__.py b/pyrit/backend/services/__init__.py index 29807150ae..d36f69a830 100644 --- a/pyrit/backend/services/__init__.py +++ b/pyrit/backend/services/__init__.py @@ -15,6 +15,10 @@ ConverterService, get_converter_service, ) +from pyrit.backend.services.scenario_run_service import ( + ScenarioRunService, + get_scenario_run_service, +) from pyrit.backend.services.scenario_service import ( ScenarioService, get_scenario_service, @@ -31,6 +35,8 @@ "get_converter_service", "ScenarioService", "get_scenario_service", + "ScenarioRunService", + "get_scenario_run_service", "TargetService", "get_target_service", ] diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py new file mode 100644 index 0000000000..d00f9b0caa --- /dev/null +++ b/pyrit/backend/services/scenario_run_service.py @@ -0,0 +1,496 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario run service for executing scenarios as background tasks. + +Manages the lifecycle of scenario runs: starting, tracking status, +retrieving results, and cancellation. +""" + +import asyncio +import contextlib +import logging +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any + +from pyrit.backend.models.scenarios import ( + AtomicAttackResults, + AttackSummary, + RunScenarioRequest, + ScenarioRunDetail, + ScenarioRunListResponse, + ScenarioRunStatus, + ScenarioRunSummary, +) +from pyrit.memory import CentralMemory +from pyrit.models import AttackOutcome, ScenarioResult +from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry +from pyrit.scenario import Scenario +from pyrit.scenario.core import DatasetConfiguration + +logger = logging.getLogger(__name__) + +_DEFAULT_MAX_CONCURRENT_RUNS = 3 + +# Maps DB ScenarioRunState values to API ScenarioRunStatus +_STATE_TO_STATUS = { + "CREATED": ScenarioRunStatus.INITIALIZING, + "IN_PROGRESS": ScenarioRunStatus.RUNNING, + "COMPLETED": ScenarioRunStatus.COMPLETED, + "FAILED": ScenarioRunStatus.FAILED, + "CANCELLED": ScenarioRunStatus.CANCELLED, +} + + +@dataclass +class _ActiveTask: + """Tracks an in-flight scenario run's asyncio task.""" + + scenario_result_id: str + task: asyncio.Task[None] | None = None + scenario: Scenario | None = None + error: str | None = None + + +class ScenarioRunService: + """ + Service for managing scenario run lifecycle. + + Uses CentralMemory (database) as the source of truth for run state. + Keeps an in-memory dict only for active asyncio tasks (cancellation support). + """ + + def __init__(self, *, max_concurrent_runs: int = _DEFAULT_MAX_CONCURRENT_RUNS) -> None: + """Initialize the scenario run service.""" + self._max_concurrent_runs = max_concurrent_runs + self._active_tasks: dict[str, _ActiveTask] = {} + self._run_semaphore = asyncio.Semaphore(max_concurrent_runs) + + async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunSummary: + """ + Start a new scenario run as a background task. + + Performs all validation and initialization eagerly (initializers, target + resolution, strategy validation, scenario.initialize_async) so errors are + returned immediately. On success, spawns a background task that only + executes scenario.run_async. + + Args: + request: The run request with scenario name, target, and options. + + Returns: + ScenarioRunResponse with run_id and RUNNING status. + + Raises: + ValueError: If scenario, target, initializer, or strategy cannot be found, + or concurrent limit exceeded. + """ + if self._run_semaphore.locked(): + raise ValueError( + f"Maximum concurrent runs ({self._max_concurrent_runs}) reached. " + "Wait for an existing run to complete or cancel one." + ) + + await self._run_semaphore.acquire() + + # Perform all initialization eagerly — errors propagate to caller + try: + scenario = await self._initialize_run_async(request=request) + except Exception: + self._run_semaphore.release() + raise + + # scenario_result_id is set during initialize_async + scenario_result_id = scenario._scenario_result_id + if scenario_result_id is None: + raise ValueError("Scenario did not produce a scenario_result_id during initialization.") + + # Track active task + active = _ActiveTask(scenario_result_id=scenario_result_id, scenario=scenario) + self._active_tasks[scenario_result_id] = active + + # Spawn background task (only runs scenario.run_async) + task = asyncio.create_task(self._execute_run_async(scenario_result_id=scenario_result_id)) + active.task = task + + response = self._build_response(scenario_result_id=scenario_result_id) + assert response is not None # guaranteed: we just inserted into DB via initialize_async + return response + + def get_run(self, *, scenario_result_id: str) -> ScenarioRunSummary | None: + """ + Get the current status of a scenario run by querying the database. + + Args: + scenario_result_id: The scenario result ID. + + Returns: + ScenarioRunSummary if found, None otherwise. + """ + return self._build_response(scenario_result_id=scenario_result_id) + + def list_runs(self, *, limit: int = 100) -> ScenarioRunListResponse: + """ + List scenario runs by querying the database (most recent first). + + Args: + limit (int): Maximum number of runs to return. Defaults to 100. + + Returns: + ScenarioRunListResponse with runs. + """ + memory = CentralMemory.get_memory_instance() + + # This is expensive, and we don't need all the data. At some point + # we may want to add a lightweight "list" query to the DB layer that only + results = memory.get_scenario_results(limit=limit) + items = [self._build_response_from_db(scenario_result=sr) for sr in results] + return ScenarioRunListResponse(items=items) + + async def cancel_run_async(self, *, scenario_result_id: str) -> ScenarioRunSummary | None: + """ + Cancel a running scenario. + + Args: + scenario_result_id: The scenario result ID. + + Returns: + Updated ScenarioRunSummary if found, None if not found. + + Raises: + ValueError: If the run is already in a terminal state or not active. + """ + # Verify run exists in DB + memory = CentralMemory.get_memory_instance() + results = memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + if not results: + return None + + scenario_result = results[0] + db_status = _STATE_TO_STATUS.get(scenario_result.scenario_run_state, ScenarioRunStatus.FAILED) + + if db_status in (ScenarioRunStatus.COMPLETED, ScenarioRunStatus.FAILED, ScenarioRunStatus.CANCELLED): + raise ValueError(f"Cannot cancel run in '{db_status}' state.") + + # Cancel the asyncio task if active and wait for it to finish + active = self._active_tasks.get(scenario_result_id) + if active is not None and active.task is not None and not active.task.done(): + active.task.cancel() + with contextlib.suppress(asyncio.CancelledError, asyncio.TimeoutError): + await asyncio.wait_for(active.task, timeout=5.0) + + # Persist cancelled state to DB + memory.update_scenario_run_state(scenario_result_id=scenario_result_id, scenario_run_state="CANCELLED") + + return self._build_response(scenario_result_id=scenario_result_id) + + async def _initialize_run_async(self, *, request: RunScenarioRequest) -> Scenario: + """ + Validate inputs and initialize the scenario eagerly. + + Performs all validation (scenario, initializers, target, strategies) and + calls scenario.initialize_async so that any errors are raised immediately + to the caller. Running initialization on creation simplifies error handling and ensures + that the scenario is fully ready to run when we spawn the background task. + + Args: + request: The run request with scenario name, target, and options. + + Returns: + The fully initialized Scenario instance ready for run_async. + + Raises: + ValueError: If any validation fails (bad scenario name, missing target, + invalid strategy, unknown initializer, etc.). + """ + # Validate scenario exists + scenario_registry = ScenarioRegistry.get_registry_singleton() + try: + scenario_class = scenario_registry.get_class(request.scenario_name) + except KeyError as e: + raise ValueError(str(e)) from None + + # Validate and run initializers + if request.initializers: + initializer_registry = InitializerRegistry.get_registry_singleton() + for initializer_name in request.initializers: + try: + initializer_class = initializer_registry.get_class(initializer_name) + except KeyError as e: + raise ValueError(f"Initializer not found: {e}") from None + instance = initializer_class() + if request.initializer_args and initializer_name in request.initializer_args: + instance.set_params_from_args(args=request.initializer_args[initializer_name]) + await instance.initialize_async() + + # Resolve target + target_registry = TargetRegistry.get_registry_singleton() + objective_target = target_registry.get_instance_by_name(request.target_name) + if objective_target is None: + available_names = target_registry.get_names() + if not available_names: + raise ValueError( + f"Target '{request.target_name}' not found. The target registry is empty. " + "Make sure to include an initializer that registers targets " + "(e.g., initializers: ['target'])." + ) + raise ValueError( + f"Target '{request.target_name}' not found in registry. Available targets: {', '.join(available_names)}" + ) + + # Build init kwargs + init_kwargs: dict[str, Any] = { + "objective_target": objective_target, + "max_concurrency": request.max_concurrency, + "max_retries": request.max_retries, + } + + if request.labels: + init_kwargs["memory_labels"] = request.labels + + # Validate and resolve strategies + if request.strategies: + strategy_class = scenario_class.get_strategy_class() + strategy_enums = [] + for name in request.strategies: + try: + strategy_enums.append(strategy_class(name)) + except ValueError: + available_strategies = [s.value for s in strategy_class] + raise ValueError( + f"Strategy '{name}' not found for scenario '{request.scenario_name}'. " + f"Available: {', '.join(available_strategies)}" + ) from None + init_kwargs["scenario_strategies"] = strategy_enums + + # Build dataset config + if request.dataset_names: + init_kwargs["dataset_config"] = DatasetConfiguration( + dataset_names=request.dataset_names, + max_dataset_size=request.max_dataset_size, + ) + elif request.max_dataset_size is not None: + default_config = scenario_class.default_dataset_config() + default_config.max_dataset_size = request.max_dataset_size + init_kwargs["dataset_config"] = default_config + + # Instantiate and initialize scenario + constructor_kwargs: dict[str, Any] = {} + if request.scenario_result_id: + constructor_kwargs["scenario_result_id"] = request.scenario_result_id + scenario = scenario_class(**constructor_kwargs) # type: ignore[call-arg] + scenario.set_params_from_args(args=request.scenario_params or {}) + await scenario.initialize_async(**init_kwargs) + return scenario + + async def _execute_run_async(self, *, scenario_result_id: str) -> None: + """ + Execute a scenario run (background task entry point). + + Only calls scenario.run_async on the already-initialized scenario. + + Note: this method intentionally does NOT remove the entry from + ``_active_tasks`` on completion. The entry must stay so that + ``_build_response_from_db`` can read ``active.error`` when the + caller next polls the run status. Cleanup happens lazily there + once the error has been surfaced. + + Args: + scenario_result_id: The scenario result ID for this run. + """ + active = self._active_tasks[scenario_result_id] + assert active.scenario is not None + + try: + await active.scenario.run_async() + + except asyncio.CancelledError: + logger.info(f"Scenario run {scenario_result_id} was cancelled.") + + except Exception as e: + active.error = str(e) + logger.exception(f"Scenario run {scenario_result_id} failed: {e}") + + finally: + self._run_semaphore.release() + + def _build_response(self, *, scenario_result_id: str) -> ScenarioRunSummary | None: + """ + Build a ScenarioRunResponse by querying the database and merging active task state. + + Args: + scenario_result_id: The scenario result ID. + + Returns: + ScenarioRunResponse if found in the database, None otherwise. + """ + memory = CentralMemory.get_memory_instance() + results = memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + if not results: + return None + return self._build_response_from_db(scenario_result=results[0]) + + def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> ScenarioRunSummary: + """ + Build a ScenarioRunResponse from a database ScenarioResult, merged with active task info. + + Args: + scenario_result: A ScenarioResult retrieved from CentralMemory. + + Returns: + The API response model. + """ + scenario_result_id = str(scenario_result.id) + active = self._active_tasks.get(scenario_result_id) + + # Clean up finished active tasks after reading the error + error = None + if active is not None: + error = active.error + if active.task is not None and active.task.done(): + del self._active_tasks[scenario_result_id] + + status = _STATE_TO_STATUS.get(scenario_result.scenario_run_state, ScenarioRunStatus.FAILED) + + # Build result fields for completed runs + strategies_used: list[str] = [] + total_attacks = 0 + completed_attacks = 0 + if status == ScenarioRunStatus.COMPLETED: + completed_attacks = sum( + 1 + for results in scenario_result.attack_results.values() + for ar in results + if ar.outcome in (AttackOutcome.SUCCESS, AttackOutcome.FAILURE) + ) + total_attacks = sum(len(results) for results in scenario_result.attack_results.values()) + strategies_used = scenario_result.get_strategies_used() + + return ScenarioRunSummary( + scenario_result_id=scenario_result_id, + scenario_name=scenario_result.scenario_identifier.name, + scenario_version=scenario_result.scenario_identifier.version, + status=status, + created_at=scenario_result.created_at, + updated_at=scenario_result.completion_time, + error=error, + strategies_used=strategies_used, + total_attacks=total_attacks, + completed_attacks=completed_attacks, + objective_achieved_rate=scenario_result.objective_achieved_rate(), + labels=scenario_result.labels, + completed_at=scenario_result.completion_time, + ) + + def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | None: + """ + Get detailed results for a completed scenario run. + + Retrieves the full ScenarioResult from CentralMemory and maps it + to a detailed response model with per-attack outcomes. + + Args: + scenario_result_id: The scenario result ID. + + Returns: + ScenarioRunDetail if the run is completed and results exist, None if not found. + + Raises: + ValueError: If the run is not in a completed state. + """ + memory = CentralMemory.get_memory_instance() + results = memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + if not results: + return None + + scenario_result = results[0] + run_response = self._build_response_from_db(scenario_result=scenario_result) + + if run_response.status != ScenarioRunStatus.COMPLETED: + raise ValueError(f"Results are only available for completed runs. Current status: '{run_response.status}'.") + + # Build per-attack detail + attacks: list[AtomicAttackResults] = [] + display_group_map = scenario_result.display_group_map + for attack_name, attack_results in scenario_result.attack_results.items(): + details: list[AttackSummary] = [] + success_count = 0 + failure_count = 0 + + for ar in attack_results: + score_value = None + if ar.last_score is not None: + score_value = str(ar.last_score.get_value()) + + last_response_text = None + if ar.last_response is not None: + last_response_text = str(ar.last_response) + + timestamp = ar.timestamp or datetime.now(timezone.utc) + details.append( + AttackSummary( + attack_result_id=ar.attack_result_id, + conversation_id=ar.conversation_id, + objective=ar.objective, + outcome=ar.outcome.value, + outcome_reason=ar.outcome_reason, + last_response=last_response_text, + score_value=score_value, + executed_turns=ar.executed_turns, + execution_time_ms=ar.execution_time_ms, + created_at=timestamp, + updated_at=timestamp, + ) + ) + + if ar.outcome == AttackOutcome.SUCCESS: + success_count += 1 + elif ar.outcome == AttackOutcome.FAILURE: + failure_count += 1 + + attacks.append( + AtomicAttackResults( + atomic_attack_name=attack_name, + display_group=display_group_map.get(attack_name), + results=details, + success_count=success_count, + failure_count=failure_count, + total_count=len(details), + ) + ) + + return ScenarioRunDetail( + run=run_response, + attacks=attacks, + ) + + +_service_instance: ScenarioRunService | None = None + + +def get_scenario_run_service() -> ScenarioRunService: + """ + Get the global scenario run service instance. + + On first call, reads ``max_concurrent_scenario_runs`` from ``app.state`` + (set by ``pyrit_backend`` CLI) if available, otherwise uses the default. + + Returns: + The singleton ScenarioRunService instance. + """ + global _service_instance + if _service_instance is not None: + return _service_instance + + max_runs = _DEFAULT_MAX_CONCURRENT_RUNS + try: + from pyrit.backend.main import app + + max_runs = getattr(app.state, "max_concurrent_scenario_runs", _DEFAULT_MAX_CONCURRENT_RUNS) + except Exception: + pass + + _service_instance = ScenarioRunService(max_concurrent_runs=max_runs) + return _service_instance diff --git a/pyrit/backend/services/scenario_service.py b/pyrit/backend/services/scenario_service.py index 52df32fe61..77b4739020 100644 --- a/pyrit/backend/services/scenario_service.py +++ b/pyrit/backend/services/scenario_service.py @@ -12,11 +12,11 @@ from typing import Optional from pyrit.backend.models.common import PaginationInfo -from pyrit.backend.models.scenarios import ScenarioListResponse, ScenarioSummary +from pyrit.backend.models.scenarios import ListRegisteredScenarioResponse, RegisteredScenario from pyrit.registry import ScenarioMetadata, ScenarioRegistry -def _metadata_to_summary(metadata: ScenarioMetadata) -> ScenarioSummary: +def _metadata_to_registered_scenario(metadata: ScenarioMetadata) -> RegisteredScenario: """ Convert a ScenarioMetadata dataclass to a ScenarioSummary Pydantic model. @@ -26,7 +26,7 @@ def _metadata_to_summary(metadata: ScenarioMetadata) -> ScenarioSummary: Returns: ScenarioSummary Pydantic model. """ - return ScenarioSummary( + return RegisteredScenario( scenario_name=metadata.registry_name, scenario_type=metadata.class_name, description=metadata.class_description, @@ -54,7 +54,7 @@ async def list_scenarios_async( *, limit: int = 50, cursor: Optional[str] = None, - ) -> ScenarioListResponse: + ) -> ListRegisteredScenarioResponse: """ List all available scenarios with pagination. @@ -66,17 +66,17 @@ async def list_scenarios_async( ScenarioListResponse with paginated scenario summaries. """ all_metadata = self._registry.list_metadata() - all_summaries = [_metadata_to_summary(m) for m in all_metadata] + all_summaries = [_metadata_to_registered_scenario(m) for m in all_metadata] page, has_more = self._paginate(items=all_summaries, cursor=cursor, limit=limit) next_cursor = page[-1].scenario_name if has_more and page else None - return ScenarioListResponse( + return ListRegisteredScenarioResponse( items=page, pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor), ) - async def get_scenario_async(self, *, scenario_name: str) -> Optional[ScenarioSummary]: + async def get_scenario_async(self, *, scenario_name: str) -> Optional[RegisteredScenario]: """ Get a single scenario by registry name. @@ -89,16 +89,16 @@ async def get_scenario_async(self, *, scenario_name: str) -> Optional[ScenarioSu all_metadata = self._registry.list_metadata() for metadata in all_metadata: if metadata.registry_name == scenario_name: - return _metadata_to_summary(metadata) + return _metadata_to_registered_scenario(metadata) return None @staticmethod def _paginate( *, - items: list[ScenarioSummary], + items: list[RegisteredScenario], cursor: Optional[str], limit: int, - ) -> tuple[list[ScenarioSummary], bool]: + ) -> tuple[list[RegisteredScenario], bool]: """ Apply cursor-based pagination. diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index b75634891a..c17eb83b54 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -146,6 +146,7 @@ def __init__( self._env_files = config._resolve_env_files() self._operator = config.operator self._operation = config.operation + self._max_concurrent_scenario_runs = config.max_concurrent_scenario_runs # Lazy-loaded registries self._scenario_registry: Optional[ScenarioRegistry] = None @@ -221,6 +222,7 @@ def with_overrides( derived._env_files = self._env_files derived._operator = self._operator derived._operation = self._operation + derived._max_concurrent_scenario_runs = self._max_concurrent_scenario_runs derived._scenario_config = self._scenario_config # Apply overrides or inherit diff --git a/pyrit/cli/pyrit_backend.py b/pyrit/cli/pyrit_backend.py index f45cc0c448..8eed2cc929 100644 --- a/pyrit/cli/pyrit_backend.py +++ b/pyrit/cli/pyrit_backend.py @@ -198,6 +198,7 @@ async def initialize_and_run_async(*, parsed_args: Namespace) -> int: if context._operation: default_labels["operation"] = context._operation app.state.default_labels = default_labels + app.state.max_concurrent_scenario_runs = context._max_concurrent_scenario_runs display_host = parsed_args.host print(f"🚀 Starting PyRIT backend on http://{display_host}:{parsed_args.port}") diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 83cad99de6..fb05426455 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -1354,7 +1354,7 @@ class TreeOfAttacksWithPruningAttack(AttackStrategy[TAPAttackContext, TAPAttackR def __init__( self, *, - objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-assignment, ty:invalid-parameter-default] + objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] attack_adversarial_config: AttackAdversarialConfig, attack_converter_config: AttackConverterConfig | None = None, attack_scoring_config: TAPAttackScoringConfig | None = None, diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index e04d56f3dd..753236968d 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -786,6 +786,8 @@ def _query_entries( conditions: Optional[Any] = None, distinct: bool = False, join_scores: bool = False, + order_by: Optional[Any] = None, + limit: int | None = None, ) -> MutableSequence[Model]: """ Fetch data from the specified table model with optional conditions. @@ -795,6 +797,8 @@ def _query_entries( conditions: SQLAlchemy filter conditions (Optional). distinct: Flag to return distinct rows (defaults to False). join_scores: Flag to join the scores table with entries (defaults to False). + order_by: SQLAlchemy order_by clause (Optional). + limit (int | None): Maximum number of rows to return. Defaults to None (no limit). Returns: List of model instances representing the rows fetched from the table. @@ -814,8 +818,12 @@ def _query_entries( ) if conditions is not None: query = query.filter(conditions) + if order_by is not None: + query = query.order_by(order_by) if distinct: - return query.distinct().all() + query = query.distinct() + if limit is not None: + query = query.limit(limit) return query.all() except SQLAlchemyError as e: logger.exception(f"Error fetching data from table {model_class.__tablename__}: {e}") # type: ignore[ty:unresolved-attribute] diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 0c3310c0ee..b28d05976e 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -354,6 +354,8 @@ def _query_entries( conditions: Optional[Any] = None, distinct: bool = False, join_scores: bool = False, + order_by: Optional[Any] = None, + limit: int | None = None, ) -> MutableSequence[Model]: """ Fetch data from the specified table model with optional conditions. @@ -363,6 +365,8 @@ def _query_entries( conditions: SQLAlchemy filter conditions (Optional). distinct: Whether to return distinct rows only. Defaults to False. join_scores: Whether to join the scores table. Defaults to False. + order_by: SQLAlchemy order_by clause (Optional). + limit (int | None): Maximum number of rows to return. Defaults to None (no limit). Returns: List of model instances representing the rows fetched from the table. @@ -378,6 +382,8 @@ def _execute_batched_query( distinct: bool = False, join_scores: bool = False, batch_size: int | None = None, + order_by: Optional[Any] = None, + limit: int | None = None, ) -> MutableSequence[Model]: """ Execute queries in batches to avoid exceeding database bind variable limits. @@ -394,6 +400,8 @@ def _execute_batched_query( join_scores: Whether to join the scores table. batch_size: Override for the number of values per batch. Defaults to ``_MAX_BIND_VARS`` when not specified. + order_by: SQLAlchemy order_by clause (Optional). + limit (int | None): Maximum number of rows to return. Defaults to None (no limit). Returns: MutableSequence[Model]: Merged and deduplicated results from all batched queries. @@ -411,6 +419,8 @@ def _execute_batched_query( conditions=and_(*conditions) if conditions else None, distinct=distinct, join_scores=join_scores, + order_by=order_by, + limit=limit, ) # Execute multiple separate queries and merge results @@ -426,6 +436,7 @@ def _execute_batched_query( conditions=and_(*conditions) if conditions else None, distinct=distinct, join_scores=join_scores, + order_by=order_by, ) # Deduplicate by primary key (id) @@ -2062,10 +2073,13 @@ def get_scenario_results( objective_target_endpoint: Optional[str] = None, objective_target_model_name: Optional[str] = None, identifier_filters: Optional[Sequence[IdentifierFilter]] = None, + limit: int | None = None, ) -> Sequence[ScenarioResult]: """ Retrieve a list of ScenarioResult objects based on the specified filters. + Results are always ordered by completion_time descending (most recent first). + Args: scenario_result_ids (Optional[Sequence[str]], optional): A list of scenario result IDs. Defaults to None. @@ -2088,9 +2102,11 @@ def get_scenario_results( identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A sequence of IdentifierFilter objects that allows filtering by identifier JSON properties. Defaults to None. + limit (int | None): Maximum number of results to return. Defaults to None (no limit). Returns: - Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. + Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters, + ordered by completion_time descending. """ if scenario_result_ids is not None and len(scenario_result_ids) == 0: return [] @@ -2149,6 +2165,8 @@ def get_scenario_results( ) try: + order_by_clause = ScenarioResultEntry.completion_time.desc() + # Handle scenario_result_ids with batched queries if needed if scenario_result_ids: entries = self._execute_batched_query( @@ -2156,9 +2174,16 @@ def get_scenario_results( batch_column=ScenarioResultEntry.id, batch_values=list(scenario_result_ids), other_conditions=conditions, + order_by=order_by_clause, + limit=limit, ) else: - entries = self._query_entries(ScenarioResultEntry, conditions=and_(*conditions) if conditions else None) + entries = self._query_entries( + ScenarioResultEntry, + conditions=and_(*conditions) if conditions else None, + order_by=order_by_clause, + limit=limit, + ) # Convert entries to ScenarioResults and populate attack_results efficiently scenario_results = [] diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 6b89313ba3..e9b6b6659e 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -949,7 +949,7 @@ class ScenarioResultEntry(Base): scenario_init_data: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON, nullable=True) objective_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON, nullable=False) objective_scorer_identifier: Mapped[Optional[dict[str, str]]] = mapped_column(JSON, nullable=True) - scenario_run_state: Mapped[Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED"]] = mapped_column( + scenario_run_state: Mapped[Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED"]] = mapped_column( String, nullable=False, default="CREATED" ) attack_results_json: Mapped[str] = mapped_column(Unicode, nullable=False) @@ -1053,6 +1053,7 @@ def get_scenario_result(self) -> ScenarioResult: objective_scorer_identifier=scorer_identifier, # type: ignore[ty:invalid-argument-type] scenario_run_state=self.scenario_run_state, labels=self.labels, + created_at=self.timestamp, number_tries=self.number_tries, completion_time=self.completion_time, display_group_map=display_group_map, diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 53f6ce9134..0874428878 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -326,6 +326,8 @@ def _query_entries( conditions: Optional[Any] = None, distinct: bool = False, join_scores: bool = False, + order_by: Optional[Any] = None, + limit: int | None = None, ) -> MutableSequence[Model]: """ Fetch data from the specified table model with optional conditions. @@ -335,6 +337,8 @@ def _query_entries( conditions: SQLAlchemy filter conditions (Optional). distinct: Flag to return distinct rows (default is False). join_scores: Flag to join the scores table (default is False). + order_by: SQLAlchemy order_by clause (Optional). + limit (int | None): Maximum number of rows to return. Defaults to None (no limit). Returns: List of model instances representing the rows fetched from the table. @@ -354,8 +358,12 @@ def _query_entries( ) if conditions is not None: query = query.filter(conditions) + if order_by is not None: + query = query.order_by(order_by) if distinct: - return query.distinct().all() + query = query.distinct() + if limit is not None: + query = query.limit(limit) return query.all() except SQLAlchemyError as e: logger.exception(f"Error fetching data from table {model_class.__tablename__}: {e}") # type: ignore[ty:unresolved-attribute] diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index 3b159846d2..7f237e8c9f 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -47,7 +47,7 @@ def __init__( self.init_data = init_data -ScenarioRunState = Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED"] +ScenarioRunState = Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED"] class ScenarioResult: @@ -64,6 +64,7 @@ def __init__( objective_scorer_identifier: "ComponentIdentifier", scenario_run_state: ScenarioRunState = "CREATED", labels: Optional[dict[str, str]] = None, + created_at: Optional[datetime] = None, completion_time: Optional[datetime] = None, number_tries: int = 0, id: Optional[uuid.UUID] = None, # noqa: A002 @@ -79,6 +80,7 @@ def __init__( objective_scorer_identifier (ComponentIdentifier): Objective scorer identifier. scenario_run_state (ScenarioRunState): Current scenario run state. labels (Optional[dict[str, str]]): Optional labels. + created_at (Optional[datetime]): When the scenario result was created. completion_time (Optional[datetime]): Optional completion timestamp. number_tries (int): Number of run attempts. id (Optional[uuid.UUID]): Optional scenario result ID. @@ -97,10 +99,16 @@ def __init__( self.scenario_run_state = scenario_run_state self.attack_results = attack_results self.labels = labels if labels is not None else {} + self.created_at = created_at if created_at is not None else datetime.now(timezone.utc) self.completion_time = completion_time if completion_time is not None else datetime.now(timezone.utc) self.number_tries = number_tries self._display_group_map = display_group_map or {} + @property + def display_group_map(self) -> dict[str, str]: + """Mapping of atomic_attack_name → display group label.""" + return self._display_group_map + def get_strategies_used(self) -> list[str]: """ Get the list of strategies used in this scenario. diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 184a235f65..e2e18e2350 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -132,6 +132,7 @@ class ConfigurationLoader(YamlLoadable): operator: Optional[str] = None operation: Optional[str] = None scenario: Optional[Union[str, dict[str, Any]]] = None + max_concurrent_scenario_runs: int = 3 def __post_init__(self) -> None: """Validate and normalize the configuration after loading.""" diff --git a/tests/unit/backend/test_scenario_run_routes.py b/tests/unit/backend/test_scenario_run_routes.py new file mode 100644 index 0000000000..9eceff77e7 --- /dev/null +++ b/tests/unit/backend/test_scenario_run_routes.py @@ -0,0 +1,313 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for scenario run API routes. +""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +import pyrit.backend.services.scenario_run_service as _svc_mod +from pyrit.backend.main import app +from pyrit.backend.models.attacks import AttackSummary +from pyrit.backend.models.scenarios import ( + AtomicAttackResults, + ScenarioRunDetail, + ScenarioRunListResponse, + ScenarioRunStatus, + ScenarioRunSummary, +) + + +@pytest.fixture +def client() -> TestClient: + """Create a test client for the FastAPI app.""" + return TestClient(app) + + +@pytest.fixture(autouse=True) +def clear_service_cache(): + """Clear the service singleton between tests.""" + _svc_mod._service_instance = None + yield + _svc_mod._service_instance = None + + +def _mock_run_response( + *, + run_id: str = "test-run-id", + scenario_name: str = "foundry.red_team_agent", + run_status: ScenarioRunStatus = ScenarioRunStatus.PENDING, +) -> ScenarioRunSummary: + """Create a mock ScenarioRunResponse.""" + return ScenarioRunSummary( + scenario_result_id=run_id, + scenario_name=scenario_name, + status=run_status, + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + error=None, + ) + + +class TestStartScenarioRunRoute: + """Tests for POST /api/scenarios/runs.""" + + def test_start_run_returns_202(self, client: TestClient) -> None: + """Test that a valid request returns 202 Accepted.""" + mock_response = _mock_run_response() + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.start_run_async = AsyncMock(return_value=mock_response) + mock_get.return_value = mock_service + + response = client.post( + "/api/scenarios/runs", + json={"scenario_name": "foundry.red_team_agent", "target_name": "my_target"}, + ) + + assert response.status_code == status.HTTP_202_ACCEPTED + data = response.json() + assert data["scenario_result_id"] == "test-run-id" + assert data["status"] == "pending" + + def test_start_run_invalid_scenario_returns_400(self, client: TestClient) -> None: + """Test that an invalid scenario returns 400.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.start_run_async = AsyncMock(side_effect=ValueError("'bad.scenario' not found in registry.")) + mock_get.return_value = mock_service + + response = client.post( + "/api/scenarios/runs", + json={"scenario_name": "bad.scenario", "target_name": "my_target"}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "not found" in response.json()["detail"] + + def test_start_run_missing_required_fields_returns_422(self, client: TestClient) -> None: + """Test that missing required fields returns 422.""" + response = client.post("/api/scenarios/runs", json={}) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_start_run_with_all_options(self, client: TestClient) -> None: + """Test that all optional fields are accepted.""" + mock_response = _mock_run_response() + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.start_run_async = AsyncMock(return_value=mock_response) + mock_get.return_value = mock_service + + response = client.post( + "/api/scenarios/runs", + json={ + "scenario_name": "foundry.red_team_agent", + "target_name": "my_target", + "initializers": ["target", "load_default_datasets"], + "strategies": ["base64", "rot13"], + "dataset_names": ["harmful_content"], + "max_dataset_size": 50, + "max_concurrency": 5, + "max_retries": 2, + "memory_labels": {"team": "red"}, + "scenario_params": {"max_turns": 10, "threshold": 0.8}, + "initializer_args": {"target": {"endpoint": "https://example.com"}}, + }, + ) + + assert response.status_code == status.HTTP_202_ACCEPTED + + +class TestListScenarioRunsRoute: + """Tests for GET /api/scenarios/runs.""" + + def test_list_runs_returns_200(self, client: TestClient) -> None: + """Test that list runs returns 200 with empty list.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.list_runs.return_value = ScenarioRunListResponse(items=[]) + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs") + + assert response.status_code == status.HTTP_200_OK + assert response.json()["items"] == [] + + def test_list_runs_returns_multiple_runs(self, client: TestClient) -> None: + """Test that list runs returns all tracked runs.""" + runs = [ + _mock_run_response(run_id="run-1"), + _mock_run_response(run_id="run-2", run_status=ScenarioRunStatus.RUNNING), + ] + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.list_runs.return_value = ScenarioRunListResponse(items=runs) + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs") + + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["items"]) == 2 + + +class TestGetScenarioRunRoute: + """Tests for GET /api/scenarios/runs/{id}.""" + + def test_get_run_returns_200(self, client: TestClient) -> None: + """Test that getting an existing run returns 200.""" + mock_response = _mock_run_response(run_status=ScenarioRunStatus.RUNNING) + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run.return_value = mock_response + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/test-run-id") + + assert response.status_code == status.HTTP_200_OK + assert response.json()["status"] == "running" + + def test_get_run_not_found_returns_404(self, client: TestClient) -> None: + """Test that getting a non-existent run returns 404.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run.return_value = None + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestCancelScenarioRunRoute: + """Tests for POST /api/scenarios/runs/{id}/cancel.""" + + def test_cancel_run_returns_200(self, client: TestClient) -> None: + """Test that cancelling a running scenario returns 200.""" + mock_response = _mock_run_response(run_status=ScenarioRunStatus.CANCELLED) + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.cancel_run_async = AsyncMock(return_value=mock_response) + mock_get.return_value = mock_service + + response = client.post("/api/scenarios/runs/test-run-id/cancel") + + assert response.status_code == status.HTTP_200_OK + assert response.json()["status"] == "cancelled" + + def test_cancel_run_not_found_returns_404(self, client: TestClient) -> None: + """Test that cancelling a non-existent run returns 404.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.cancel_run_async = AsyncMock(return_value=None) + mock_get.return_value = mock_service + + response = client.post("/api/scenarios/runs/nonexistent/cancel") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_cancel_completed_run_returns_409(self, client: TestClient) -> None: + """Test that cancelling a completed run returns 409 Conflict.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.cancel_run_async = AsyncMock(side_effect=ValueError("Cannot cancel run in 'completed' state.")) + mock_get.return_value = mock_service + + response = client.post("/api/scenarios/runs/test-run-id/cancel") + + assert response.status_code == status.HTTP_409_CONFLICT + assert "Cannot cancel" in response.json()["detail"] + + +class TestGetScenarioRunResultsRoute: + """Tests for GET /api/scenarios/runs/{id}/results.""" + + def test_get_results_returns_200(self, client: TestClient) -> None: + """Test that getting results of a completed run returns 200.""" + mock_result = ScenarioRunDetail( + run=ScenarioRunSummary( + scenario_result_id="result-uuid", + scenario_name="foundry.red_team_agent", + scenario_version=1, + status=ScenarioRunStatus.COMPLETED, + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + objective_achieved_rate=50, + labels={"team": "red"}, + completed_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ), + attacks=[ + AtomicAttackResults( + atomic_attack_name="base64_attack", + display_group="encoding", + results=[ + AttackSummary( + attack_result_id="ar-1", + conversation_id="conv-1", + objective="Extract sensitive info", + outcome="success", + outcome_reason="Model revealed data", + last_response="Here is the data...", + score_value="1.0", + executed_turns=3, + execution_time_ms=1500, + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ), + ], + success_count=1, + failure_count=0, + total_count=1, + ), + ], + ) + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run_results.return_value = mock_result + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/test-run-id/results") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["run"]["scenario_result_id"] == "result-uuid" + assert data["run"]["objective_achieved_rate"] == 50 + assert len(data["attacks"]) == 1 + assert data["attacks"][0]["atomic_attack_name"] == "base64_attack" + assert data["attacks"][0]["results"][0]["outcome"] == "success" + + def test_get_results_not_found_returns_404(self, client: TestClient) -> None: + """Test that getting results of a non-existent run returns 404.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run_results.return_value = None + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/nonexistent/results") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_get_results_not_completed_returns_409(self, client: TestClient) -> None: + """Test that getting results of a non-completed run returns 409.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run_results.side_effect = ValueError( + "Results are only available for completed runs. Current status: 'running'." + ) + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/test-run-id/results") + + assert response.status_code == status.HTTP_409_CONFLICT + assert "only available for completed runs" in response.json()["detail"] diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py new file mode 100644 index 0000000000..524e63918c --- /dev/null +++ b/tests/unit/backend/test_scenario_run_service.py @@ -0,0 +1,492 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for ScenarioRunService. +""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import pyrit.backend.services.scenario_run_service as _svc_mod +from pyrit.backend.models.scenarios import ( + RunScenarioRequest, + ScenarioRunStatus, +) +from pyrit.backend.services.scenario_run_service import ( + _DEFAULT_MAX_CONCURRENT_RUNS, + ScenarioRunService, +) + +_REGISTRY_PATCH_BASE = "pyrit.registry" +_MEMORY_PATCH = "pyrit.memory.CentralMemory.get_memory_instance" + + +@pytest.fixture(autouse=True) +def clear_service_cache(): + """Clear the singleton instance between tests.""" + _svc_mod._service_instance = None + yield + _svc_mod._service_instance = None + + +def _make_request( + *, + scenario_name: str = "foundry.red_team_agent", + target_name: str = "my_target", + initializers: list[str] | None = None, + strategies: list[str] | None = None, + scenario_result_id: str | None = None, +) -> RunScenarioRequest: + """Create a RunScenarioRequest for testing.""" + return RunScenarioRequest( + scenario_name=scenario_name, + target_name=target_name, + initializers=initializers, + strategies=strategies, + scenario_result_id=scenario_result_id, + ) + + +def _make_db_scenario_result( + *, + result_id: str = "sr-uuid-1", + scenario_name: str = "foundry.red_team_agent", + run_state: str = "IN_PROGRESS", + attack_results: dict | None = None, +) -> MagicMock: + """Create a mock ScenarioResult as returned by CentralMemory.""" + sr = MagicMock() + sr.id = result_id + sr.scenario_identifier.name = scenario_name + sr.scenario_identifier.version = 1 + sr.scenario_run_state = run_state + sr.get_strategies_used.return_value = [] + sr.attack_results = attack_results or {} + sr.number_tries = 1 + sr.created_at = datetime(2025, 1, 1, tzinfo=timezone.utc) + sr.completion_time = datetime(2025, 1, 1, 0, 5, tzinfo=timezone.utc) + sr.labels = {} + sr.objective_achieved_rate.return_value = 0 + sr.get_display_groups.return_value = {} + sr.display_group_map = {} + return sr + + +@pytest.fixture +def mock_memory(): + """Patch CentralMemory.get_memory_instance to return a mock.""" + mock = MagicMock() + mock.get_scenario_results.return_value = [] + with patch(_MEMORY_PATCH, return_value=mock): + yield mock + + +@pytest.fixture +def mock_all_registries(mock_memory): + """Patch all registries and CentralMemory with valid defaults.""" + mock_scenario_instance = MagicMock() + mock_scenario_instance.initialize_async = AsyncMock() + mock_scenario_instance.run_async = AsyncMock() + mock_scenario_instance._scenario_result_id = "sr-uuid-1" + + mock_scenario_class = MagicMock(return_value=mock_scenario_instance) + mock_scenario_class.get_strategy_class.return_value = MagicMock() + mock_scenario_class.default_dataset_config.return_value = MagicMock() + + mock_sr = MagicMock() + mock_sr.get_class.return_value = mock_scenario_class + + mock_tr = MagicMock() + mock_tr.get_instance_by_name.return_value = MagicMock() + mock_tr.get_names.return_value = ["my_target"] + + mock_ir = MagicMock() + mock_ir.get_class.return_value = MagicMock(return_value=MagicMock(initialize_async=AsyncMock())) + + # By default, return a matching DB result for get_run / list_runs queries + db_result = _make_db_scenario_result() + mock_memory.get_scenario_results.return_value = [db_result] + + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_tr), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton", return_value=mock_ir), + ): + yield { + "scenario_registry": mock_sr, + "target_registry": mock_tr, + "initializer_registry": mock_ir, + "scenario_class": mock_scenario_class, + "scenario_instance": mock_scenario_instance, + "memory": mock_memory, + "db_result": db_result, + } + + +class TestScenarioRunServiceStartRun: + """Tests for ScenarioRunService.start_run_async.""" + + async def test_start_run_returns_running_status(self, mock_all_registries) -> None: + """Test that starting a run returns RUNNING status with run_id = scenario_result_id.""" + service = ScenarioRunService() + response = await service.start_run_async(request=_make_request()) + + assert response.scenario_result_id == "sr-uuid-1" + assert response.status == ScenarioRunStatus.RUNNING + assert response.scenario_name == "foundry.red_team_agent" + assert response.error is None + + async def test_start_run_invalid_scenario_raises_value_error(self, mock_memory) -> None: + """Test that an invalid scenario name raises ValueError immediately.""" + service = ScenarioRunService() + + mock_sr = MagicMock() + mock_sr.get_class.side_effect = KeyError("'bad.scenario' not found in registry. Available: foo") + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton"), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton"), + ): + with pytest.raises(ValueError, match="not found in registry"): + await service.start_run_async(request=_make_request(scenario_name="bad.scenario")) + + async def test_start_run_invalid_target_raises_value_error(self, mock_memory) -> None: + """Test that an invalid target name raises ValueError immediately.""" + service = ScenarioRunService() + + mock_sr = MagicMock() + mock_sr.get_class.return_value = MagicMock() + + mock_tr = MagicMock() + mock_tr.get_instance_by_name.return_value = None + mock_tr.get_names.return_value = ["other_target"] + + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_tr), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton"), + ): + with pytest.raises(ValueError, match="my_target.*not found in registry"): + await service.start_run_async(request=_make_request()) + + async def test_start_run_invalid_initializer_raises_value_error(self, mock_memory) -> None: + """Test that an invalid initializer name raises ValueError immediately.""" + service = ScenarioRunService() + + mock_sr = MagicMock() + mock_sr.get_class.return_value = MagicMock() + + mock_ir = MagicMock() + mock_ir.get_class.side_effect = KeyError("'bad_init' not found") + + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton"), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton", return_value=mock_ir), + ): + with pytest.raises(ValueError, match="Initializer not found"): + await service.start_run_async(request=_make_request(initializers=["bad_init"])) + + async def test_start_run_invalid_strategy_raises_value_error(self, mock_memory) -> None: + """Test that an invalid strategy name raises ValueError immediately.""" + service = ScenarioRunService() + + mock_strategy_class = MagicMock(side_effect=ValueError("not a valid strategy")) + mock_strategy_class.__iter__ = MagicMock(return_value=iter([MagicMock(value="valid_strat")])) + + mock_scenario_class = MagicMock() + mock_scenario_class.get_strategy_class.return_value = mock_strategy_class + + mock_sr = MagicMock() + mock_sr.get_class.return_value = mock_scenario_class + + mock_tr = MagicMock() + mock_tr.get_instance_by_name.return_value = MagicMock() + + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_tr), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton"), + ): + with pytest.raises(ValueError, match="Strategy.*not found for scenario"): + await service.start_run_async(request=_make_request(strategies=["bad_strategy"])) + + async def test_start_run_exceeds_concurrent_limit(self, mock_all_registries) -> None: + """Test that exceeding concurrent run limit raises ValueError.""" + service = ScenarioRunService() + scenario_instance = mock_all_registries["scenario_instance"] + + # Each call needs a unique scenario_result_id + call_count = 0 + original_init = scenario_instance.initialize_async + + async def _set_unique_id(**kwargs: object) -> None: + nonlocal call_count + call_count += 1 + scenario_instance._scenario_result_id = f"sr-uuid-{call_count}" + + scenario_instance.initialize_async = AsyncMock(side_effect=_set_unique_id) + + # Fill up to the limit + for _ in range(_DEFAULT_MAX_CONCURRENT_RUNS): + await service.start_run_async(request=_make_request()) + + # Next one should fail + with pytest.raises(ValueError, match="Maximum concurrent runs"): + await service.start_run_async(request=_make_request()) + + async def test_start_run_runs_initializers(self, mock_all_registries) -> None: + """Test that initializers are run during start_run_async.""" + service = ScenarioRunService() + mock_ir = mock_all_registries["initializer_registry"] + mock_init_instance = mock_ir.get_class.return_value.return_value + + response = await service.start_run_async( + request=_make_request(initializers=["target", "load_default_datasets"]) + ) + + assert response.status == ScenarioRunStatus.RUNNING + assert mock_init_instance.initialize_async.await_count == 2 + + async def test_start_run_passes_scenario_result_id_for_resume(self, mock_all_registries) -> None: + """Test that scenario_result_id is passed to the scenario constructor for resumption.""" + service = ScenarioRunService() + mock_scenario_class = mock_all_registries["scenario_class"] + + response = await service.start_run_async(request=_make_request(scenario_result_id="existing-result-uuid")) + + assert response.status == ScenarioRunStatus.RUNNING + mock_scenario_class.assert_called_once_with(scenario_result_id="existing-result-uuid") + + async def test_start_run_omits_scenario_result_id_when_none(self, mock_all_registries) -> None: + """Test that scenario_result_id is not passed to constructor when not provided.""" + service = ScenarioRunService() + mock_scenario_class = mock_all_registries["scenario_class"] + + await service.start_run_async(request=_make_request()) + + mock_scenario_class.assert_called_once_with() + + +class TestScenarioRunServiceGetRun: + """Tests for ScenarioRunService.get_run.""" + + def test_get_run_returns_none_for_unknown_id(self, mock_memory) -> None: + """Test that get_run returns None for non-existent run.""" + mock_memory.get_scenario_results.return_value = [] + service = ScenarioRunService() + result = service.get_run(scenario_result_id="nonexistent-id") + assert result is None + + def test_get_run_returns_existing_run(self, mock_memory) -> None: + """Test that get_run returns a run from the database.""" + db_result = _make_db_scenario_result(result_id="sr-123", run_state="IN_PROGRESS") + mock_memory.get_scenario_results.return_value = [db_result] + + service = ScenarioRunService() + fetched = service.get_run(scenario_result_id="sr-123") + + assert fetched is not None + assert fetched.scenario_result_id == "sr-123" + assert fetched.scenario_name == "foundry.red_team_agent" + assert fetched.status == ScenarioRunStatus.RUNNING + + +class TestScenarioRunServiceListRuns: + """Tests for ScenarioRunService.list_runs.""" + + def test_list_runs_empty(self, mock_memory) -> None: + """Test that list_runs returns empty list when DB has no results.""" + mock_memory.get_scenario_results.return_value = [] + service = ScenarioRunService() + result = service.list_runs() + assert result.items == [] + mock_memory.get_scenario_results.assert_called_once_with(limit=100) + + def test_list_runs_returns_all_runs(self, mock_memory) -> None: + """Test that list_runs returns all runs from the database.""" + db_results = [ + _make_db_scenario_result(result_id="sr-1", run_state="COMPLETED"), + _make_db_scenario_result(result_id="sr-2", run_state="IN_PROGRESS"), + ] + mock_memory.get_scenario_results.return_value = db_results + + service = ScenarioRunService() + result = service.list_runs() + assert len(result.items) == 2 + mock_memory.get_scenario_results.assert_called_once_with(limit=100) + + def test_list_runs_passes_custom_limit(self, mock_memory) -> None: + """Test that list_runs passes a custom limit to the memory query.""" + mock_memory.get_scenario_results.return_value = [] + service = ScenarioRunService() + service.list_runs(limit=10) + mock_memory.get_scenario_results.assert_called_once_with(limit=10) + + +class TestScenarioRunServiceCancelRun: + """Tests for ScenarioRunService.cancel_run_async.""" + + async def test_cancel_run_returns_none_for_unknown_id(self, mock_memory) -> None: + """Test that cancel returns None for non-existent run.""" + mock_memory.get_scenario_results.return_value = [] + service = ScenarioRunService() + result = await service.cancel_run_async(scenario_result_id="nonexistent-id") + assert result is None + + async def test_cancel_run_sets_cancelled_status(self, mock_all_registries) -> None: + """Test that cancelling a running scenario persists CANCELLED to DB.""" + service = ScenarioRunService() + mock_memory = mock_all_registries["memory"] + response = await service.start_run_async(request=_make_request()) + + # After update_scenario_run_state, the next DB query should return CANCELLED + running_result = mock_all_registries["db_result"] + cancelled_result = _make_db_scenario_result(result_id=response.scenario_result_id, run_state="CANCELLED") + mock_memory.get_scenario_results.side_effect = [[running_result], [cancelled_result]] + + result = await service.cancel_run_async(scenario_result_id=response.scenario_result_id) + + mock_memory.update_scenario_run_state.assert_called_once_with( + scenario_result_id=response.scenario_result_id, scenario_run_state="CANCELLED" + ) + assert result is not None + assert result.status == ScenarioRunStatus.CANCELLED + + async def test_cancel_completed_run_raises_value_error(self, mock_memory) -> None: + """Test that cancelling a completed run raises ValueError.""" + db_result = _make_db_scenario_result(result_id="sr-done", run_state="COMPLETED") + mock_memory.get_scenario_results.return_value = [db_result] + + service = ScenarioRunService() + with pytest.raises(ValueError, match="Cannot cancel run"): + await service.cancel_run_async(scenario_result_id="sr-done") + + async def test_cancel_already_cancelled_run_raises_value_error(self, mock_memory) -> None: + """Test that cancelling an already-cancelled run raises ValueError.""" + db_result = _make_db_scenario_result(result_id="sr-cancelled", run_state="CANCELLED") + mock_memory.get_scenario_results.return_value = [db_result] + + service = ScenarioRunService() + with pytest.raises(ValueError, match="Cannot cancel run"): + await service.cancel_run_async(scenario_result_id="sr-cancelled") + + +class TestScenarioRunServiceExecution: + """Tests for the background execution logic.""" + + async def test_execute_run_completes_successfully(self, mock_all_registries) -> None: + """Test that a successful execution removes active task and DB reflects COMPLETED.""" + service = ScenarioRunService() + mock_instance = mock_all_registries["scenario_instance"] + mock_memory = mock_all_registries["memory"] + + mock_scenario_result = MagicMock() + mock_scenario_result.id = "sr-uuid-1" + mock_scenario_result.scenario_run_state = "COMPLETED" + mock_scenario_result.get_strategies_used.return_value = ["base64"] + mock_scenario_result.attack_results = {"attack1": []} + mock_scenario_result.number_tries = 1 + mock_scenario_result.created_at = datetime(2025, 1, 1, tzinfo=timezone.utc) + mock_scenario_result.completion_time = datetime(2025, 1, 1, 0, 5, tzinfo=timezone.utc) + + mock_instance.run_async = AsyncMock(return_value=mock_scenario_result) + + response = await service.start_run_async(request=_make_request()) + + # Wait for the background task to complete + active = service._active_tasks.get(response.scenario_result_id) + assert active is not None + assert active.task is not None + await active.task + + # Active task is cleaned up on next get_run (deferred cleanup) + assert response.scenario_result_id in service._active_tasks + fetched = service.get_run(scenario_result_id=response.scenario_result_id) + assert fetched is not None + assert response.scenario_result_id not in service._active_tasks + + async def test_execute_run_fails_with_error(self, mock_all_registries) -> None: + """Test that a run_async failure stores error and surfaces it via get_run.""" + service = ScenarioRunService() + mock_instance = mock_all_registries["scenario_instance"] + + mock_instance.run_async = AsyncMock(side_effect=RuntimeError("scenario exploded")) + + response = await service.start_run_async(request=_make_request()) + + # Wait for the background task + active = service._active_tasks.get(response.scenario_result_id) + assert active is not None + assert active.task is not None + await active.task + + # Error is stored on the active task until get_run reads it + assert active.error == "scenario exploded" + assert response.scenario_result_id in service._active_tasks + + # get_run should surface the error and clean up + fetched = service.get_run(scenario_result_id=response.scenario_result_id) + assert fetched is not None + assert fetched.error == "scenario exploded" + assert response.scenario_result_id not in service._active_tasks + + +class TestScenarioRunServiceGetResults: + """Tests for ScenarioRunService.get_run_results.""" + + def test_get_results_returns_none_for_unknown_id(self, mock_memory) -> None: + """Test that get_run_results returns None for non-existent run.""" + mock_memory.get_scenario_results.return_value = [] + service = ScenarioRunService() + result = service.get_run_results(scenario_result_id="nonexistent-id") + assert result is None + + def test_get_results_raises_if_not_completed(self, mock_memory) -> None: + """Test that get_run_results raises ValueError if run is not completed.""" + db_result = _make_db_scenario_result(result_id="sr-running", run_state="IN_PROGRESS") + mock_memory.get_scenario_results.return_value = [db_result] + + service = ScenarioRunService() + with pytest.raises(ValueError, match="only available for completed runs"): + service.get_run_results(scenario_result_id="sr-running") + + def test_get_results_returns_details_for_completed_run(self, mock_memory) -> None: + """Test that get_run_results returns full details for a completed run.""" + from pyrit.models import AttackOutcome + + mock_attack_result = MagicMock() + mock_attack_result.attack_result_id = "ar-1" + mock_attack_result.conversation_id = "conv-1" + mock_attack_result.objective = "Extract info" + mock_attack_result.outcome = AttackOutcome.SUCCESS + mock_attack_result.outcome_reason = "Model complied" + mock_attack_result.last_response = MagicMock(value="Here is the data") + mock_attack_result.last_score = MagicMock() + mock_attack_result.last_score.get_value.return_value = "1.0" + mock_attack_result.executed_turns = 3 + mock_attack_result.execution_time_ms = 1500 + mock_attack_result.timestamp = None + + db_result = _make_db_scenario_result( + result_id="sr-123", + run_state="COMPLETED", + attack_results={"base64_attack": [mock_attack_result]}, + ) + db_result.objective_achieved_rate.return_value = 100 + mock_memory.get_scenario_results.return_value = [db_result] + + service = ScenarioRunService() + detail = service.get_run_results(scenario_result_id="sr-123") + + assert detail is not None + assert detail.run.scenario_result_id == "sr-123" + assert detail.run.objective_achieved_rate == 100 + assert len(detail.attacks) == 1 + assert detail.attacks[0].atomic_attack_name == "base64_attack" + assert detail.attacks[0].success_count == 1 + assert detail.attacks[0].results[0].objective == "Extract info" + assert detail.attacks[0].results[0].outcome == "success" diff --git a/tests/unit/backend/test_scenario_service.py b/tests/unit/backend/test_scenario_service.py index 7f435d76a5..cffff85583 100644 --- a/tests/unit/backend/test_scenario_service.py +++ b/tests/unit/backend/test_scenario_service.py @@ -13,7 +13,7 @@ from pyrit.backend.main import app from pyrit.backend.models.common import PaginationInfo -from pyrit.backend.models.scenarios import ScenarioListResponse, ScenarioSummary +from pyrit.backend.models.scenarios import ListRegisteredScenarioResponse, RegisteredScenario from pyrit.backend.services.scenario_service import ScenarioService, get_scenario_service from pyrit.registry import ScenarioMetadata @@ -65,7 +65,6 @@ def _make_scenario_metadata( class TestScenarioServiceListScenarios: """Tests for ScenarioService.list_scenarios_async.""" - @pytest.mark.asyncio async def test_list_scenarios_returns_empty_when_no_scenarios(self) -> None: """Test that list returns empty list when no scenarios are registered.""" with patch.object(ScenarioService, "__init__", lambda self: None): @@ -78,7 +77,6 @@ async def test_list_scenarios_returns_empty_when_no_scenarios(self) -> None: assert result.items == [] assert result.pagination.has_more is False - @pytest.mark.asyncio async def test_list_scenarios_returns_scenarios_from_registry(self) -> None: """Test that list returns scenarios from registry.""" metadata = _make_scenario_metadata() @@ -100,7 +98,6 @@ async def test_list_scenarios_returns_scenarios_from_registry(self) -> None: assert result.items[0].default_datasets == ["test_dataset"] assert result.items[0].max_dataset_size is None - @pytest.mark.asyncio async def test_list_scenarios_paginates_with_limit(self) -> None: """Test that list respects the limit parameter.""" metadata_list = [ @@ -118,7 +115,6 @@ async def test_list_scenarios_paginates_with_limit(self) -> None: assert result.pagination.has_more is True assert result.pagination.next_cursor == "test.scenario_2" - @pytest.mark.asyncio async def test_list_scenarios_paginates_with_cursor(self) -> None: """Test that list uses cursor for pagination.""" metadata_list = [ @@ -137,7 +133,6 @@ async def test_list_scenarios_paginates_with_cursor(self) -> None: assert result.items[1].scenario_name == "test.scenario_3" assert result.pagination.has_more is True - @pytest.mark.asyncio async def test_list_scenarios_last_page_has_more_false(self) -> None: """Test that last page shows has_more=False.""" metadata_list = [ @@ -155,7 +150,6 @@ async def test_list_scenarios_last_page_has_more_false(self) -> None: assert result.pagination.has_more is False assert result.pagination.next_cursor is None - @pytest.mark.asyncio async def test_list_scenarios_includes_max_dataset_size(self) -> None: """Test that max_dataset_size is included in response.""" metadata = _make_scenario_metadata(max_dataset_size=10) @@ -173,7 +167,6 @@ async def test_list_scenarios_includes_max_dataset_size(self) -> None: class TestScenarioServiceGetScenario: """Tests for ScenarioService.get_scenario_async.""" - @pytest.mark.asyncio async def test_get_scenario_returns_matching_scenario(self) -> None: """Test that get returns the matching scenario.""" metadata = _make_scenario_metadata(registry_name="foundry.red_team_agent") @@ -188,7 +181,6 @@ async def test_get_scenario_returns_matching_scenario(self) -> None: assert result is not None assert result.scenario_name == "foundry.red_team_agent" - @pytest.mark.asyncio async def test_get_scenario_returns_none_for_missing(self) -> None: """Test that get returns None when scenario not found.""" with patch.object(ScenarioService, "__init__", lambda self: None): @@ -210,18 +202,18 @@ class TestScenarioRoutes: """Tests for scenario API routes.""" def test_list_scenarios_returns_200(self, client: TestClient) -> None: - """Test that GET /api/scenarios returns 200.""" + """Test that GET /api/scenarios/catalog returns 200.""" with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: mock_service = MagicMock() mock_service.list_scenarios_async = AsyncMock( - return_value=ScenarioListResponse( + return_value=ListRegisteredScenarioResponse( items=[], pagination=PaginationInfo(limit=50, has_more=False, next_cursor=None, prev_cursor=None), ) ) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios") + response = client.get("/api/scenarios/catalog") assert response.status_code == status.HTTP_200_OK data = response.json() @@ -229,8 +221,8 @@ def test_list_scenarios_returns_200(self, client: TestClient) -> None: assert data["pagination"]["has_more"] is False def test_list_scenarios_with_items(self, client: TestClient) -> None: - """Test that GET /api/scenarios returns scenario data.""" - summary = ScenarioSummary( + """Test that GET /api/scenarios/catalog returns scenario data.""" + summary = RegisteredScenario( scenario_name="foundry.red_team_agent", scenario_type="RedTeamAgentScenario", description="Red team agent testing", @@ -244,14 +236,14 @@ def test_list_scenarios_with_items(self, client: TestClient) -> None: with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: mock_service = MagicMock() mock_service.list_scenarios_async = AsyncMock( - return_value=ScenarioListResponse( + return_value=ListRegisteredScenarioResponse( items=[summary], pagination=PaginationInfo(limit=50, has_more=False, next_cursor=None, prev_cursor=None), ) ) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios") + response = client.get("/api/scenarios/catalog") assert response.status_code == status.HTTP_200_OK data = response.json() @@ -270,21 +262,21 @@ def test_list_scenarios_passes_pagination_params(self, client: TestClient) -> No with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: mock_service = MagicMock() mock_service.list_scenarios_async = AsyncMock( - return_value=ScenarioListResponse( + return_value=ListRegisteredScenarioResponse( items=[], pagination=PaginationInfo(limit=10, has_more=False, next_cursor=None, prev_cursor=None), ) ) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios?limit=10&cursor=test.scenario_1") + response = client.get("/api/scenarios/catalog?limit=10&cursor=test.scenario_1") assert response.status_code == status.HTTP_200_OK mock_service.list_scenarios_async.assert_called_once_with(limit=10, cursor="test.scenario_1") def test_get_scenario_returns_200(self, client: TestClient) -> None: - """Test that GET /api/scenarios/{name} returns 200 when found.""" - summary = ScenarioSummary( + """Test that GET /api/scenarios/catalog/{name} returns 200 when found.""" + summary = RegisteredScenario( scenario_name="foundry.red_team_agent", scenario_type="RedTeamAgentScenario", description="Red team agent testing", @@ -300,26 +292,26 @@ def test_get_scenario_returns_200(self, client: TestClient) -> None: mock_service.get_scenario_async = AsyncMock(return_value=summary) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios/foundry.red_team_agent") + response = client.get("/api/scenarios/catalog/foundry.red_team_agent") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["scenario_name"] == "foundry.red_team_agent" def test_get_scenario_returns_404_when_not_found(self, client: TestClient) -> None: - """Test that GET /api/scenarios/{name} returns 404 when not found.""" + """Test that GET /api/scenarios/catalog/{name} returns 404 when not found.""" with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: mock_service = MagicMock() mock_service.get_scenario_async = AsyncMock(return_value=None) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios/nonexistent") + response = client.get("/api/scenarios/catalog/nonexistent") assert response.status_code == status.HTTP_404_NOT_FOUND def test_get_scenario_with_dotted_name(self, client: TestClient) -> None: """Test that dotted scenario names (e.g., 'foundry.red_team_agent') work in path.""" - summary = ScenarioSummary( + summary = RegisteredScenario( scenario_name="garak.encoding", scenario_type="EncodingScenario", description="Encoding scenario", @@ -335,7 +327,7 @@ def test_get_scenario_with_dotted_name(self, client: TestClient) -> None: mock_service.get_scenario_async = AsyncMock(return_value=summary) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios/garak.encoding") + response = client.get("/api/scenarios/catalog/garak.encoding") assert response.status_code == status.HTTP_200_OK mock_service.get_scenario_async.assert_called_once_with(scenario_name="garak.encoding")